In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
import os
import boto3
from datasets import load_dataset
import torchvision.transforms as T
import torchvision.transforms.functional as F
from PIL import Image
import io
from transformers import CLIPTokenizerFast
from torchvision.transforms import InterpolationMode

conf = SparkConf().setAppName("PySpark Image Processing with RDDs") \
    .set("spark.executor.memory", "20g").set("spark.driver.memory", "8g")  
sc = SparkContext(conf=conf)
# Initialize Spark session
spark = SparkSession.builder \
    .appName("PySpark Image Processing with RDDs and DataFrames") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/24 12:50:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
img_dataset = load_dataset("imagefolder", data_dir="raw_data/palace/paintings/", split="train")
img_dataset

Resolving data files:   0%|          | 0/3048 [00:00<?, ?it/s]

Dataset({
    features: ['image', 'caption'],
    num_rows: 3047
})

In [3]:
rdd = sc.parallelize(list(zip(img_dataset['image'], img_dataset['caption'])), numSlices=100)

In [4]:
def resize_and_pad_then_resize(img, final_size=(512, 512), padding_mode='constant', fill=0):
    """
    Resize an image to make its longest side equal to the original image's longest side,
    pad the shorter side to make the image a square, then resize to final_size.

    Args:
        img (PIL.Image): The image to resize and pad.
        final_size (tuple): The desired output size (height, width).
        padding_mode (str): Type of padding. Options include 'constant', 'edge', etc.
        fill (int, tuple): Pixel fill value for constant padding. Can be int or tuple.

    Returns:
        PIL.Image: The resized and padded, then resized image.
    """
    original_width, original_height = img.size
    max_side = max(original_width, original_height)

    # Determine new size keeping aspect ratio
    if original_width > original_height:
        scale = max_side / original_width
        new_width = max_side
        new_height = int(original_height * scale)
    else:
        scale = max_side / original_height
        new_height = max_side
        new_width = int(original_width * scale)

    # Resize the image to max_side to keep aspect ratio
    img = F.resize(img, (new_height, new_width), interpolation=InterpolationMode.LANCZOS)

    # Calculate padding amounts
    pad_width = (max_side - new_width) // 2
    pad_height = (max_side - new_height) // 2

    # Apply padding to make it a square
    img = F.pad(img, [pad_width, pad_height, pad_width, pad_height], padding_mode=padding_mode, fill=fill)

    # Final resize to the desired output size
    img = F.resize(img, final_size, interpolation=InterpolationMode.LANCZOS)
    return img


def preprocess_image(image, caption):
    # Setup the transformation pipeline with the updated function
    transform_pipeline = T.Compose([
        T.Lambda(lambda img: resize_and_pad_then_resize(img, final_size=(512, 512), padding_mode='constant', fill=0)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.ToPILImage()
    ])
    processed_image = transform_pipeline(image)

    # Tokenization
    tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
    tokens = tokenizer.encode(caption, max_length=77, truncation=True, return_tensors="pt").tolist()[0]

    return (processed_image, tokens)

In [5]:
processed_rdd = rdd.map(lambda x: preprocess_image(x[0], x[1]))

In [6]:
processed_rdd.take(1)

24/04/24 12:51:03 WARN TaskSetManager: Stage 0 contains a task of very large size (23855 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

[(<PIL.Image.Image image mode=RGB size=512x512>,
  [49406, 6046, 29263, 593, 5493, 47986, 49407])]

In [7]:
def image_to_binary(pil_image):
    byte_arr = io.BytesIO()
    pil_image.save(byte_arr, format='JPEG')
    return byte_arr.getvalue()

processed_rdd_2 = processed_rdd.map(lambda x: (image_to_binary(x[0]), x[1]))
processed_rdd_2.take(1)

24/04/24 12:51:08 WARN TaskSetManager: Stage 1 contains a task of very large size (23855 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

[(b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x02\x00\x02\x00\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&\'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd

In [8]:
from pyspark.sql.types import StructType, StructField, BinaryType, ArrayType, IntegerType

schema = StructType([
    StructField("image", BinaryType(), True),
    StructField("input_ids", ArrayType(IntegerType()), True)
])
df = spark.createDataFrame(processed_rdd_2, schema)

In [9]:
df.show()

24/04/24 12:51:15 WARN TaskSetManager: Stage 2 contains a task of very large size (23855 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

+--------------------+--------------------+
|               image|           input_ids|
+--------------------+--------------------+
|[FF D8 FF E0 00 1...|[49406, 6046, 292...|
|[FF D8 FF E0 00 1...|[49406, 6046, 292...|
|[FF D8 FF E0 00 1...|[49406, 6046, 292...|
|[FF D8 FF E0 00 1...|[49406, 6046, 292...|
|[FF D8 FF E0 00 1...|[49406, 6046, 292...|
|[FF D8 FF E0 00 1...|[49406, 4652, 328...|
|[FF D8 FF E0 00 1...|[49406, 4652, 328...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 5979, 796...|
|[FF D8 FF E0 00 1...|[49406, 59

In [None]:
from datasets import Dataset, Features, Image, Value, Sequence


features = Features({"image": Image(), "input_ids": Sequence(Value("int64"))})
dataset = Dataset.from_spark(df, features=features)
dataset[0]

In [None]:
from torchvision.transforms.functional import to_tensor
import torch


def image_to_tensor(examples):
    # examples are a batch of 4 images
    # we apply the transformation (reference above for what it transfomed to)
    # then apply the tokenization
    examples["pixel_values"] = [to_tensor(image) for image in examples["image"]]
    return examples


train_set = dataset.map(image_to_tensor, remove_columns=['image'], batched=True)

In [None]:
# Save as parquet
import pyarrow.parquet as pq
pq.write_table(train_set.data.table,'test.parquet')

In [None]:
train_set_read_from_parquet = Dataset(pq.read_table('test.parquet'))
train_set_read_from_parquet

In [None]:
import torch
print(type(train_set_read_from_parquet[0].get('pixel_values'))) 
new_img = train_set_read_from_parquet[0].get('pixel_values')
img_tensor = torch.tensor(new_img)
print(img_tensor.shape)
tensor = img_tensor.squeeze(0)
unloader = T.ToPILImage()
image = unloader(tensor)
image