## Prequisites

In [None]:
!pip install pyspark py4j pillow torch torchvision

Collecting pyspark
  Downloading pyspark-3.5.1.tar.gz (317.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.0/317.0 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (f

In [None]:
!git clone --single-branch --branch feature/add_transformers https://github.com/OFA-Sys/OFA.git
!pip install OFA/transformers/
!git clone https://huggingface.co/OFA-Sys/OFA-medium

Cloning into 'OFA'...
remote: Enumerating objects: 5745, done.[K
remote: Counting objects: 100% (916/916), done.[K
remote: Compressing objects: 100% (254/254), done.[K
remote: Total 5745 (delta 695), reused 662 (delta 662), pack-reused 4829[K
Receiving objects: 100% (5745/5745), 97.78 MiB | 20.77 MiB/s, done.
Resolving deltas: 100% (2243/2243), done.
Updating files: 100% (3223/3223), done.
Processing ./OFA/transformers
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting sacremoses (from transformers==4.18.0.dev0)
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for transformers: filename=

### Imports

In [None]:
from pyspark.sql import SparkSession, Row
from pyspark.sql.types import BinaryType, IntegerType, StringType, BooleanType, StructType, StructField
from pyspark.sql.functions import udf, col, lit
from pyspark.sql.functions import min as sql_min
from PIL import Image
import numpy as np
import requests, json, csv, io

### Prepare raw data

In [None]:
fname_metadata = 'tour_photo.csv'
key = 'TOUR_API_KEY'
num_of_rows = '15'

def collect_metadata():
    # User parameters

    # Request parameters
    base_url = 'http://apis.data.go.kr/B551011/PhotoGalleryService1/galleryList1'
    params = {
        'numOfRows': num_of_rows,
        'pageNo': '1',
        'MobileOS' : 'ETC',
        'MobileApp' : 'AppTest',
        'arrange' : 'A',
        '_type' : 'json',
        'serviceKey' : key}

    # Get request
    response= requests.get(base_url,params=params).json()['response']

    # Save as csv file
    with open(fname_metadata, 'w', newline='', encoding='utf-8') as file:
        writer = csv.DictWriter(file, fieldnames=response[0].keys())
        writer.writeheader()
        writer.writerows(response['body']['items']['item'])


In [None]:
# Folder name to save the photos
save_folder = "tour_photos"

def download():
    with open(fname_metadata, 'r', encoding='utf-8') as file:
        csv_reader = csv.reader(file)
        next(csv_reader)  # skip header

        for row in csv_reader:
            img_url = row['galWebImageUrl']  # Parse image url
            img_name = img_url.split('/')[-1]  # Extract filename

            try:
                response = requests.get(img_url)
                response.raise_for_status()
            except requests.exceptions.RequestException as e:
                print(f"Error downloading image from {img_url}: {e}")
                continue

            # Save
            with open(f"{save_folder}/{img_name}", 'wb') as img_file:
                img_file.write(response.content)


### Spark Session

In [None]:
# SparkSession
spark = (
    SparkSession.builder
        .master('local[*]')
        .appName('K-TourImageCaption')
        .getOrCreate()
    )

### Exploration metadata

In [None]:
df_meta = spark.read.csv(fname_metadata, header=True, inferSchema=True)
df_meta.printSchema()

root
 |-- galContentId: integer (nullable = true)
 |-- galContentTypeId: integer (nullable = true)
 |-- galTitle: string (nullable = true)
 |-- galWebImageUrl: string (nullable = true)
 |-- galCreatedtime: long (nullable = true)
 |-- galModifiedtime: long (nullable = true)
 |-- galPhotographyMonth: integer (nullable = true)
 |-- galPhotographyLocation: string (nullable = true)
 |-- galPhotographer: string (nullable = true)
 |-- galSearchKeyword: string (nullable = true)



### Load images to spark dataframe

In [None]:
# Load images to spark dataframe
df_img_raw = spark.read.format('image').load(save_folder)

df_img_raw.printSchema()

root
 |-- image: struct (nullable = true)
 |    |-- origin: string (nullable = true)
 |    |-- height: integer (nullable = true)
 |    |-- width: integer (nullable = true)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)



### Image <-> Bytes

In [None]:
def decode(data: bytearray, w: int, h: int, c: int = 3) -> Image:
    mode = 'L' if c == 1 else 'RGB'
    # img = Image.open(io.BytesIO(bytes(data)), mode=mode)
    nparr = (
        np
        .frombuffer(bytes(data), np.uint8)
        .reshape((h,w,c))
    )
    img = Image.fromarray(nparr, mode=mode)
    return img

def encode(img: Image) -> bytearray:
    nparr = np.array(img).flatten()
    return bytearray(nparr)

## Pre-processing

### Validate images

In [None]:
@udf(returnType=BooleanType())
def validate_image(image_path) -> bool:
    try:
        # Trim first 7 chars "file://" and load.
        img = Image.open(image_path[7:])
        # An error will occur if it's broken.
        img.verify()
        return True
    except :
        # if return is False, it will be dropped.
        return False

previous_size = df_img_raw.count()

# Drop broken images
df_valid = df_img_raw.filter(validate_image(col('image.origin')))
print(f'Previous dataframe\'s size : {previous_size}')
print(f'Filtered dataframe\'s size : {df_valid.count()}')

Previous dataframe's size : 15
Filtered dataframe's size : 8


### Indexing

In [None]:
@udf(returnType=StringType())
def parse_id(filepath : str) -> str:
    # Trim out the path and ext.
    return filepath.split('/')[-1][:-4]

df_indexed = (
    df_valid
        .withColumn('id', parse_id('image.origin'))
        .withColumn('image',
                    col('image').dropFields('origin')
                    )
)

### BGR -> RGB

In [None]:
@udf(returnType=BinaryType())
def bgr2rgb(d : bytearray, w : int, h : int) -> bytearray:
    flipped_bytes = (
        np
        .frombuffer(bytes(d), np.uint8)
        .reshape((h,w,3))[...,::-1]
        .flatten()
    )
    return bytearray(flipped_bytes)

df_rgb = (
    df_indexed
    .withColumn('image',
                col('image')
                .withField('data', bgr2rgb('image.data', 'image.width', 'image.height'))
    )
)

### Resize

#### Calculate target size

In [None]:
@udf(returnType=IntegerType())
def get_min(w : int, h : int) -> int:
    return min(w,h)

# Target size to be cropped
len_min = (
    df_rgb
    .select(get_min('image.width','image.height').alias('_'))
    .select(sql_min('_'))
    .first()[0]
)

print(len_min)

532


#### Crop

In [None]:
@udf(returnType=BinaryType())
def crop(row: Row, x: int) -> bytearray:
    w,h = row['width'],row['height']
    # Calculate the cordinates to trim
    l,t,r,b = (w - x) / 2, (h - x) / 2, (w + x) / 2, (h + x) / 2
    # Load image from data
    image = decode(row['data'],w,h)
    # Crop
    image_cropped = image.crop((l,t,r,b))
    # Encode to binary data
    bin_cropped = encode(image_cropped)
    return bin_cropped

df_resized = (
    df_rgb
    .withColumn('image',
                col('image')
                .withField('data', crop('image',lit(len_min)))
                .withField('width', lit(len_min))
                .withField('height', lit(len_min))
                )
)

### Convert to grayscale (Unused)

In [None]:
@udf(returnType=BinaryType())
def grayscale(data : bytearray, size : int) -> bytearray:
    image = decode(data,size,size)
    image_gray = image.convert('L')
    result = encode(image_gray)
    return result

df_grayscale = (
    df_resized
    .withColumn('image',
                col('image')
                .withField('data',grayscale('image.data',lit(len_min)))
                .withField('nChannels',lit(1))
    )
)

### Visualize

In [None]:
def visualize():
    x = df_rgb.first()['image']
    y = df_resized.first()['image']
    z = df_grayscale.first()['image']

    display(decode(x['data'], x['width'], x['height']))
    display(decode(y['data'], y['width'], y['width']))
    display(Image.frombytes(data=bytes(z['data']),size=(z['width'],z['width']),mode='L'))

### Normalize

In [None]:
@udf(returnType=BinaryType())
def normalize(data: bytearray, w: int,h: int, mean, std) -> bytearray:
    buffer = bytes(data)
    nparr = np.frombuffer(buffer, np.uint8).reshape((h, w, 3))
    min = np.min(nparr, axis=(0,1))
    max = np.max(nparr, axis=(0,1))
    nparr_norm = ((nparr-min)/(max-min) - mean) / std
    return bytearray(nparr_norm.flatten())

# Referred to OFA example
mean, std = [0.5], [0.5]

df_norm = (
    df_resized
    .withColumn('image',
                col('image')
                .withField('data',normalize('image.data','image.width','image.height',lit(mean), lit(std)))
                )
)

### Attach human hints

In [None]:
df_joined = (
    df_norm
    .join(df_meta, df_norm.id == df_meta.galContentId, 'inner')
    .select(
        'id',
        'image',
        col('galSearchKeyword').alias('tags'),
        col('galPhotographyLocation').alias('loc')
    )
)
df_joined.printSchema()

root
 |-- id: string (nullable = true)
 |-- image: struct (nullable = true)
 |    |-- height: integer (nullable = false)
 |    |-- width: integer (nullable = false)
 |    |-- nChannels: integer (nullable = true)
 |    |-- mode: integer (nullable = true)
 |    |-- data: binary (nullable = true)
 |-- tags: string (nullable = true)
 |-- loc: string (nullable = true)



## Generate captions

#### Initialize model

In [None]:
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
from transformers.models.ofa.generate import sequence_generator

ckpt_dir = 'OFA-Sys/ofa-medium'
tokenizer_ofa = OFATokenizer.from_pretrained(ckpt_dir)
model_ofa = OFAModel.from_pretrained(ckpt_dir, use_cache=True)

#### Prepare inputs

In [None]:
import torch

def to_tensor(data: bytearray, size: int):
    nparr = (
        np
        .frombuffer(bytes(data), np.float64)
        .reshape((size,size,3))
        .transpose((2, 0, 1))
    )

    tensor = (
        torch
        .tensor(nparr)
        .unsqueeze(0)
        .to(torch.float32)
    )

    return tensor

#### Model inference

In [None]:
result_captions = []

# using the generator of fairseq version
generator = sequence_generator.SequenceGenerator(
                    tokenizer=tokenizer,
                    beam_size=3,
                    max_len_b=50,
                    min_len=0,
                    no_repeat_ngram_size=3,
                )

captions = []
for row in df_joined.collect():
    id = row['id']
    data = row['image']['data']
    size = row['image']['width']
    loc = row['loc']
    tags = row['tags']

    input_tensor = to_tensor(data, size)

    context = f'''
        What does the image describe? Explain it in as much detail as possible.
    '''
    #  if image_location != '':
    #      context += '\n' + f'The location of the photo : {image_location}'

    inputs = tokenizer([context], return_tensors="pt").input_ids

    data = {}
    data["net_input"] = {"input_ids": inputs, 'patch_images': input_tensor, 'patch_masks':torch.tensor([True])}
    gen_output = generator.generate([model], data)
    gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]

    # using the generator of huggingface version
    gen = model.generate(inputs, patch_images=input_tensor, num_beams=3, no_repeat_ngram_size=3)
    caption = tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip()

    captions.append(Row(id=id, caption=caption))

### Get response

In [None]:
df_c_temp = spark.createDataFrame(captions)
df_captioned = (
    df_joined
    .join(df_c_temp, df_joined.id == df_c_temp.id, 'inner')
)

In [None]:
df_captioned.select('tags','caption').show(vertical=True, truncate=False)

-RECORD 0------------------------------------------------------------------------------------------------------------------------------------------------
 tags    | 시루떡, 시루팥떡, 팥시루떡, 한국음식, 한식, 칠석, 명절음식, 전통음식, 한복, 젓가락                                                            
 caption | a piece of chocolate cake with a cup of tea                                                                                                   
-RECORD 1------------------------------------------------------------------------------------------------------------------------------------------------
 tags    | 식혜, 한국음식, 한식, 설날, 명절음식, 전통음료                                                                                                
 caption | i'm not sure if i 'd like to use a spoon or spoon, but                                                                                        
-RECORD 2---------------------------------------------------------------------------------------------------------------------