In [1]:
import torch
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

import numpy as np
import pandas as pd
import glob
import os
from PIL import Image

### Preparation of dataset

In [95]:
def create_dataset(texts_path, images_path):
  """Create dataframe with path to image and corresponding text description"""

  texts_df = pd.read_csv(texts_path)
  texts_df['text'] = texts_df['color'] + " " + texts_df['name'] + " " + texts_df['description']
  texts_df = texts_df[['Unnamed: 0','text']]
  texts_df['product'] = np.arange(len(texts_df))
    
  df = pd.DataFrame(columns=["Image","Text"])  
    
  for image in glob.glob(images_path):
    img_name = os.path.basename(image)
    key_img_name = img_name.split('_')[0]
    img_descr = texts_df[texts_df['Unnamed: 0']==int(key_img_name)].iloc[0,1:]
    df = df.append({'Image': img_name, 'Text':img_descr[0], 'Product':img_descr[1]}, ignore_index=True)
    
  return df, df['Product'].unique()

In [115]:
class CustomImageLoader:
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = annotations_file
        self.img_dir = img_dir
        self.transform = transform

    def getbatch(self, prod_idx):
        batch = []
        sliced_indices = self.img_labels[self.img_labels['Product'].isin(prod_idx)].index

        for i in sliced_indices:
            img_path = os.path.join(self.img_dir, self.img_labels.iloc[i, 0])
            image = Image.open(img_path)
            label, product = self.img_labels.iloc[i, 1], self.img_labels.iloc[i, 2]
            if self.transform:
                image = self.transform(image)
            batch.append((image,label,product))
            
        unzipped = list(zip(*batch))
        
        return unzipped[0], unzipped[1], unzipped[2]

In [None]:
b_1 = (t, im1, im2, im3)
b_2 = (t2, i1, i2)

emb_1 = (t_embs(t), ...)  # (4, emb_size)
emb_2 = ... #(3, emb_size)


B = (b_1, b_2, b_3, b_4, b_5) # (batch_size, 4, emb_size)


In [116]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transformation of images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224,224)),
    transforms.Normalize((0.5,), (0.5,))  # Scale images to [-1, 1]
])

# Data paths
descriptions_data = "./processed_data/processedSKUs_nodups.csv"
images_folder = "./processed_data/images/*.jpg"
img_dir = "./processed_data/images/"

# Creation of organized dataframe
annotations_file, products = create_dataset(descriptions_data, images_folder)

# Creation of custom Batch Loader, where batch contains images, belonging to same product
dataset = CustomImageLoader(annotations_file, img_dir, transform=transform)

# Creation of batches of products
batch_size = 10
products_groups = [products[i:i + batch_size] for i in range(0, len(products), batch_size)]

In [119]:
for i in products_groups:
    X, Y, P = dataset.getbatch(i)
    break