In [1]:
import transformers
import torch
import torchvision

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

import tokenizers
import itertools

import random
import math
import copy
from timm.scheduler import CosineLRScheduler

import wandb 
import nltk
nltk.download('stopwords')


device = 'cuda:1'


DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'vit_bert_s - randn temperature'
algo = 'MAMO'


NUM_WORKERS = 14
torch.set_num_threads(12)


id = wandb.util.generate_id()
wandb.login()

# set earlier ID
# id = '8z7zcst9'
id = 'yod1ogj5'

print(id)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to /home/ml/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmadhava20217[0m. Use [1m`wandb login --relogin`[0m to force relogin


yod1ogj5


In [2]:
DATASET_SRC = '../Datasets/Flickr30k/'
MODEL_SAVE_PATH = f'Models/{model_name}/{algo}/checkpoint'

VOCAB_PATH = 'Vocabulary/flickr30k.vocab'

lr = 1e-4
init_lr = 1e-6
min_lr = 1e-5
decay = 0.01
beta1 = 0.9
beta2 = 0.999

warmup_epochs = 5
EPOCHS = 40

BATCH_SIZE = 60

MASKING_RATIO_IMG = 0.75
MASKING_RATIO_TXT = 0.25

ALPHA = 0.995               # EWMA


n_layers = 3

if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))
    
weights_path = 'Models/vit_bert_s - randn temperature/MAMO/checkpoint_44.pth'

In [3]:
def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel

def deleteLaterEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(num_layers_to_keep, 0, -1):
        newModuleList.append(oldModuleList[-i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel


def get_bert_model(model, num_layers):
    return deleteEncodingLayers(model, num_layers)



class MAMO_mixer(torch.nn.Module):
    def __init__(self, base_bert, n_layers = 2, n_visual_tokens = 197, vision_embedding_dim = 384, emb_dims = 512):
        # prepare decoder
        super().__init__()
        self.n_visual_tokens = n_visual_tokens
        self.vision_emb_dim = vision_embedding_dim
        self.base_model = deleteLaterEncodingLayers(base_bert.base_model, n_layers).encoder
        
        self.pooler = torch.nn.AdaptiveAvgPool1d(1)
        self.emb_dimension = emb_dims
        
        if self.vision_emb_dim == self.emb_dimension:
            self.dimension_caster = torch.nn.Identity()
        else:
            self.dimension_caster = torch.nn.Linear(self.vision_emb_dim, self.emb_dimension, bias = False)  # no bias here
        
        
    def forward(self, vision_embedding, text_embedding, text_attn_mask):
        # assert len(vision_embedding) == len(text_embedding)
        n_batch = len(vision_embedding)
        
        # normalize dimensions
        new_vision_emb = self.dimension_caster(vision_embedding)
        
        # concatenate
        concatenated_emb = torch.cat([new_vision_emb, text_embedding], dim = 1)
        
        # create attention mask
        vision_attention_mask = torch.ones(n_batch, self.n_visual_tokens).to(text_attn_mask.device)
        attn_mask = torch.cat([vision_attention_mask, text_attn_mask], dim = 1)
        
        attn_mask = attn_mask[:, None, None, :]
        
        # forward
        return self.base_model(concatenated_emb, attn_mask)

In [4]:
class MAMO(torch.nn.Module):
    def __init__(self,
                 vit,
                 bert,
                 vit_num_patches = 196,
                 vit_emb_dim = 384,
                 bert_emb_dim = 512,
                 bert_layers = 2,
                 vocab_size = 30522,
                 mask_token_id = 103):
       super().__init__()
       self.vit = vit
       self.bert = bert.base_model
       self.bert = deleteEncodingLayers(self.bert.base_model, bert_layers)
       self.mamo = MAMO_mixer(bert, bert_layers, 197, vit_emb_dim)
       
       # vit patches data
       self.vit_num_patches = vit_num_patches
       
       # vocab size
       self.vocab_size = vocab_size
       # mask token
       self.mask_token_id = mask_token_id
       
       # learnable temperature parameter
       self.tau = torch.nn.Parameter(torch.Tensor([1.]))#torch.FloatTensor(1).uniform_(2, 5))       # uniform in range 1 to 5
       self.tau.requires_grad = True

       # joint representation
       self.pooler = torch.nn.Sequential(
           torch.nn.AdaptiveAvgPool1d(1),
           torch.nn.Flatten()
       )
       self.img_proj = torch.nn.Linear(vit_emb_dim, min(vit_emb_dim, bert_emb_dim))
       self.txt_proj = torch.nn.Linear(bert_emb_dim, min(vit_emb_dim, bert_emb_dim))

       
       # masked representation modeling
       self.mrm_proj = torch.nn.Sequential(
            torch.nn.Linear(bert_emb_dim, bert_emb_dim),
            torch.nn.Tanh(),
       )
       
       # head for masked image modeling
       self.mim_proj = torch.nn.Sequential(
           torch.nn.Linear(bert_emb_dim, vit_emb_dim),
       )
        
       # head for masked language modeling
       self.mlm_head = bert.cls
       
       self.itc_head = torch.nn.Linear(bert_emb_dim, bert_emb_dim)
       
       self.itm__head = torch.nn.Sequential(
           torch.nn.Linear(bert_emb_dim, bert_emb_dim),
           torch.nn.LeakyReLU(),
           torch.nn.Linear(bert_emb_dim, 1)
       )
    
       
       
    def forward(self, image, text, attn_mask,
                masked_image = None,
                masked_text = None,
                image_text_matching = False,
                ):
        
        if image_text_matching == True:
            img_rep = torch.nn.functional.normalize(self.vit(image)['last_hidden_state'], p = 2, dim = 2)
            txt_rep = torch.nn.functional.normalize(self.bert(text, attn_mask)['last_hidden_state'], p = 2, dim = 2)
            joint_rep = self.mamo(img_rep, txt_rep, attn_mask)['last_hidden_state']
            
            return img_rep, txt_rep, joint_rep, self.itm__head(joint_rep)[:, 0, :]
        
        else:
            # return mask_img-clean_txt, clean_img,-mask_txt, 
            img_rep = self.vit(image)['last_hidden_state']              # clean image
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']   # clean text
            
            mask_img_rep = self.vit(masked_image)['last_hidden_state']
            mask_txt_rep = self.bert(masked_text, attn_mask)['last_hidden_state']
            
            # multimodal prediction
            c_img_m_txt = self.mamo(img_rep, mask_txt_rep, attn_mask)['last_hidden_state']
            m_img_c_txt = self.mamo(mask_img_rep, txt_rep, attn_mask)['last_hidden_state']
            
            # pure txt
            txt_prediction = self.mlm_head(mask_txt_rep)
            
            # pool and flatten text and visual features obtained before fusion
            img_rep = torch.nn.functional.normalize(self.img_proj(self.pooler(img_rep.transpose(1,2))), 2)
            txt_rep = torch.nn.functional.normalize(self.txt_proj(self.pooler(txt_rep.transpose(1,2))), 2)
            
            return (c_img_m_txt, m_img_c_txt, mask_img_rep, txt_prediction, img_rep, txt_rep)
            
    def mrm_projection(self, rep):
        return torch.nn.functional.normalize(self.mrm_proj(rep), dim = 2)
    
    def mim_projection(self, rep):
        return torch.nn.functional.normalize(self.mim_proj(rep), dim = 2)
    
    def get_mrm_loss(self, online_representation, target_representation, mask):
        # remove cls token
        on_rep = self.mrm_projection(online_representation[:, 1:, :])
        tr_rep = target_representation[:, 1:, :]
        
        # normalize
        on_rep = torch.nn.functional.normalize(on_rep, dim = 2)
        tr_rep = torch.nn.functional.normalize(tr_rep, dim = 2)
        
        loss = torch.nn.functional.mse_loss(on_rep, tr_rep, reduction = 'none')
        mrm_loss = (loss * mask).sum()/(mask.sum() + 1e-5)              # add for 0 division errors
        return mrm_loss
    
    def get_mim_loss(self, online_representation, target_representation, mask):
        on_rep = self.mim_projection(online_representation[:, 1:self.vit_num_patches+1, :]) # omit cls token
        tr_rep = target_representation[:, 1:self.vit_num_patches+1, :]
        
        # normalize
        on_rep = torch.nn.functional.normalize(on_rep, dim = 2)
        tr_rep = torch.nn.functional.normalize(tr_rep, dim = 2)
        
        loss = torch.nn.functional.l1_loss(on_rep, tr_rep, reduction = 'none')
        if mask.ndim == 2:
            mask = mask[:, :, None]
        mim_loss = (loss * mask).sum() / (mask.sum() + 1e-5)
        return mim_loss 
    
    
    def get_mlm_loss(self, scores, sen, masked_sen):
        labels = torch.where(masked_sen == self.mask_token_id, sen, -100)
        loss = torch.nn.functional.cross_entropy(scores.view(-1, self.vocab_size), labels.view(-1), ignore_index=-100)
        
        return loss
    
    
    def get_itc_loss(self, img_feats, txt_feats):
        # Calculate cosine similarity
        sim = torch.exp((img_feats@txt_feats.T)/self.tau)
        self_mask = torch.eye(sim.shape[0], device=sim.device)

        return sim, (torch.nn.functional.cross_entropy(sim, self_mask) + torch.nn.functional.cross_entropy(sim.T, self_mask))/2.
    
    def get_samples(self, similarities):
        probs = torch.nn.functional.softmax(similarities, dim = 1)
        txt_indices = torch.multinomial(probs, num_samples=1, replacement=True).squeeze(1)
        img_indices = torch.multinomial(probs.T, num_samples=1, replacement=True).squeeze(1)
        
        return txt_indices, img_indices

In [5]:
DIMENSION = 224

MAX_LEN = 50

# ViT config
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained('WinKawaks/vit-small-patch16-224')
tokenizer = transformers.AutoTokenizer.from_pretrained("prajjwal1/bert-small")

import torchvision.transforms.v2 as v2

## image transforms
img_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.int8, scale = True),
    v2.Resize(size = (DIMENSION, DIMENSION), antialias = False),
    v2.RandAugment(),
    # v2.RandomVerticalFlip(),
    # v2.RandomHorizontalFlip(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean = [0.5, 0.5, 0.5],
        std =  [0.5, 0.5, 0.5]
    )
])



In [6]:
vit_model = transformers.ViTModel.from_pretrained('WinKawaks/vit-small-patch16-224').to(DEVICE)
bert_model = transformers.BertForMaskedLM.from_pretrained("prajjwal1/bert-small")

online_network = MAMO(
                    vit = vit_model,
                    bert = bert_model,
                    vit_num_patches= 196,
                    vit_emb_dim=384,
                    bert_emb_dim=512,
                    bert_layers=n_layers,
                    vocab_size=tokenizer.vocab_size,
                    mask_token_id= tokenizer.mask_token_id
                ).to(DEVICE)


chkpt = torch.load(weights_path, map_location=device)['online_model_state_dict']
online_network.load_state_dict(chkpt)

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [7]:
class Flickr30K_Finetune(torchvision.datasets.Flickr30k):
    def __init__(self,
                 data_path,
                 ann_path,
                 img_transform = None,
                 txt_transform = None,
                 max_length = 50,
                 ):
        super().__init__(data_path, ann_path)
        self.img_transform = img_transform
        self.tokenizer = txt_transform
        self.max_length = max_length

    def process_string(self, txts):
        ret = []
        for string in txts:
            tok_str = string.lower().split() # separated by spaces
            stopwords = nltk.corpus.stopwords.words('english')
            proc_str = [x for x in tok_str if x not in stopwords]               # stopword removal
            proc_str = [word.lower() for word in proc_str if word.isalpha()]    # punctuation removal
            ret.append(" ".join(proc_str))
            
        return ret
    
    def __getitem__(self, idx):
        img, txt = super().__getitem__(idx)
        # get images and texts
        img = self.img_transform(img)
        
        # process string
        txt = self.process_string(txt)
        tok_text = self.tokenizer(txt, truncation = True, padding = 'max_length', max_length = self.max_length, return_token_type_ids=False)
        toks, attn_mask = tok_text['input_ids'], tok_text['attention_mask']        
        toks, attn_mask = torch.tensor(toks), torch.tensor(attn_mask)
        
        return img, toks, attn_mask

In [8]:
dataset = Flickr30K_Finetune(
    DATASET_SRC + 'flickr30k-images',
    DATASET_SRC + 'results_20130124.token',
    img_transform,
    tokenizer,
    MAX_LEN
)

# dataloader

pretrain_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size = BATCH_SIZE,
    pin_memory = True,
    num_workers = 0,
    shuffle = True
)


for idx, data in enumerate(pretrain_dataloader):
    img, txt, attn_mask = data
    break

In [9]:
txt.shape

torch.Size([60, 5, 50])

In [10]:
#optimiser
optim = torch.optim.AdamW(online_network.parameters(),
                          lr = lr,
                          weight_decay = decay,
                          betas = [0.9, 0.999],
                          )

epoch_steps = math.ceil(len(dataset)/BATCH_SIZE)
num_steps = int(EPOCHS * epoch_steps)
warmup_steps = int(warmup_epochs * epoch_steps)

lr_scheduler = CosineLRScheduler(
        optim,
        t_initial=num_steps,
        # t_mul=1.,
        lr_min=min_lr,
        warmup_lr_init = init_lr,
        warmup_t=warmup_steps,
        cycle_limit=1,
        t_in_epochs=False,
    )