In [1]:
from typing import Dict, List, Type

import pandas as pd
import torch
from sklearn.decomposition import PCA
from torch import nn
from torch.nn import functional as F
from torchvision import transforms as T
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from torch.nn.functional import normalize as normalize_emb
from torchvision import models as tv_models
from datasets import load_dataset

In [2]:
DEVICE="cuda"

In [3]:
def preprocess_data(dataset):
    # Transformations
    resize = T.Resize((224, 224))
    normalize = T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
    pil_to_image = T.PILToTensor()
    # Get models used to preprocess features
    mini_lm = SentenceTransformer("all-MiniLM-L6-v2").to(DEVICE)
    resnet = tv_models.resnet50(pretrained=True).to(DEVICE)
    resnet.eval()
    resnet.fc = nn.Identity()
    # Preprocess
    preprocessed = []
    for sample in tqdm(
        iterable=dataset,
        total=len(dataset),
        desc="Processing data",
    ):
        image = sample["image"]
        label = sample["label"]
        # There are 4 images in "L" format
        if sample["image"].mode == "L" or sample["image"].mode == "RGBA":
            continue
        image = pil_to_image(image).float().to(DEVICE)
        resized_img = resize(image)
        normalized_img = normalize(resized_img)
        for description in sample["description"].split("\n"):
            if not description:
                continue
            with torch.no_grad():
                # Added batch dim
                img_emb = resnet(normalized_img.unsqueeze(dim=0))
                text_emb = mini_lm.encode(
                    sentences=description,
                    convert_to_tensor=True,
                )
            preprocessed.append(
                {
                    "img_emb": normalize_emb(img_emb[0], dim=0), # Drop batch dim
                    "text_emb": normalize_emb(text_emb, dim=0),
                    "image_index": sample["img_index"],
                    "text": description,
                    "label": sample["label"],
                }
            )
    return pd.DataFrame(preprocessed)

## Preprocess Cube

In [4]:
!mkdir data/cub/

mkdir: cannot create directory ‘data/cub/’: File exists


In [5]:
dataset = load_dataset("alkzar90/CC6204-Hackaton-Cub-Dataset")

Found cached dataset cc6204-hackaton-cub-dataset (/home/erthax/.cache/huggingface/datasets/alkzar90___cc6204-hackaton-cub-dataset/default/0.0.0/de850c9086bff0dd6d6eab90f79346241178f65e1a016a50eec240ae9cdf2064)


  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# Orginal type is not mutable
dataset = {
    "train": list(dataset["train"]),
    "test": list(dataset["test"]),
}

In [7]:
for split in ["train", "test"]:
    for index, sample in enumerate(dataset[split]):
        sample["img_index"] = index

In [9]:
preprocessed_train = preprocess_data(dataset["train"])

Processing data: 100%|██████████████████████| 5994/5994 [05:06<00:00, 19.58it/s]


In [10]:
preprocessed_train.to_pickle("data/cub/preprocessed_train.pkl")

In [11]:
preprocessed_test = preprocess_data(dataset["test"])

Processing data: 100%|██████████████████████| 5794/5794 [04:49<00:00, 20.04it/s]


In [12]:
preprocessed_test.to_pickle("data/cub/preprocessed_test.pkl")

## Preprocess Hatefull Meme

In [4]:
!mkdir data/heatfull_meme/

mkdir: cannot create directory ‘data/heatfull_meme/’: File exists


In [5]:
train = pd.read_json(path_or_buf="data/heatfull_meme/data/train.jsonl", lines=True)
test = pd.read_json(path_or_buf="data/heatfull_meme/data/dev.jsonl", lines=True)

In [6]:
train = train.to_dict("records")
test = test.to_dict("records")

In [7]:
for split in [train, test]:
    for sample in split:
        sample["img_index"] = sample.pop("id")
        sample["description"] = sample.pop("text")

In [8]:
from PIL import Image

In [9]:
import copy

In [10]:
for sample in train:
    img = Image.open(f"data/heatfull_meme/data/{sample['img']}")
    sample["image"] = copy.deepcopy(img)
    img.close()

In [11]:
for sample in test:
    img = Image.open(f"data/heatfull_meme/data/{sample['img']}")
    sample["image"] = copy.deepcopy(img)
    img.close()

In [12]:
preprocessed_train = preprocess_data(train)

Processing data: 100%|█████████████████████| 8500/8500 [00:53<00:00, 159.55it/s]


In [13]:
preprocessed_train.to_pickle("data/heatfull_meme/preprocessed_train.pkl")

In [14]:
preprocessed_test = preprocess_data(test)

Processing data: 100%|███████████████████████| 500/500 [00:03<00:00, 153.86it/s]


In [15]:
preprocessed_test.to_pickle("data/heatfull_meme/preprocessed_test.pkl")