In [1]:
import torch
import torch.nn as nn
import torch.nn.init
import torchvision.models as models
from torch.nn.utils.clip_grad import clip_grad_norm
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from torchtext.models import ROBERTA_BASE_ENCODER
from torchtext.functional import to_tensor

import numpy as np
import numpy.ma as ma
import pandas as pd
import glob
import os
from PIL import Image
from tqdm import tqdm
import warnings
import random

In [2]:
warnings.filterwarnings('ignore')

In [3]:
# Fixed Seed. Credits: https://clay-atlas.com/us/blog/2021/08/24/pytorch-en-set-seed-reproduce/
seed = 1001
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

### 1. Preparation of dataset

The unified dataset with rows for each pair of image names and corresponding texts is made.

In [4]:
def create_dataset(texts_path, images_path):
    """Create dataframe with path to image and corresponding text description"""

    texts_df = pd.read_csv(texts_path)
    texts_df['text'] = texts_df['color'] + " " + texts_df['name'] + " " + texts_df['description']
    texts_df = texts_df[['Unnamed: 0','text']]
    texts_df['product'] = np.arange(len(texts_df))
    df = pd.DataFrame(columns=["Image_id", "Image_name","Text","Product_id"])  
    
    for i, image in enumerate(glob.glob(images_path)):
        img_name = os.path.basename(image)
        key_img_name = img_name.split('_')[0]
        img_descr = texts_df[texts_df['Unnamed: 0']==int(key_img_name)].iloc[0,1:]
        df = df.append({'Image_id':i, 'Image_name': img_name, 'Text':img_descr[0], 'Product_id':img_descr[1]}, ignore_index=True)
    
    return df, df['Product_id'].unique()

The class CustomImageLoader helps to build custom loader of the training data in the form of the batches. It takes into account that all the images, corresponding to the same product have to be in the same batch. 

In [5]:
class CustomImageLoader:
    """Creation of batches of images, texts and their ids."""
    def __init__(self, annotations_file, img_dir, transform=None, batch_size=10):
        self.img_labels = annotations_file
        self.img_dir = img_dir
        self.transform = transform
        self.batch_size = batch_size

    def getbatch(self, prod_idx):
        batch = []
        sliced_indices = self.img_labels[self.img_labels['Product_id'].isin(prod_idx)].index

        for i in sliced_indices:
            img_path = os.path.join(self.img_dir, self.img_labels.iloc[i, 1])
            image = Image.open(img_path)
            image_id, text, product_id = self.img_labels.iloc[i, 0], self.img_labels.iloc[i, 2], self.img_labels.iloc[i, 3]
            if self.transform:
                image = self.transform(image)
            batch.append((image,text,(image_id, product_id)))
            
        unzipped = list(zip(*batch))
        
        return unzipped[0], unzipped[1], unzipped[2]
    
    def __getitem__(self, prod_idx):
        return self.getbatch(prod_idx)
    
    def __iter__(self):
        products_groups = [products[i:i + self.batch_size] for i in range(0, len(products), self.batch_size)]
        for i in products_groups:
            yield self.getbatch(i)

Preparation steps for passing data to dataloader are made, as well as, the constant variables for the whole program are defined.

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transformation of images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(size=(224,224)),
    transforms.Normalize((0.5,), (0.5,))  # Scale images to [-1, 1]
])

# Data paths (Arina)
# descriptions_data = "./processed_data/processedSKUs_nodups.csv"
# images_folder = "./processed_data/images/*.jpg"
# img_dir = "./processed_data/images/"
# George
descriptions_data = "../processedSKUs_nodups.csv"
images_folder = "../images_onlyids/*.jpg"
img_dir = "../images_onlyids/"

# Creation of organized dataframe
annotations_file, products = create_dataset(descriptions_data, images_folder)

# Creation of true constant adjacency matrix, where rows - products, columns - images. 
ADJACENCY_TRUE = pd.crosstab(annotations_file.Product_id, annotations_file.Image_id)

# Creation of custom Batch Loader, where batch contains images, belonging to same product
dataset = CustomImageLoader(annotations_file, img_dir, transform=transform)

# Creation of batches of products. For each product there are 2/3 images, so the actual batch size is ~20-30 pairs.
prod_batch_size = 10
products_groups = [products[i:i + prod_batch_size] for i in range(0, len(products), prod_batch_size)]

### 2. Definition of classes for image, text encoders and supporting functions

In [7]:
def normalization(X):
    """L2-normalization of features columns"""
    norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt()
    X = torch.div(X, norm)
    
    return X

In [8]:
class ImageEncoder(nn.Module):
    def __init__(self, embedding_size, cnn_type):
        """Initializing parameters"""
        super(ImageEncoder, self).__init__()
        self.embedding_size = embedding_size # Size of projected image
        self.cnn = self.load_cnn(cnn_type)

        # No need to finetune parameters = frozen layers
        for param in self.cnn.parameters():
            param.requires_grad = False

        # Replacing last fully connected layer with new one
        self.fc = nn.Linear(self.cnn.classifier._modules['6'].in_features, embedding_size)
        self.cnn.classifier = nn.Sequential(*list(self.cnn.classifier.children())[:-1])

        # Initializing the weights of fully-connected layer, which makes projection to new space
        self.initialization_weights()
  
    def load_cnn(self, cnn_type):
        """Loading pretrained model"""
        model = models.__dict__[cnn_type](pretrained=True)

        return model

    def initialization_weights(self):
        """Xavier initialization"""
        r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
        self.fc.weight.data.uniform_(-r, r)
        self.fc.bias.data.fill_(0)

    def forward(self, X):
        """Creation of features"""
        # Creation of embeddings
        features = self.cnn(X)

        # Normalization of embeddings
        features = normalization(features)

        # Projection to new space
        features = self.fc(features)

        return features

In [9]:
ROBERTA_OUT_DIM = 768

class TextEncoder(nn.Module):

    def __init__(self, embedding_size):
        """Initializing parameters"""
        super(TextEncoder, self).__init__()
        self.embedding_size = embedding_size # Size of projected text
        self.roberta = ROBERTA_BASE_ENCODER.get_model()
        self.transform = ROBERTA_BASE_ENCODER.transform()

        # Linear layer
        self.fc = nn.Linear(ROBERTA_OUT_DIM, embedding_size)

        # Initializing the weights of fully-connected layer, which makes projection to new space
        self.initialization_weights()
        
    def _roberta_encode(self, batch):
        transformed = self.transform(batch)
        model_input = to_tensor(transformed, padding_value=1)
        return self.roberta(model_input)

    def initialization_weights(self):
        """Xavier initialization"""
        r = np.sqrt(6.) / np.sqrt(self.fc.in_features + self.fc.out_features)
        self.fc.weight.data.uniform_(-r, r)
        self.fc.bias.data.fill_(0)

    def forward(self, X, lengths=None):
        """Creation of features"""
        # Creation of embeddings
        features = self._roberta_encode(list(X))

        # Normalization of embeddings
        features = normalization(features)

        # Projection to new space
        features = self.fc(features)

        return features

### 3. Definition of loss class

In [10]:
def similarity_score(img, text):
    """Similarity calculation"""
    
    return text.mm(img.t())

In [11]:
class CombinedLoss(nn.Module):
    """Class for combined loss"""

    def __init__(self, margin=0, lambda_coeff=0.1):
        super(CombinedLoss, self).__init__()
        self.margin = margin
        self.lambda_coeff = lambda_coeff
        # self.sim = similarity_score
    
    def _calc_penalty(self, similarities):
        if isinstance(similarities, torch.Tensor):
            similarities = similarities.detach().numpy()
        adj_tmp = np.zeros_like(similarities)
        adj_tmp[similarities.argmax(0), np.arange(similarities.shape[1])] = 1
        images_mapped = adj_tmp.sum(1)
        assert images_mapped.sum() == similarities.shape[1], f"{images_mapped.shape=}, {similarities.shape=}"
        penalty = self.lambda_coeff * len(np.where(images_mapped==0)[0])
        return penalty

    def forward(self, image_emb, label_emb, img_product_ids):
        # Image_emb - batch_size x embedding size
        # Label_emb - number of unique products from batch x embedding size
        # Img_product_ids - batch_size x 2, where in each row there is (image_index, product_index).
        batch_size = image_emb.shape[0]
        products_num = label_emb.shape[0]

        # Slicing from true adjacency matrix, prepared beforehand
        adj_true = ADJACENCY_TRUE.iloc[img_product_ids[:,1]]
        adj_true = adj_true.iloc[:,img_product_ids[:,0]]
        
        adj_true = adj_true[~adj_true.index.duplicated(keep="first")]
        # print(adj_true.sum(0).values)
        # Creation of similarity matrix of size - number of unique products from batch x batch_size
        sim_matrix = similarity_score(image_emb, label_emb)
        
        # Calculation of loss for each product
        # 1st - For each product minimum similarity with true images is calculated
        adj_true = torch.Tensor(adj_true.values)
        product_true = sim_matrix*adj_true
        product_min_sims = product_true.masked_fill(product_true == 0, np.inf).min(dim=1)[0]
        
        # 2nd - For each product maximum similarity score with image, not belonging to it
        adj_true_inv = torch.where(adj_true==0, 1, 0)
        product_false = sim_matrix*adj_true_inv
        product_max_sims, product_max_sims_indices = product_false.max(axis=1)
        
        # 3d - For each product maximum similarity between its images and products, other than the true one, is found
        not_product_max_sims = np.array([])
        for p in range(products_num):
            sim_p = (sim_matrix[np.arange(products_num)!=p,:]*adj_true[p,:]).max()
            not_product_max_sims = np.append(not_product_max_sims,sim_p.detach())
        not_product_max_sims = torch.Tensor(not_product_max_sims)
        
        # Caluculation of mean loss
        first_hinge = self.margin+not_product_max_sims - product_min_sims
        second_hinge = self.margin+product_max_sims - product_min_sims
        loss = torch.where(first_hinge<0,0,first_hinge)+torch.where(second_hinge<0,0,second_hinge)
        loss = torch.sum(loss)/products_num
        #Regularizer (this is just a single value)
        loss += self._calc_penalty(sim_matrix)
        return loss

### 4. Definition of class and functions for training

The JewelryClassifier class combines the procedure of creation of the embeddings for images and texts, as well as, calculates loss.

In [12]:
class JewelryClassifier:
    """Class, unifying the embeddings creation and loss calculation."""
    def __init__(self, emb_size=128, grad_clip=2, learning_rate=0.01, *loss_args, **loss_kwargs):
        print(f"{learning_rate=}, {loss_args=}")
        self.grad_clip = grad_clip
        self.im_enc = ImageEncoder(emb_size, 'vgg19')
        self.txt_enc = TextEncoder(emb_size)
        if torch.cuda.is_available():
            self.im_enc.cuda()
            self.txt_enc.cuda()
        self.criterion = CombinedLoss(*loss_args, **loss_kwargs)
        params = list(self.txt_enc.fc.parameters())
        params += list(self.im_enc.fc.parameters())
        # The image cnn is fine-tuned but roberta is not.
        params += self.im_enc.cnn.parameters()
        self.params = params
        self.optimizer = torch.optim.Adam(params, lr=learning_rate)
        self.step = 0
    
    def state_dict(self):
        state_dict = [self.im_enc.state_dict(), self.txt_enc.state_dict()]
        return state_dict

    def load_state_dict(self, state_dict):
        self.im_enc.load_state_dict(state_dict[0])
        self.txt_enc.load_state_dict(state_dict[1])
    
    def save(self, path):
        torch.save(self.state_dict(), path)
    
    def on_stage_start(self, stage):
        if stage == "TRAIN":
            self.im_enc.train()
            self.txt_enc.train()
        elif stage == "VALID":
            self.im_enc.eval()
            self.txt_enc.eval()
    
    def train(self):
        return self.on_stage_start("TRAIN")
    
    def eval(self):
        return self.on_stage_start("VALID")
    
    # @staticmethod
    # def to_tensor(x, grad=True):
    #     if not isinstance(x, torch.Tensor):
    #         x = torch.Tensor(x, requires_grad=grad)
    #     if torch.cuda.is_available():
    #         x = x.cuda()
    #     return x
    
    def forward_emb(self, imgs, txts, lengths=None, grad=True):
        """Compute the image and text embeddings"""
        # Set mini-batch dataset
        #imgs = self.to_tensor(imgs, grad)
        #txts = self.to_tensor(txts, grad)

        # Forward
        imgs_emb = self.im_enc(imgs)
        txts_emb = self.txt_enc(txts, lengths)
        return imgs_emb, txts_emb

    def forward_loss(self, imgs_emb, txts_emb, img_product_ids):
        """Compute the loss given pairs of image and text embeddings and there ids"""
        loss = self.criterion(imgs_emb, txts_emb, img_product_ids)
        return loss

    def train_emb(self, imgs, txts, img_product_ids):
        """One training step given images and captions."""
        self.step += 1

        # compute the embeddings
        imgs_emb, txts_emb = self.forward_emb(imgs, txts)

        # measure accuracy and record loss
        self.optimizer.zero_grad()
        loss = self.forward_loss(imgs_emb, txts_emb.mean(axis=1), img_product_ids)

        # compute gradient and do SGD step
        loss.backward()
        if self.grad_clip > 0:
            clip_grad_norm(self.params, self.grad_clip)
        self.optimizer.step()
        return loss.item()

## Accuracy Calculation

In [13]:
def _batch_accuracy_score(model, imgs, txts, img_product_ids):
    model.eval()
    # compute the embeddings
    imgs_emb, txts_emb = model.forward_emb(imgs, txts)
    # sim_matrix: rows=products, columns=images
    sim_matrix = similarity_score(imgs_emb, txts_emb.mean(axis=1))
    # The accuracy score is the normalized number of products that
    # have been mapped to at least one correct image and no incorrect images.
    # NOTE: This is on the batch level so it's not normalized. Check the accuracy_score
    #       function for the normalized version.
    batch_adj_true = CombinedLoss.batch_specific_adj_matrix(img_product_ids).values
    if isinstance(sim_matrix, torch.Tensor):
        sim_matrix = sim_matrix.detach().numpy()
    batch_adj_pred = np.zeros_like(sim_matrix)
    batch_adj_pred[sim_matrix.argmax(0), np.arange(sim_matrix.shape[1])] = 1
    accuracies = []
    for prod_id in range(len(batch_adj_true)):
        true_imgs = batch_adj_true[prod_id, :]
        pred_imgs = batch_adj_pred[prod_id, :]
        # If at least one incorrect image is mapped to the product
        # then the accuracy on this product is 0.
        if -1 in true_imgs - pred_imgs:
            accuracies.append(0)
            continue
        # If no image is mapped to the product, the accuracy is 0
        if len(np.where(pred_imgs==1)[0]) == 0:
            accuracies.append(0)
            continue
        accuracies.append(1)
    return accuracies

def accuracy_score(model, products_groups, data_set=dataset):
    accuracies = []
    for i, pr_group in tqdm(enumerate(products_groups)):
        im_batch, txt_batch, img_product_ids = data_set.getbatch(pr_group)
        im_batch = torch.stack(im_batch, dim=0) # Images are converted to batched tensor

        img_product_ids = np.array(img_product_ids) # Numpy array from (image,product) pairs

        labels, discovered_products = [], [] 
        for p, product in enumerate(img_product_ids[:,1]): # In that loop unique text batch is created
            if product not in discovered_products:
                labels.append(txt_batch[p])
                discovered_products.append(product)

        batch_accs = _batch_accuracy_score(model, im_batch, labels, img_product_ids)
        accuracies += batch_accs
    return accuracies

In [14]:
def train(products_groups, val_loader, epochs, emb_size, grad_clip=2, use_valid=False, lr=0.01, *loss_args, **loss_kwargs):
    """The main training function."""
    model = JewelryClassifier(emb_size, grad_clip, lr, *loss_args, **loss_kwargs)
    val_loss = "null"
    for epoch in range(1, epochs+1):
        model.train()
        losses = []
        for i, pr_group in tqdm(enumerate(products_groups)):
            im_batch, txt_batch, img_product_ids = dataset.getbatch(pr_group)
            im_batch = torch.stack(im_batch, dim=0) # Images are converted to batched tensor

            img_product_ids = np.array(img_product_ids) # Numpy array from (image,product) pairs
            
            labels, discovered_products = [], [] 
            for p, product in enumerate(img_product_ids[:,1]): # In that loop unique text batch is created
                if product not in discovered_products:
                    labels.append(txt_batch[p])
                    discovered_products.append(product)
            
            loss = model.train_emb(im_batch, labels, img_product_ids)
            losses.append(loss)
        # print(f"In epoch {epoch} the average loss is {np.mean(losses)}.")
        
        if use_valid:
            val_losses = []
            for i, (im_batch, txt_batch, txt_lengths) in enumerate(val_loader):
                model.eval()
                im_emb, txt_emb = model.forward_emb(im_batch, txt_batch, txt_lengths, grad=False)
                val_loss = model.forward_loss(im_emb, txt_emb)
                val_losses.append(val_loss)
                format_desc()
            val_loss = sum(val_losses)/len(val_losses)
            format_desc()

    torch.save(model, 'model.pth')
    return losses

## Loss plots

In [15]:
import matplotlib.pyplot as plt

In [16]:
def plot(losses, set_type="Train", out_path=None):
    plt.plot(losses)
    plt.title(f"{set_type} loss over the epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()
    if out_path is not None:
        plt.savefig(out_path)

## Train the model and get accuracy

In [None]:
train_losses = train(products_groups, None, 8, 128, 2, False, 0.01, 0.2, 0.1)

learning_rate=0.01, loss_args=(0.2, 0.1)


0it [00:01, ?it/s]


In [None]:
plot(train_losses, "Train")

In [None]:
model = torch.load("models/model.pth")

In [None]:
accs = accuracy_score(model, products_groups)
print("Final Accuracy:", accs)