# Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, IterableDataset

from torch import Tensor

from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from transformers import ViTFeatureExtractor, ViTModel
import timm

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import time
import os
import copy
import json
import math
import random

from transformers import BertTokenizer, BertModel
import pickle

# making sure that the whole embedding tensor is printed in output
torch.set_printoptions(threshold=10_000)

In [2]:
torch.__version__

'1.11.0+cu102'

In [3]:
! set CUDA_VISIBLE_DEVICES = '0,1,2,3'

In [7]:
# making sure the feature extraction runs on GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Loading data

## For end-to-end

In [5]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, txt_file, img_file, img_dir_path, transforms):
        self.text_file = txt_file
        self.image_file = img_file
        self.ingredients = []
        self.txt_ids = []
        self.img_id_map = {}

        self.img_dir = img_dir_path
        self.transforms = transforms

        for row in self.text_file:
            id = row["id"]
            self.txt_ids.append(id)
            # title = row["title"]
            # ingredients = row["ingredients"]
            instructions = row["instructions"]

            ingredient_text = ""
            # instructions_text = ""

            ingredient_text = " ".join(instruction["text"] for instruction in instructions)

            self.ingredients.append(ingredient_text)

        for row in self.image_file:
            self.img_id_map[row["id"]] = row["images"][0]


    def __len__(self):
        return len(self.ingredients)
    
    def __getitem__(self, idx):
        text = self.ingredients[idx]
        image_file = self.img_id_map[self.txt_ids[idx]]
        image_path = self.img_dir + image_file[0] + "/" + image_file[1] + "/" + image_file[2] + "/" + image_file[3] + "/" + image_file
        img = Image.open(image_path)
        img = self.transforms(img)
        return img, text

In [6]:
# transforming each image
data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

train_dataset = Dataset(json.load(open("/common/users/kcm161/data/train/text.json")), json.load(open("/common/users/kcm161/data/train/image.json")), \
    "/common/users/kcm161/train/", data_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=False)

val_dataset = Dataset(json.load(open("/common/users/kcm161/data/val/text.json")), json.load(open("/common/users/kcm161/data/val/image.json")), \
    "/common/users/kcm161/val/", data_transforms)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False)

test_dataset = Dataset(json.load(open("/common/users/kcm161/data/test/text.json")), json.load(open("/common/users/kcm161/data/test/image.json")), \
    "/common/users/kcm161/test/", data_transforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

## For triplet finetuning

In [12]:
triplet_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=False)

# Triplet Finetuning

In [13]:
def generate_triplets(images, texts):
    triplet_images = torch.zeros((16, 3, 224, 224))
    triplet_pos_text = []
    triplet_neg_text = []
    
#     print(images.shape[0])
    
    for i in range(images.shape[0]):
        triplet_images[i] = images[i]
        triplet_pos_text.append(texts[i])
        
#     print(len(triplet_pos_text))
        
    for i in range(images.shape[0]):
        neg_idx = random.randint(0, images.shape[0] - 1)
#         print(neg_idx)
        while neg_idx == i:
            neg_idx = random.randint(0, images.shape[0] - 1)
#             print(neg_idx)

        triplet_neg_text.append(triplet_pos_text[neg_idx])
        
    return triplet_images, triplet_pos_text, triplet_neg_text
    

## Models

In [14]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

for param in img_model.parameters():
    param.requires_grad = False

for param in img_model.encoder.layer[11].parameters():
    param.requires_grad = True

for param in img_model.layernorm.parameters():
    param.requires_grad = True

for param in img_model.pooler.parameters():
    param.requires_grad = True

img_model = nn.DataParallel(img_model, device_ids = [0])
img_model = img_model.to(f'cuda:{img_model.device_ids[0]}')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_model = BertModel.from_pretrained("bert-base-uncased")

for param in text_model.parameters():
    param.requires_grad = False

for param in text_model.encoder.layer[11].parameters():
    param.requires_grad = True

for param in text_model.pooler.parameters():
    param.requires_grad = True

text_model = nn.DataParallel(text_model, device_ids = [0])
text_model = text_model.to(f'cuda:{text_model.device_ids[0]}')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Training Prep

In [15]:
class AverageMeter(object):
    # Utility function for timers
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [18]:
optimizer_image = torch.optim.Adam(img_model.parameters(), lr=1e-4, weight_decay=0.0)
optimizer_text = torch.optim.Adam(text_model.parameters(), lr=1e-4, weight_decay=0.0)

optimizer_total = torch.optim.Adam(list(text_model.parameters()) + list(img_model.parameters()), lr=1e-4, weight_decay=0.0)

criterion = nn.TripletMarginLoss(margin = 1)
criterion.to(device);

In [16]:
from tqdm.notebook import tqdm

def triplet_train(triplet_loader, img_model, text_model, criterion, optimizer_image, optimizer_text,  optimizer_total, epoch):
    print('Starting training epoch {}'.format(epoch))
    img_model.train()
    text_model.train()
    
    batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
    end = time.time()
    
#     optimizer_image.zero_grad()
#     optimizer_text.zero_grad()
    optimizer_total.zero_grad()
    
    train_loss, total_samples, running_loss = 0, 0, 0
    
    batch = 1
    
    with tqdm(total = len(triplet_loader)) as pbar:
        for img, text in triplet_loader:
            
            img, pos_text, neg_text = generate_triplets(img, text)

            image_encodings = img_model(img.to(f'cuda:{text_model.device_ids[0]}'))
            
            pos_encoded_ingredients = tokenizer(pos_text, return_tensors='pt', max_length=512, truncation = True, padding = True).to(f'cuda:{text_model.device_ids[0]}')
            pos_output_ingredients = text_model(**pos_encoded_ingredients)

            neg_encoded_ingredients = tokenizer(neg_text, return_tensors='pt', max_length=512, truncation = True, padding = True).to(f'cuda:{text_model.device_ids[0]}')
            neg_output_ingredients = text_model(**neg_encoded_ingredients)

            loss = criterion(image_encodings["last_hidden_state"][:, 0, :], pos_output_ingredients["last_hidden_state"][:, 0, :], 
                             neg_output_ingredients["last_hidden_state"][:, 0, :])

            data_time.update(time.time() - end)

            # Compute gradient and optimize
#             optimizer_image.zero_grad()
#             optimizer_text.zero_grad()
            optimizer_total.zero_grad()

            loss.backward()
#             optimizer_image.step()
#             optimizer_text.step()
            optimizer_total.step()

            batch_time.update(time.time() - end)
            end = time.time()

            running_loss += loss.item() * img.shape[0]
            total_samples += img.shape[0]

            train_loss += running_loss

            if batch % 50 == 0:
                print('  batch {} loss: {}'.format(batch, running_loss / 50))
                running_loss = 0.

            if batch % 50 == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val} ({batch_time.avg})\t'
                    'Data {data_time.val} ({data_time.avg})\t'.format(
                      epoch, batch, len(train_loader), batch_time=batch_time,
                     data_time=data_time)) 
                pbar.update(50)

            batch += 1

        print('Finished training epoch {}'.format(epoch))
        print('Epoch Loss:', train_loss / total_samples)


## Save models

In [20]:
save_dict = {
    "image_vit_encoder": img_model.state_dict(),
    "text_encoder": text_model.state_dict(),
}

torch.save(save_dict, f'/common/users/kcm161/step3_models_instructions_onlytriplet_e1.pt')

## Load models

In [11]:
model_weights = torch.load("/common/users/kcm161/step3_models_onlytriplet_e1.pt", map_location = "cuda:0")

img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

for param in img_model.parameters():
    param.requires_grad = False

for param in img_model.encoder.layer[11].parameters():
    param.requires_grad = True

for param in img_model.layernorm.parameters():
    param.requires_grad = True

for param in img_model.pooler.parameters():
    param.requires_grad = True

img_model = nn.DataParallel(img_model, device_ids = [0])
img_model.load_state_dict(model_weights["image_vit_encoder"])
img_model = img_model.to(f'cuda:{img_model.device_ids[0]}')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_model = BertModel.from_pretrained("bert-base-uncased")

for param in text_model.parameters():
    param.requires_grad = False

for param in text_model.encoder.layer[11].parameters():
    param.requires_grad = True

for param in text_model.pooler.parameters():
    param.requires_grad = True

text_model = nn.DataParallel(text_model, device_ids = [0])
text_model.load_state_dict(model_weights["text_encoder"])
text_model = text_model.to(f'cuda:{text_model.device_ids[0]}')

optimizer_image = torch.optim.Adam(img_model.parameters(), lr=1e-4, weight_decay=0.0)
optimizer_text = torch.optim.Adam(text_model.parameters(), lr=1e-4, weight_decay=0.0)

optimizer_total = torch.optim.Adam(list(text_model.parameters()) + list(img_model.parameters()), lr=1e-4, weight_decay=0.0)

criterion = nn.TripletMarginLoss(margin = 1)
criterion.to(device);

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Training 

In [19]:
epochs = 10

for epoch in range(1, epochs):
    triplet_train(triplet_train_loader, img_model, text_model, criterion, optimizer_image, optimizer_text, optimizer_total, epoch)
    
    save_dict = {
        "image_vit_encoder": img_model.state_dict(),
        "text_encoder": text_model.state_dict(),
    }
    
    torch.save(save_dict, f'/common/users/kcm161/step3_models_instructions_onlytriplet_e{epoch}.pt')

Starting training epoch 1


  0%|          | 0/17600 [00:00<?, ?it/s]

  batch 50 loss: 15.886473064422608
Epoch: [1][50/17600]	Time 0.875758171081543 (0.9893178224563599)	Data 0.8578240871429443 (0.9697428941726685)	
  batch 100 loss: 15.594302463531495
Epoch: [1][100/17600]	Time 0.6010310649871826 (0.940607397556305)	Data 0.5835087299346924 (0.9218562650680542)	
  batch 150 loss: 11.192723121643066
Epoch: [1][150/17600]	Time 1.110100507736206 (0.948008599281311)	Data 1.0920896530151367 (0.929504861831665)	
  batch 200 loss: 8.379649143218995
Epoch: [1][200/17600]	Time 1.1579561233520508 (0.9301448488235473)	Data 1.1392502784729004 (0.9116637527942657)	
  batch 250 loss: 8.228947830200195
Epoch: [1][250/17600]	Time 1.06022310256958 (0.9674950437545776)	Data 1.0425007343292236 (0.9490633420944213)	
  batch 300 loss: 7.131164169311523
Epoch: [1][300/17600]	Time 1.0772957801818848 (1.0032671268781026)	Data 1.0597848892211914 (0.9849181588490804)	
  batch 350 loss: 7.641985130310059
Epoch: [1][350/17600]	Time 1.1193079948425293 (1.0308578743253436)	Data 1.10

RuntimeError: The size of tensor a (16) must match the size of tensor b (14) at non-singleton dimension 0

# End-to-end training

## Models

### Image Embeddings - ViT

In [9]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

In [10]:
for param in img_model.parameters():
    param.requires_grad = False

for param in img_model.encoder.layer[11].parameters():
    param.requires_grad = True

for param in img_model.layernorm.parameters():
    param.requires_grad = True

for param in img_model.pooler.parameters():
    param.requires_grad = True

img_model = nn.DataParallel(img_model, device_ids = [0])
img_model = img_model.to(f'cuda:{img_model.device_ids[0]}')

### Text Embeddings - BERT

In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_model = BertModel.from_pretrained("bert-base-uncased")

for param in text_model.parameters():
    param.requires_grad = False

for param in text_model.encoder.layer[11].parameters():
    param.requires_grad = True

for param in text_model.pooler.parameters():
    param.requires_grad = True

text_model = nn.DataParallel(text_model, device_ids = [0])
text_model = text_model.to(f'cuda:{text_model.device_ids[0]}')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Load ViT and BERT weights

In [8]:
model_weights = torch.load("/common/users/kcm161/step3_models_instructions_onlytriplet_e1.pt", map_location = "cuda:0")

img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

for param in img_model.parameters():
    param.requires_grad = False

img_model = nn.DataParallel(img_model, device_ids = [0])
img_model.load_state_dict(model_weights["image_vit_encoder"])
img_model = img_model.to(f'cuda:{img_model.device_ids[0]}')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_model = BertModel.from_pretrained("bert-base-uncased")

for param in text_model.parameters():
    param.requires_grad = False

text_model = nn.DataParallel(text_model, device_ids = [0])
text_model.load_state_dict(model_weights["text_encoder"])
text_model = text_model.to(f'cuda:{text_model.device_ids[0]}')

optimizer_image = torch.optim.Adam(img_model.parameters(), lr=1e-4, weight_decay=0.0)
optimizer_text = torch.optim.Adam(text_model.parameters(), lr=1e-4, weight_decay=0.0)

optimizer_total = torch.optim.Adam(list(text_model.parameters()) + list(img_model.parameters()), lr=1e-4, weight_decay=0.0)

criterion = nn.TripletMarginLoss(margin = 1)
criterion.to(device);

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Transformer

In [9]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout( p = dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        pe = torch.transpose(pe, 0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[0, :x.size(1), :]
        return self.dropout(x)

In [10]:
class CrossModalAttention(nn.Module):
    def __init__(self, model_dim = 768, n_heads = 2, n_layers = 2, num_image_patches = 197, num_classes = 2, drop_rate = 0.1):
        super().__init__()
        self.text_positional = SinusoidalPositionalEncoding(model_dim, dropout= drop_rate)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        self.sep_token = nn.Parameter(torch.zeros(1, 1, model_dim))
        self.image_positional = nn.Parameter(torch.zeros(1, num_image_patches + 1, model_dim))
        self.image_positional_drop = nn.Dropout(p = drop_rate)
        layers = nn.TransformerEncoderLayer(d_model = model_dim, nhead = n_heads, batch_first= True)
        self.encoder = nn.TransformerEncoder(layers, num_layers = n_layers)
        self.cls_projection = nn.Linear(model_dim, num_classes)

    def forward(self, image_features, text_features, src_key_padding_mask = None):
        # image_features = image_features.to(device)

        image_features *= math.sqrt(768)
        text_features *= math.sqrt(768)

        batch_size = image_features.shape[0]

        cls_token = self.cls_token.expand(batch_size, -1, -1)
        image_features = torch.cat((cls_token, image_features), dim = 1)
        image_features = image_features + self.image_positional
        image_features = self.image_positional_drop(image_features)

        text_features = self.text_positional(text_features)
        sep_token = self.sep_token.expand(batch_size, -1, -1)

        transformer_input = torch.cat((image_features, sep_token, text_features), dim = 1)
        if src_key_padding_mask is not None:
            src_key_padding_mask = torch.cat((torch.zeros(image_features.shape[0], image_features.shape[1] + 1).to(f'cuda:{transformer.device_ids[0]}'), 
                                              src_key_padding_mask.to(f'cuda:{transformer.device_ids[0]}')), dim  = 1)

        transformer_outputs = self.encoder(transformer_input, src_key_padding_mask = src_key_padding_mask)
        projected_output = transformer_outputs[:, 0, :]
        
        return self.cls_projection(projected_output)

In [17]:
transformer = CrossModalAttention()
transformer = nn.DataParallel(transformer, device_ids=[0])

## Training prep

In [15]:
class AverageMeter(object):
    # Utility function for timers
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val, self.avg, self.sum, self.count = 0, 0, 0, 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [18]:
optimizer_image = torch.optim.SGD(img_model.parameters(), lr=1e-3, weight_decay=0.0)
optimizer_text = torch.optim.SGD(text_model.parameters(), lr=1e-3, weight_decay=0.0)
optimizer_transformer = torch.optim.SGD(transformer.parameters(), lr=1e-3, weight_decay=0.0)

criterion = nn.CrossEntropyLoss()

img_model.to(device);
text_model.to(device);
transformer.to(device);
criterion.to(device);

In [19]:
from tqdm.notebook import tqdm

def train(train_loader, img_model, text_model, transformer, criterion, optimizer_image, optimizer_text, optimizer_transformer, epoch):
    print('Starting training epoch {}'.format(epoch))
    img_model.train()
    text_model.train()
    transformer.train()
    
    batch_time, data_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
    end = time.time()
    
    optimizer_image.zero_grad()
    optimizer_text.zero_grad()
    optimizer_transformer.zero_grad()
    
    train_loss, total_samples, running_loss = 0, 0, 0
    
    batch = 1
    
    with tqdm(total = len(train_loader)) as pbar:
        for img, text in train_loader:
        
            # Run forward pass
    #         print(img.shape)
            image_encodings = img_model(img.to(f'cuda:{text_model.device_ids[0]}'))

            encoded_ingredients = tokenizer(text, return_tensors='pt', max_length=512, truncation = True, padding = True).to(f'cuda:{text_model.device_ids[0]}')
            output_ingredients = text_model(**encoded_ingredients)

            transformer_image_inputs, transformer_text_inputs, output_attention_mask, ground_truth = get_transformer_input (image_encodings["last_hidden_state"], 
                                                                                                                            output_ingredients["last_hidden_state"], 
                                                                                                                            encoded_ingredients.attention_mask)
            text_padding_mask = ~output_attention_mask.bool()
            outputs = transformer(transformer_image_inputs.to(f'cuda:{transformer.device_ids[0]}'), transformer_text_inputs.to(f'cuda:{transformer.device_ids[0]}'), 
                                  text_padding_mask.to(f'cuda:{transformer.device_ids[0]}'))

            loss = criterion(outputs, ground_truth.to(f'cuda:{transformer.device_ids[0]}')) 

            data_time.update(time.time() - end)

            # Compute gradient and optimize
            optimizer_image.zero_grad()
            optimizer_text.zero_grad()
            optimizer_transformer.zero_grad()

            loss.backward()
            optimizer_image.step()
            optimizer_text.step()
            optimizer_transformer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            running_loss += loss.item() * img.shape[0]
            total_samples += img.shape[0]

            train_loss += running_loss

            if batch % 50 == 0:
                print('  batch {} loss: {}'.format(batch, running_loss / 50))
                running_loss = 0.

            if batch % 50 == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val} ({batch_time.avg})\t'
                    'Data {data_time.val} ({data_time.avg})\t'.format(
                      epoch, batch, len(train_loader), batch_time=batch_time,
                     data_time=data_time)) 
                pbar.update(50)

            batch += 1

        print('Finished training epoch {}'.format(epoch))
        print('Epoch Loss:', train_loss / total_samples)


In [20]:
def get_transformer_input(image_features, text_features, input_attention_mask):
    neg_to_pos_ratio = 2
    
    input_batch_size = image_features.shape[0]
    output_batch_size = (neg_to_pos_ratio + 1) * input_batch_size
    ground_truths = torch.zeros(output_batch_size)
    ground_truths[:input_batch_size] = 1
        
    final_image_features = torch.zeros(output_batch_size, *image_features.shape[1:])
    final_text_features = torch.zeros(output_batch_size, *text_features.shape[1:])
    output_attention_mask = torch.zeros(output_batch_size, *input_attention_mask.shape[1:])
    
    final_image_features[: input_batch_size] = image_features
    final_text_features[:input_batch_size] = text_features
    output_attention_mask[:input_batch_size] = input_attention_mask
    
    for n in range(neg_to_pos_ratio):
        a = torch.randperm(input_batch_size)
        b = torch.zeros(input_batch_size).to(dtype = torch.int64)
        
        for idx in range(input_batch_size):
            c = random.randint(0, input_batch_size - 1)
            while c == a[idx]:
                c = random.randint(0, input_batch_size - 1)
            
            b[idx] = c
            
        final_image_features[((1 + n) * input_batch_size): (2 + n) * input_batch_size] = image_features[a]
        final_text_features[(1 + n) * input_batch_size: (2 + n) * input_batch_size] = text_features[b]
        output_attention_mask[(1 + n) * input_batch_size: (2 + n) * input_batch_size] = input_attention_mask[b]
        
    return final_image_features, final_text_features, output_attention_mask, ground_truths.type(torch.LongTensor)

## Training loop

In [22]:
epochs = 10

for epoch in range(1, epochs):
    train(train_loader, img_model, text_model, transformer, criterion, optimizer_image, optimizer_text, optimizer_transformer, epoch)
    
    save_dict = {
        "image_vit_encoder": img_model.state_dict(),
        "text_encoder": text_model.state_dict(),
        "cm_transformer": transformer.state_dict()
    }
    
    torch.save(save_dict, f'/common/users/kcm161/step3_models_instructions_with_triplet_sgd_e{epoch}.pt')

Starting training epoch 1


  0%|          | 0/17600 [00:00<?, ?it/s]

  batch 50 loss: 10.386340618133545
Epoch: [1][50/17600]	Time 0.9003903865814209 (1.043252534866333)	Data 0.8919963836669922 (1.0347679805755616)	
  batch 100 loss: 10.298310947418212
Epoch: [1][100/17600]	Time 0.6943926811218262 (1.0010074949264527)	Data 0.6859979629516602 (0.9926024055480958)	
  batch 150 loss: 10.261950550079346
Epoch: [1][150/17600]	Time 1.1548550128936768 (1.0066885312398275)	Data 1.1466069221496582 (0.9983069960276286)	
  batch 200 loss: 10.236226749420165
Epoch: [1][200/17600]	Time 1.3503468036651611 (0.9929545593261718)	Data 1.3420495986938477 (0.9845962071418762)	
  batch 250 loss: 10.165457935333253
Epoch: [1][250/17600]	Time 1.214792013168335 (1.035047643661499)	Data 1.2064740657806396 (1.0266626243591308)	
  batch 300 loss: 10.127619132995605
Epoch: [1][300/17600]	Time 1.1201674938201904 (1.070337659517924)	Data 1.1119122505187988 (1.0619704214731853)	
  batch 350 loss: 10.238421554565429
Epoch: [1][350/17600]	Time 1.2172636985778809 (1.101286871773856)	Dat

  0%|          | 0/17600 [00:00<?, ?it/s]

  batch 50 loss: 5.2627639484405515
Epoch: [2][50/17600]	Time 0.9513545036315918 (1.102919192314148)	Data 0.9378535747528076 (1.0902174758911132)	
  batch 100 loss: 5.452137999534607
Epoch: [2][100/17600]	Time 0.7375247478485107 (1.0669990181922913)	Data 0.7230088710784912 (1.054327130317688)	
  batch 150 loss: 5.1708753490448
Epoch: [2][150/17600]	Time 1.2102572917938232 (1.0751028378804526)	Data 1.1963965892791748 (1.0625444253285725)	
  batch 200 loss: 5.248636932373047
Epoch: [2][200/17600]	Time 1.085035800933838 (1.093427301645279)	Data 1.071558952331543 (1.081041761636734)	
  batch 250 loss: 5.340456528663635
Epoch: [2][250/17600]	Time 0.989823579788208 (1.068752371788025)	Data 0.9771020412445068 (1.056390139579773)	
  batch 300 loss: 5.45594244480133
Epoch: [2][300/17600]	Time 0.8946633338928223 (1.062893912792206)	Data 0.8823950290679932 (1.0506064152717591)	
  batch 350 loss: 5.651820096969605
Epoch: [2][350/17600]	Time 1.041299819946289 (1.062803705760411)	Data 1.028723478317

  0%|          | 0/17600 [00:00<?, ?it/s]

  batch 50 loss: 4.8082512092590335
Epoch: [3][50/17600]	Time 0.9691228866577148 (1.1020145511627197)	Data 0.958892822265625 (1.0904083156585693)	
  batch 100 loss: 4.767690508365631
Epoch: [3][100/17600]	Time 0.7561020851135254 (1.0651993584632873)	Data 0.746279239654541 (1.0534481167793275)	
  batch 150 loss: 4.779217371940613
Epoch: [3][150/17600]	Time 1.2258760929107666 (1.0717190281550089)	Data 1.2141444683074951 (1.0599225362141926)	
  batch 200 loss: 4.639702723026276
Epoch: [3][200/17600]	Time 1.0775225162506104 (1.0539909017086029)	Data 1.0653488636016846 (1.0422128760814666)	
  batch 250 loss: 4.766773447990418
Epoch: [3][250/17600]	Time 2.874955892562866 (1.049760768890381)	Data 2.865328550338745 (1.0379829416275024)	
  batch 300 loss: 4.839803433418274
Epoch: [3][300/17600]	Time 0.8955326080322266 (1.4067522064844766)	Data 0.8829221725463867 (1.3952077571551005)	
  batch 350 loss: 4.975383443832397
Epoch: [3][350/17600]	Time 3.424175977706909 (1.6415507629939488)	Data 3.410

KeyboardInterrupt: 

## Save models

In [17]:
save_dict = {
    "image_vit_encoder": img_model.state_dict(),
    "text_encoder": text_model.state_dict(),
    "cm_transformer": transformer.state_dict()
}

torch.save(save_dict, f'/common/users/kcm161/step3_models_with_triplet_e1.pt')

## Load models

In [12]:
model_weights = torch.load("/common/users/kcm161/step3_models_with_triplet_e1.pt", map_location = torch.device("cuda:1"))

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text_model = BertModel.from_pretrained("bert-base-uncased")
img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

for param in text_model.parameters():
    param.requires_grad = False

for param in img_model.parameters():
    param.requires_grad = False

transformer = CrossModalAttention()

optimizer_image = torch.optim.SGD(img_model.parameters(), lr=1e-3, weight_decay=0.0)
optimizer_text = torch.optim.SGD(text_model.parameters(), lr=1e-3, weight_decay=0.0)
optimizer_transformer = torch.optim.SGD(transformer.parameters(), lr=1e-3, weight_decay=0.0)

criterion = nn.CrossEntropyLoss().to(device)

text_model = nn.DataParallel(text_model, device_ids=[1])
text_model.load_state_dict(model_weights["text_encoder"])
text_model.to((f'cuda:{text_model.device_ids[0]}'));
text_model.train();

img_model = nn.DataParallel(img_model, device_ids=[1])
img_model.load_state_dict(model_weights["image_vit_encoder"])
img_model.to((f'cuda:{img_model.device_ids[0]}'));
img_model.train();

transformer = nn.DataParallel(transformer, device_ids=[1])
# transformer.load_state_dict(model_weights_1["cm_transformer"])
transformer.to((f'cuda:{transformer.device_ids[0]}'));
transformer.train();

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Ablation Studies

In [23]:
len(test_loader)

60740

## Generate validation and test embeddings

### Load models

In [14]:
text_model = BertModel.from_pretrained("bert-base-uncased")
img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
transformer = CrossModalAttention()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

model_weights = torch.load("/common/users/kcm161/step3_models_with_triplet_sgd_e1.pt", map_location="cuda:1")

text_model = nn.DataParallel(text_model, device_ids=[1])
text_model.load_state_dict(model_weights["text_encoder"])
text_model.to((f'cuda:{text_model.device_ids[0]}'));
text_model.eval();

img_model = nn.DataParallel(img_model, device_ids=[1])
img_model.load_state_dict(model_weights["image_vit_encoder"])
img_model.to((f'cuda:{img_model.device_ids[0]}'));
img_model.eval();

transformer = nn.DataParallel(transformer, device_ids=[1])
transformer.load_state_dict(model_weights["cm_transformer"])
transformer.to((f'cuda:{transformer.device_ids[0]}'));
transformer.eval();

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Generate embeddings

In [26]:
def generate_embeddings(loader, img_model, text_model, transformer):
    
    model_weights = torch.load("/common/users/kcm161/step3_models_instructions_with_triplet_sgd_e2.pt", map_location="cuda:0")

    text_model = BertModel.from_pretrained("bert-base-uncased")
    img_model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
    transformer = CrossModalAttention()

    text_model = nn.DataParallel(text_model, device_ids=[0])
    text_model.load_state_dict(model_weights["text_encoder"])
    text_model.to((f'cuda:{text_model.device_ids[0]}'));
    text_model.eval();

    img_model = nn.DataParallel(img_model, device_ids=[0])
    img_model.load_state_dict(model_weights["image_vit_encoder"])
    img_model.to((f'cuda:{img_model.device_ids[0]}'));
    img_model.eval();

    transformer = nn.DataParallel(transformer, device_ids=[0])
    transformer.load_state_dict(model_weights["cm_transformer"])
    transformer.to((f'cuda:{transformer.device_ids[0]}'));
    transformer.eval();
        
    img_encodings = np.zeros((len(loader),), dtype = object)
    text_encodings = np.zeros((len(loader),), dtype = object)
    text_masks = np.zeros((len(loader),), dtype = object)
    
    idx = 0

    print(img_encodings.shape, len(loader))
        
    with torch.no_grad():
        for img, text in loader:

            # Run forward pass
            image_output = img_model(img.to(f'cuda:{img_model.device_ids[0]}'))
            encoded_ingredients = tokenizer(text, return_tensors='pt', max_length=512, truncation = True, padding = True).to(f'cuda:{text_model.device_ids[0]}')
            output_ingredients = text_model(**encoded_ingredients)
            input_attention_mask = encoded_ingredients.attention_mask
            text_padding_mask = ~input_attention_mask.bool()

            img_encodings[idx] = image_output["last_hidden_state"].cpu().detach().numpy()
            text_encodings[idx] = output_ingredients["last_hidden_state"].cpu().detach().numpy()
            text_masks[idx] = text_padding_mask.cpu().detach().numpy()
        
            idx += 1

            if idx % 2000 == 0:
                print(idx)
                # print(img_encodings[idx-1])
    # print(img_encodings)

    return img_encodings, text_encodings, text_masks
    
    # ranker(img_encodings, text_encodings, text_masks, transformer, "recipe", 1000)

## Ranking extractor embeddings

In [45]:
triplet_ranker(img_encode, text_encode)

Mean median 15.7
Recall {1: 0.09549999999999999, 5: 0.28869999999999996, 10: 0.4130999999999999}


In [None]:
def triplet_ranker(im_vecs, instr_vecs, N = 1000, flag = "image"):
    # Ranker
    idxs = range(N)

    glob_rank = []
    glob_recall = {1:0.0,5:0.0,10:0.0}
    for i in range(10):

        ids = random.sample(range(0,len(im_vecs)), N)
        
        im_sub = im_vecs[ids]
        instr_sub = instr_vecs[ids]

        if flag == "image":
            sims = np.dot(im_sub,instr_sub.reshape((N, 768, 1))) # for im2recipe
        else:
            sims = np.dot(instr_sub,im_sub.T) # for recipe2im

        med_rank = []
        recall = {1:0.0,5:0.0,10:0.0}

        sims = sims.squeeze(1).squeeze(2)

        for ii in idxs:

            # name = ids_sub[ii]
            # get a column of similarities
            sim = sims[ii]

            # sort indices in descending order
            sorting = np.argsort(sim)[::-1].tolist()

            # print(sorting)

            # find where the index of the pair sample ended up in the sorting
            pos = sorting.index(ii)

            if (pos+1) == 1:
                recall[1]+=1
            if (pos+1) <=5:
                recall[5]+=1
            if (pos+1)<=10:
                recall[10]+=1

            # store the position
            med_rank.append(pos+1)

        for i in recall.keys():
            recall[i]=recall[i]/N

        med = np.median(med_rank)
#         print ("median", med)

        for i in recall.keys():
            glob_recall[i]+=recall[i]
        glob_rank.append(med)

    for i in glob_recall.keys():
        glob_recall[i] = glob_recall[i]/10
    
    print ("Mean median", np.average(glob_rank))
    print ("Recall", glob_recall)

## Ranking Transformer outputs

### Ingredients - testing

In [18]:
img_encodings, text_encodings, text_masks = generate_embeddings(test_loader, img_model, text_model, transformer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(60740,) 60740
2000
4000
6000
8000
10000
12000
14000
16000
18000
20000
22000
24000
26000
28000
30000
32000
34000
36000
38000
40000
42000
44000
46000
48000
50000
52000
54000
56000
58000
60000


In [22]:
# ingredients im2recipe
ranker(img_encodings, text_encodings, text_masks, transformer)

(10.8, 1.5362291495737217, {1: 0.072, 5: 0.286, 10: 0.487})

In [23]:
# ingredients recipe2im
ranker(img_encodings, text_encodings, text_masks, transformer, "image")

(6.1,
 1.2206555615733703,
 {1: 0.159, 5: 0.4699999999999999, 10: 0.6340000000000001})

### Title - testing

In [31]:
img_encodings, text_encodings, text_masks = generate_embeddings(test_loader, img_model, text_model, transformer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(60740,) 60740
2000
4000
6000
8000
10000
12000
14000
16000
18000
20000
22000
24000
26000
28000
30000
32000
34000
36000
38000
40000
42000
44000
46000
48000
50000
52000
54000
56000
58000
60000


In [34]:
# title im2recipe
ranker(img_encodings, text_encodings, text_masks, transformer)

(25.5,
 3.0413812651491097,
 {1: 0.046000000000000006, 5: 0.15599999999999997, 10: 0.26200000000000007})

In [35]:
# title im2recipe
ranker(img_encodings, text_encodings, text_masks, transformer, "recipe")

(27.1,
 2.981610303175115,
 {1: 0.033999999999999996, 5: 0.14200000000000002, 10: 0.24300000000000002})

### Instructions - testing

### Full Text - testing

# Ranking Functions

In [20]:
def ranker(img_encodings, text_encodings, text_masks, transformer, retrieval_type = "recipe", n = 100):
    t = time.time()
    data_size = len(img_encodings)
    
    glob_rank = []
    glob_recall = {1:0.0, 5:0.0, 10:0.0}
    
    # with tqdm (total = n * n * 10) as pbar:
    for i in range(10):
        ids_sub = np.random.choice(data_size, n, replace = False)
        # imgs_sub = img_encodings[ids_sub, :] # numpy 
        # text_sub = text_encodings[ids_sub, :] # numpy
        # attn_sub = text_masks[ids_sub, :] # numpy

        imgs_sub = img_encodings[ids_sub] # numpy 
        text_sub = text_encodings[ids_sub] # numpy
        attn_sub = text_masks[ids_sub] # numpy
                
        probs = torch.zeros((n,n)).detach().cpu()
        
        if retrieval_type == "recipe":
            for x in range(n):
                for y in range(n):
                    temp = transformer(torch.from_numpy(imgs_sub[x]), torch.from_numpy(text_sub[y]), torch.from_numpy(attn_sub[y]))
                    probs[x, y] =  nn.Softmax(dim=1)(temp)[0][1].detach().cpu()
                    # print(probs[x,y])
                
                # pbar.update(n)

        else:
            for x in range(n):
                for y in range(n):
                    temp = transformer(torch.from_numpy(imgs_sub[y]), torch.from_numpy(text_sub[x]), torch.from_numpy(attn_sub[x]))
                    probs[x, y] =  nn.Softmax(dim=1)(temp)[0][1].detach().cpu()
                    # print(probs[x,y])
                

        ranks, _ = compute_ranks(probs.numpy())
        
        recall = {1: 0.0, 5:0.0, 10:0.0}
        for ii in recall.keys():
            recall[ii] = (ranks <= ii).sum() / ranks.shape[0]
        med = int(np.median(ranks))
        # print(med, recall)
        for ii in recall.keys():
            glob_recall[ii] += recall[ii]
        glob_rank.append(med)

        # print(i)
        
    for i in glob_recall.keys():
        glob_recall[i] /= 10
        
    medR = np.mean(glob_rank)
    medR_std = np.std(glob_rank)
            
    return medR, medR_std, glob_recall

In [21]:
def compute_ranks(sims):
    ranks = []
    preds = []

    # print(sims, sims.shape)
    
    for ii in range(sims.shape[0]):
        sim = sims[ii, :]
        sorting = np.argsort(sim)[::-1].tolist()
        pos = sorting.index(ii)
        
        ranks.append(pos + 1.0)
        preds.append(sorting[0])
        
    return np.asarray(ranks), preds