In [1]:
import transformers
import torch
import torchvision

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

import tokenizers
import itertools

import random
import math
import copy
from timm.scheduler import CosineLRScheduler

from utils.MAMO import MAMO
from utils.dataset import re_train_dataset, re_eval_dataset

import wandb
import nltk

os.environ['TOKENIZERS_PARALLELISM'] = 'true'
nltk.download('stopwords')

device = 'cuda:0'

DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'vit_bert_s - normalized'
algo = 'MAMO'

# fix the seed for reproducibility
seed = 6969
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = True

id = wandb.util.generate_id()
wandb.login()

NUM_WORKERS = 8
torch.set_num_threads(12)

id = 'vnxhkmir'
id

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to /home/ml/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmadhava20217[0m. Use [1m`wandb login --relogin`[0m to force relogin


'vnxhkmir'

In [2]:
MAX_LENGTH = 30
BATCH_SIZE = 80
DIMENSION = 224
EPOCHS = 10
warmup_epochs = 1

lr = 1e-4
init_lr = 1e-6
min_lr = 1e-5
decay = 0.01
beta1 = 0.9
beta2 = 0.999

n_layers = 3

In [3]:
weights_path = 'Models/vit_bert_s - normalized/MAMO/checkpoint_final.pth'

MODEL_SAVE_PATH = MODEL_SAVE_PATH = f'Finetuning/{model_name}/{algo}/checkpoint'
if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))

In [4]:
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, k=10, max_len = 30):
    # test; k for top-k; tokenizer is model.bert
    model.eval()

    texts = data_loader.dataset.text
    num_text = len(texts)
    text_bs = 64
    text_feats = []
    text_embeds = []
    text_atts = []
    for i in range(0, num_text, text_bs):
        text = texts[i: min(num_text, i+text_bs)]
        text_input = tokenizer(text, padding='max_length', truncation=True, max_length=max_len, return_tensors="pt").to(device)
        text_output = model.bert(text_input.input_ids, text_input.attention_mask)
        text_feat = torch.nn.functional.normalize(text_output['last_hidden_state'], dim = 2)
        text_embed = model.txt_proj(model.pooler(text_feat.transpose(1,2)))
        text_embeds.append(text_embed)
        text_feats.append(text_feat)
        text_atts.append(text_input.attention_mask)
    text_embeds = torch.cat(text_embeds,dim=0)
    text_feats = torch.cat(text_feats,dim=0)
    text_atts = torch.cat(text_atts,dim=0)

    image_feats = []
    image_embeds = []
    for image, img_id in data_loader:
        image = image.to(device)
        image_feat = torch.nn.functional.normalize(model.vit(image)['last_hidden_state'], dim = 2)
        image_embed = model.img_proj(model.pooler(image_feat.transpose(1, 2)))

        image_feats.append(image_feat)
        image_embeds.append(image_embed)

    image_feats = torch.cat(image_feats,dim=0)
    image_embeds = torch.cat(image_embeds,dim=0)

    sims_matrix = image_embeds @ text_embeds.t()
    score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)


    for i,sims in enumerate(sims_matrix):
        topk_sim, topk_idx = sims.topk(k=k, dim=0)

        encoder_output = image_feats[i].repeat(k,1,1)
        # encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
        output = model.mamo(encoder_output,
                            text_feats[topk_idx],
                            text_atts[topk_idx])['last_hidden_state']
        
        score = 1 - model.itm__head(output[:,0,:])#[:,1]          # take output for prediction head 1
        score_matrix_i2t[i,topk_idx] = score.squeeze(1)

    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)

    for i,sims in enumerate(sims_matrix):
        topk_sim, topk_idx = sims.topk(k=k, dim=0)
        encoder_output = image_feats[topk_idx]
        # encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
        output = model.mamo(encoder_output,
                                    text_feats[i].repeat(k,1,1),
                                    text_atts[i].repeat(k,1))['last_hidden_state']
        score = 1 - model.itm__head(output[:,0,:])#[:,1]
        score_matrix_t2i[i,topk_idx] = score.squeeze(1)

    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()



@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):

    #Images->Text
    ranks = np.zeros(scores_i2t.shape[0])
    for index,score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    #Text->Images
    ranks = np.zeros(scores_t2i.shape[0])

    for index,score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    tr_mean = (tr1 + tr5 + tr10) / 3
    ir_mean = (ir1 + ir5 + ir10) / 3
    r_mean = (tr_mean + ir_mean) / 2

    eval_result =  {'txt_r1': tr1,
                    'txt_r5': tr5,
                    'txt_r10': tr10,
                    'txt_r_mean': tr_mean,
                    'img_r1': ir1,
                    'img_r5': ir5,
                    'img_r10': ir10,
                    'img_r_mean': ir_mean,
                    'r_mean': r_mean}
    return eval_result


In [5]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from PIL import Image

def create_dataset(dataset, config):

    ## image transforms
    train_transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.int8, scale = True),
        v2.Resize(size = (DIMENSION, DIMENSION), antialias = False),
        v2.RandAugment(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(
            mean = [0.5, 0.5, 0.5],
            std =  [0.5, 0.5, 0.5]
        )
    ])
    test_transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.int8, scale = True),
        v2.Resize(size = (DIMENSION, DIMENSION), antialias = False),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(
            mean = [0.5, 0.5, 0.5],
            std =  [0.5, 0.5, 0.5]
        )
        ])

    if dataset=='re':
        train_dataset = re_train_dataset(config['train_file'], train_transform, config['image_root'])
        val_dataset = re_eval_dataset(config['val_file'], test_transform, config['image_root'])
        test_dataset = re_eval_dataset(config['test_file'], test_transform, config['image_root'])
        return train_dataset, val_dataset, test_dataset


def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
    loaders = []
    for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
        if is_train:
            shuffle = (sampler is None)
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )
        loaders.append(loader)
    return loaders

In [6]:
config = {'train_file': ['Jsons/flickr30k_train.json'],
          'val_file': 'Jsons/flickr30k_val.json',
          'test_file': 'Jsons/flickr30k_test.json',
          'image_root': './',
          'image_res': DIMENSION}


train_dataset, val_dataset, test_dataset = create_dataset('re', config)

In [7]:
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size = BATCH_SIZE,
                                               num_workers = NUM_WORKERS,
                                               shuffle = True,
                                               drop_last = True)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size = BATCH_SIZE,
                                         num_workers = NUM_WORKERS,
                                         shuffle = False,
                                         drop_last = False)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                         batch_size = BATCH_SIZE,
                                         num_workers = NUM_WORKERS,
                                         shuffle = False,
                                         drop_last = False)

In [8]:
vit_model = transformers.ViTModel.from_pretrained('WinKawaks/vit-small-patch16-224').to(DEVICE)
bert_model = transformers.BertForMaskedLM.from_pretrained("prajjwal1/bert-small").to(DEVICE)
tokenizer = transformers.AutoTokenizer.from_pretrained("prajjwal1/bert-small")

model = MAMO(
            vit = vit_model,
            bert = bert_model,
            vit_num_patches= 196,
            vit_emb_dim=384,
            bert_emb_dim=512,
            bert_layers=3,
            vocab_size=tokenizer.vocab_size,
            mask_token_id= tokenizer.mask_token_id,
            # cls_token_id=tokenizer.cls_token_id
            ).train().to(DEVICE)

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
chkpt = torch.load(weights_path, map_location=DEVICE)
model.load_state_dict(chkpt)

model = model.to(DEVICE)

In [10]:
#optimiser
optim = torch.optim.AdamW(model.parameters(),
                          lr = lr,
                          weight_decay = decay,
                          betas = [beta1, beta2],
                          )

epoch_steps = math.ceil(len(train_dataset)/BATCH_SIZE)
num_steps = int(EPOCHS * epoch_steps)
warmup_steps = int(warmup_epochs * epoch_steps)

lr_scheduler = CosineLRScheduler(
        optim,
        t_initial=num_steps,
        # t_mul=1.,
        lr_min=min_lr,
        warmup_lr_init = init_lr,
        warmup_t=warmup_steps,
        cycle_limit=1,
        t_in_epochs=False,
    )

In [11]:
# wandB init
wandb.init(
    id = id,# id,
    resume =  'allow',
    project = 'MAMO - Finetuning',
    name = 'MAMO - ViT-S, BERT-S',

    config = {
        'architecture': model_name,
        'dataset':'ImageNet1K',
        'warmup_epochs': warmup_epochs,
        'epochs' : EPOCHS,
        'batch_size': BATCH_SIZE,
        'masking_ratio_img' : 0.25,
        'masking_ratio_itxt' : 0.75,
        'mask_patch_size': 196,
        'image_size' : DIMENSION,
        'optim_params':{
            'optim': 'AdamW',
            'beta1': beta1,
            'beta2': beta2,
            'weight_decay': decay,
            'learning_rate': lr,
        },
        'accumulation_iters': 1,
        'patch_size_mask' : 32,
    },
)

In [12]:
import re
nums = [re.match(r'.*checkpoint_(.*).pth', x).group(1) for x in glob.glob(MODEL_SAVE_PATH+'*.pth')]
# if len(nums) > 0:
#     nums.remove("final")
nums = [int(x) for x in nums]

CHKPT = -1

if len(nums) != 0:
    CHKPT = max(nums)

    load_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, CHKPT)
    chkpt = torch.load(load_path, map_location = {'cuda:1': device, 
                                                  'cuda:0': device})

    model.load_state_dict(chkpt['model'])
    optim.load_state_dict(chkpt['optimizer'])
    # lr_scheduler.load_state_dict(chkpt['scheduler_state_dict'])
    
    print(load_path)
    
    print("loaded earlier settings")

Finetuning/vit_bert_s - normalized/MAMO/checkpoint_6.pth
loaded earlier settings


In [13]:
scaler = torch.cuda.amp.grad_scaler.GradScaler()
itm_loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in range(CHKPT+1, EPOCHS + warmup_epochs):
    num_samples = 0
    ft_loss = 0
    # net all losses
    net_itc_loss = 0
    net_itm_loss = 0
    for idx, data in (pbar := tqdm(enumerate(train_dataloader), total = len(train_dataloader))):
        img, txt, img_idx= data
        text_input = tokenizer(txt, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
        txt, attn_mask = text_input.input_ids, text_input.attention_mask
        # vision
        img = img.to(DEVICE)

        # language
        txt = txt.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)



        # masked modeling real training
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):                           # casting to f16
            # forward step for online network
            img_rep, txt_rep, joint_rep, img_txt_matching = model(img,
                                                                txt,
                                                                attn_mask,
                                                                retrieval = True)
            # ITC loss
            sim, itc_loss = model.get_itc_loss(img_rep, txt_rep)

            #itm loss
            # sample for each image and each text separately
            img_maps, txt_maps = model.get_samples(sim)
            right_samples = torch.arange(0, img_maps.size(0)).to(DEVICE)

            labs_img = (img_maps == right_samples).float().unsqueeze(1)
            labs_txt = (txt_maps == right_samples).float().unsqueeze(1)

            outs_img = model(img, txt[txt_maps], attn_mask[txt_maps], image_text_matching = True)[-1]
            outs_txt = model(img[img_maps], txt, attn_mask, image_text_matching = True)[-1]

            # softmax probabilities
            itm_1 = itm_loss_fn(outs_img, labs_img)
            itm_2 = itm_loss_fn(outs_txt, labs_txt)
            itm_loss = itm_1 + itm_2

            # TOTAL LOSS
            net_loss = (itc_loss) + (itm_loss)

        scaler.scale(net_loss).backward()

        # BACKPROP
        scaler.step(optim)        # fp16
        scaler.update()           # fp16
        optim.zero_grad(set_to_none = True)
        lr_scheduler.step_update(epoch * epoch_steps + idx)

        # update and calc loss
        num_samples+=1

        net_itc_loss+= itc_loss.item()
        net_itm_loss+= itm_loss.item()
        ft_loss+= net_loss.item()
        pbar.set_description(f"Train Loss: {ft_loss/num_samples}")

    train_stats = {'train_loss': ft_loss/num_samples,
                   'itc_loss': net_itc_loss/num_samples,
                   'itm_loss': net_itm_loss/num_samples}    

    # VALIDATION
    score_val_i2t, score_val_t2i, = evaluation(model, val_loader, tokenizer, DEVICE, k=64, max_len = 35)


    val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)


    log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                    **{f'val_{k}': v for k, v in val_result.items()},
                    'epoch': epoch,
                }
    
    save_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, epoch)
    save_obj = {
        'model': model.state_dict(),
        'optimizer': optim.state_dict(),
        # 'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
    }
    torch.save(save_obj, save_path)
    if (epoch-warmup_epochs+1) % 15 == 0:
        wandb.save(save_path)
    wandb.log(log_stats)


  0%|          | 0/1812 [00:00<?, ?it/s]

Train Loss: 0.08660175968177419: 100%|██████████| 1812/1812 [14:50<00:00,  2.04it/s]
Train Loss: 0.049685174257239575: 100%|██████████| 1812/1812 [14:34<00:00,  2.07it/s]
Train Loss: 0.0430078198229966: 100%|██████████| 1812/1812 [14:35<00:00,  2.07it/s]  
Train Loss: 0.04789213587159381: 100%|██████████| 1812/1812 [14:35<00:00,  2.07it/s] 


In [14]:
# testing
score_test_i2t, score_test_t2i = evaluation(model, test_loader, tokenizer, DEVICE, k=64, max_len = 35)
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)

log_stats = {**{f'test_{k}': v for k, v in test_result.items()}}
wandb.log(log_stats)

In [15]:
wandb.finish()



0,1
epoch,▁▃▆█
test_img_r1,▁
test_img_r10,▁
test_img_r5,▁
test_img_r_mean,▁
test_r_mean,▁
test_txt_r1,▁
test_txt_r10,▁
test_txt_r5,▁
test_txt_r_mean,▁

0,1
epoch,10.0
test_img_r1,0.38
test_img_r10,3.08
test_img_r5,1.56
test_img_r_mean,1.67333
test_r_mean,4.32
test_txt_r1,1.0
test_txt_r10,13.7
test_txt_r5,6.2
test_txt_r_mean,6.96667
