In [1]:
import torch
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

import numpy as np
import pandas as pd
import glob
import os
from PIL import Image

### Preparation of dataset

In [95]:
def create_dataset(texts_path, images_path):
  """Create dataframe with path to image and corresponding text description"""

  texts_df = pd.read_csv(texts_path)
  texts_df['text'] = texts_df['color'] + " " + texts_df['name'] + " " + texts_df['description']
  texts_df = texts_df[['Unnamed: 0','text']]
  texts_df['product'] = np.arange(len(texts_df))
    
  df = pd.DataFrame(columns=["Image","Text"])  
    
  for image in glob.glob(images_path):
    img_name = os.path.basename(image)
    key_img_name = img_name.split('_')[0]
    img_descr = texts_df[texts_df['Unnamed: 0']==int(key_img_name)].iloc[0,1:]
    df = df.append({'Image': img_name, 'Text':img_descr[0], 'Product':img_descr[1]}, ignore_index=True)
    
  return df, df['Product'].unique()

In [115]:
class CustomImageLoader:
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = annotations_file
        self.img_dir = img_dir
        self.transform = transform

    def getbatch(self, prod_idx):
        batch = []
        sliced_indices = self.img_labels[self.img_labels['Product'].isin(prod_idx)].index

        for i in sliced_indices:
            img_path = os.path.join(self.img_dir, self.img_labels.iloc[i, 0])
            image = Image.open(img_path)
            label, product = self.img_labels.iloc[i, 1], self.img_labels.iloc[i, 2]
            if self.transform:
                image = self.transform(image)
            batch.append((image,label,product))
            
        unzipped = list(zip(*batch))
        
        return unzipped[0], unzipped[1], unzipped[2]

In [None]:
b_1 = (t, im1, im2, im3)
b_2 = (t2, i1, i2)

emb_1 = (t_embs(t), ...)  # (4, emb_size)
emb_2 = ... #(3, emb_size)


B = (b_1, b_2, b_3, b_4, b_5) # (batch_size, 4, emb_size)


In [116]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transformation of images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224,224)),
    transforms.Normalize((0.5,), (0.5,))  # Scale images to [-1, 1]
])

# Data paths
descriptions_data = "./processed_data/processedSKUs_nodups.csv"
images_folder = "./processed_data/images/*.jpg"
img_dir = "./processed_data/images/"

# Creation of organized dataframe
annotations_file, products = create_dataset(descriptions_data, images_folder)

# Creation of custom Batch Loader, where batch contains images, belonging to same product
dataset = CustomImageLoader(annotations_file, img_dir, transform=transform)

# Creation of batches of products
batch_size = 10
products_groups = [products[i:i + batch_size] for i in range(0, len(products), batch_size)]

In [119]:
for i in products_groups:
    X, Y, P = dataset.getbatch(i)
    break

## Training Boilerplate

In [None]:
from model import FullEncoder, ImageEncoder, TextEncoder
from loss import OurLossFunction
from tqdm import tqdm

In [2]:
class JewelryClassifier:
    def __init__(self, emb_size, grad_clip=2, *loss_args):
        self.grad_clip = grad_clip
        self.im_enc = ImageEncoder(emb_size)
        self.txt_enc = TextEncoder(emb_size)
        if torch.cuda.is_available():
            self.im_enc.cuda()
            self.txt_enc.cuda()
        self.criterion = OurLossFunction(*loss_args)
        params = self.txt_enc.fc.parameters()
        params += self.im_enc.fc.parameters()
        # The image cnn is fine-tuned but roberta is not.
        params += self.im_enc.cnn.parameters()
        self.params = params
        self.optimizer = torch.optim.Adam(params, lr=learning_rate)
        self.step = 0
    
    def state_dict(self):
        state_dict = [self.im_enc.state_dict(), self.txt_enc.state_dict()]
        return state_dict

    def load_state_dict(self, state_dict):
        self.im_enc.load_state_dict(state_dict[0])
        self.txt_enc.load_state_dict(state_dict[1])
    
    def save(self, path):
        torch.save(self.state_dict(), path)
    
    def on_stage_start(self, stage):
        if stage == "TRAIN":
            self.im_enc.train()
            self.txt_enc.train()
        elif stage == "VALID":
            self.im_enc.eval()
            self.txt_enc.eval()
    
    def train(self):
        return self.on_stage_start("TRAIN")
    
    def eval(self):
        return self.on_stage_start("VALID")
    
    @staticmethod
    def to_tensor(x, grad=True):
        if not isinstance(x, torch.Tensor):
            x = torch.Tensor(x, requires_grad=grad)
        if torch.cuda.is_available():
            x = x.cuda()
        return x
    
    def forward_emb(self, imgs, texts, lengths, grad=True):
        """Compute the image and text embeddings
        """
        # Set mini-batch dataset
        imgs = self.to_tensor(imgs, grad)
        txts = self.to_tensor(txts, grad)

        # Forward
        imgs_emb = self.im_enc(imgs)
        txts_emb = self.txt_enc(txts, lengths)
        return imgs_emb, txts_emb

    def forward_loss(self, imgs_emb, txts_emb):
        """Compute the loss given pairs of image and text embeddings
        """
        loss = self.criterion(imgs_emb, txts_emb)
        return loss

    def train_emb(self, imgs, txts, lengths):
        """One training step given images and captions.
        """
        self.step += 1

        # compute the embeddings
        imgs_emb, txts_emb = self.forward_emb(imgs, txts, lengths)

        # measure accuracy and record loss
        self.optimizer.zero_grad()
        loss = self.forward_loss(imgs_emb, txts_emb)

        # compute gradient and do SGD step
        loss.backward()
        if self.grad_clip > 0:
            clip_grad_norm(self.params, self.grad_clip)
        self.optimizer.step()
        return loss.item()

In [None]:
def train(train_loader, val_loader, n_epochs, emb_size, grad_clip=2, *loss_args):
    model = JewelryClassifier(emb_size, grad_clip, *loss_args)
    pbar = tqdm(range(1, epochs+1))
    avg_loss = "inf"
    val_loss = "inf"
    def format_desc():
        pbar.set_description(f"AVG Loss={avg_loss}, Train Loss={loss}, Valid Loss={val_loss}  (epoch {epoch})")
    for epoch in range(1, epochs+1):
        model.train()
        # TODO: Add learning rate scheduler
        losses = []
        pbar = tqdm(enumerate(train_loader))
        for i, (im_batch, txt_batch, txt_lengths) in pbar:
            loss = model.train_emb(im_batch, txt_batch, txt_lengths)
            losses.append(loss)
            format_desc()
        avg_loss = loss = sum(losses)/len(losses)
        format_desc()
        val_losses = []
        for i, (im_batch, txt_batch, txt_lengths) in enumerate(val_loader):
            model.eval()
            im_emb, txt_emb = model.forward_emb(im_batch, txt_batch, txt_lengths, grad=False)
            val_loss = model.forward_loss(im_emb, txt_emb)
            val_losses.append(val_loss)
            format_desc()
        val_loss = sum(val_losses)/len(val_losses)
        format_desc()
        # TODO Checkpointing
    # Save model
    model.save()