In [1]:
import os
import clip
import torch
import pickle
import numpy as np
import torch.nn as nn
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader, BatchSampler

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# load CLIP's preprocess function
_, preprocess = clip.load("ViT-B/32", device=device)

cuda


In [3]:
import load_dataset

os.environ['EXTRACT_DIR'] = "/tmp/GLAMI-1M/"
load_dataset.download_dataset(dataset_url="https://huggingface.co/datasets/glami/glami-1m/resolve/main/GLAMI-1M-dataset--test-only.zip")

Downloading: 100%|█████████████████████████| 1.39G/1.39G [01:01<00:00, 22.8MB/s]


Unzipping


100%|████████████████████████████████| 116008/116008 [00:10<00:00, 11393.82it/s]


## Preprocess data

In [4]:
df = load_dataset.get_dataframe('test')[['item_id', 'image_id', 'name', 'description', 'category_name', 'image_file']].copy()
print(f"number of products: {len(df)}")

number of products: 116004


In [5]:
# only need `category_name` and `image_file` feature
# drop duplicate image 
df = df.drop_duplicates(subset=['image_file'])[['category_name', 'image_file']].reset_index(drop=True)
print(f"number of images: {len(df)}")

number of images: 85577


In [6]:
df_sample = df.sample(13000, random_state=41)
print(f"number of training images: {df_sample['image_file'].nunique()}")
print(f"number of training texts: {df_sample['category_name'].nunique()}")

number of training images: 13000
number of training texts: 186


In [7]:
# create more meaningful prompt for category name
category_name_to_prompt = {}
category_names = df_sample['category_name'].unique()
for category_name in category_names:
    human_readable_category_name = (category_name.strip()
                                    .replace('women-s', "women's").replace('womens', "women's")
                                    .replace('men-s', "men's").replace('mens', "men's").replace('-', ' ')
                                    .replace(' and ', ' or '))
    prompt = ("A photo of a " + human_readable_category_name + ", a type of fashion product")
    category_name_to_prompt[category_name] = prompt
    
df_sample['prompt'] = df_sample['category_name'].apply(lambda category_name: category_name_to_prompt[category_name])

In [8]:
# each category corresponds to an independent label
category_name_to_label = {}
for label, category_name in enumerate(category_names):
    category_name_to_label[category_name] = label
    
df_sample['label'] = df_sample['category_name'].map(category_name_to_label)

In [9]:
images_list = []
prompt_list = []
label_list = []

for row in df_sample.itertuples(index=True):
    image = Image.open(row.image_file).convert("RGB")
    prompt = row.prompt
    label = row.label

    images_list.append(image)
    prompt_list.append(prompt)
    label_list.append(label)

In [10]:
# 11000 training samples
# 2000 test samples
images_list_train = images_list[:11000]
images_list_test = images_list[11000:]

prompt_list_train = prompt_list[:11000]
prompt_list_test = prompt_list[11000:]

label_list_train =  label_list[:11000]
label_list_test = label_list[11000:]

## Create `DataLoader`

In [11]:
BATCH_SIZE = 32
class FashionDataset(Dataset):
    def __init__(self, images_list, prompt_list, label_list, preprocess):
        self.images_list = images_list
        self.prompt_list = prompt_list
        self.label_list = label_list
        self.preprocess = preprocess

    def __len__(self):
        return len(self.images_list)

    def __getitem__(self, idx):
        image = self.images_list[idx]
        image_tensor = self.preprocess(image)

        prompt = self.prompt_list[idx]
        prompt_token = clip.tokenize([prompt])[0]

        label = self.label_list[idx]
        return image_tensor, prompt_token, label

In [12]:
train_dataset = FashionDataset(images_list_train, prompt_list_train, label_list_train, preprocess)
test_dataset = FashionDataset(images_list_test, prompt_list_test, label_list_test, preprocess)

In [13]:
# ensures no same label per batch

class BalancedBatchSampler(BatchSampler):
    """
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size
    
train_labels = torch.tensor(label_list_train)
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader_sample_batch = DataLoader(train_dataset, batch_sampler=train_sampler)

test_labels = torch.tensor(label_list_test)
test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
test_dataloader_sample_batch = DataLoader(test_dataset, batch_sampler=test_sampler)

In [14]:
# save data loader for training
with open('data_loader/train_dataloader.pkl', 'wb') as f:
    pickle.dump(train_dataloader_sample_batch, f)
    
with open('data_loader/test_dataloader.pkl', 'wb') as f:
    pickle.dump(test_dataloader_sample_batch, f)