In [None]:
import transformers
import torch
import torchvision

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

import tokenizers
import itertools

import random
import math
import copy
from timm.scheduler import CosineLRScheduler

import wandb 
import nltk
nltk.download('stopwords')

from utils.MAMO import MAMO

from utils.dataset import pretrain_dataset
from utils.mim_utils import create_masked_image

device = 'cuda:0'


DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'ViT-S BERT-S (fixed everything)'
algo = 'MAMO'


NUM_WORKERS = 8
torch.set_num_threads(12)


id = wandb.util.generate_id()
wandb.login()

# set earlier ID
# id = ''

print(id)

In [None]:
DATASET_JSON = 'Jsons/flickr30k_train.json'
MODEL_SAVE_PATH = f'Models/{model_name}/{algo}/checkpoint'

lr = 2.5e-4
init_lr = 1e-6
min_lr = 1e-5
decay = 0.01
beta1 = 0.9
beta2 = 0.999

warmup_epochs = 3
EPOCHS = 12

BATCH_SIZE = 96

MASKING_RATIO_IMG = 0.75
MASKING_RATIO_TXT = 0.25

ALPHA = 0.995               # EWMA


n_layers = 2

if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))

In [None]:
DIMENSION = 224

MAX_LEN = 30

# ViT config
tokenizer = transformers.AutoTokenizer.from_pretrained("prajjwal1/bert-small")

In [None]:
import torchvision.transforms.v2 as v2

## image transforms
img_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.int8, scale = True),
    v2.Resize(size = (DIMENSION, DIMENSION), antialias = False),
    v2.RandAugment(),
    # v2.RandomVerticalFlip(),
    # v2.RandomHorizontalFlip(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean = [0.5, 0.5, 0.5],
        std =  [0.5, 0.5, 0.5]
    )
])

In [None]:
dataset = pretrain_dataset(
               ann_file = [DATASET_JSON],
               transform = img_transform,
               tokenizer = tokenizer,
               max_words = MAX_LEN,
               input_size = DIMENSION,
               mask_patch_size = 32,
               model_patch_size = 16,
               masking_ratio = MASKING_RATIO_IMG,
               txt_masking_ratio = MASKING_RATIO_TXT,
               mask_token = tokenizer.mask_token,
               mask_token_id = tokenizer.mask_token_id,
               max_length = MAX_LEN + 5,
               )

In [None]:
# dataloader

pretrain_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size = BATCH_SIZE,
    pin_memory = True,
    num_workers = NUM_WORKERS,
    shuffle = True
)

In [None]:
vit_model = transformers.ViTModel.from_pretrained('WinKawaks/vit-small-patch16-224').to(DEVICE)
bert_model = transformers.BertForMaskedLM.from_pretrained("prajjwal1/bert-small")

online_network = MAMO(
                    vit = vit_model,
                    bert = bert_model,
                    vit_num_patches= 196,
                    vit_emb_dim=384,
                    bert_emb_dim=512,
                    bert_layers=n_layers,
                    vocab_size=tokenizer.vocab_size,
                    mask_token_id= tokenizer.mask_token_id,
                    # cls_token_id=tokenizer.cls_token_id
                ).train().to(DEVICE)

In [None]:
# utils for target network

# freeze weights
def freeze_weights(nw):
    for param in nw.parameters():
        param.requires_grad = False
        
    return nw
    
def ewma_weights(target, current, alpha = 0.995):
    sdA = target.state_dict()
    sdB = current.state_dict()
    
    for key in sdA:
        sdA[key] = alpha*sdA[key] + (1-alpha)*sdB[key]
    
    target.load_state_dict(sdA)
    return target



In [None]:
target_network = copy.deepcopy(online_network)

In [None]:
#optimiser
optim = torch.optim.AdamW(online_network.parameters(),
                          lr = lr,
                          weight_decay = decay,
                          betas = [0.9, 0.999],
                          )

epoch_steps = math.ceil(len(dataset)/BATCH_SIZE)
num_steps = int(EPOCHS * epoch_steps)
warmup_steps = int(warmup_epochs * epoch_steps)

lr_scheduler = CosineLRScheduler(
        optim,
        t_initial=num_steps,
        # t_mul=1.,
        lr_min=min_lr,
        warmup_lr_init = init_lr,
        warmup_t=warmup_steps,
        cycle_limit=1,
        t_in_epochs=False,
    )


# wandB init
wandb.init(
    id = id,# id,
    resume =  'allow',
    project = 'MAMO - Pretrain',
    name = 'MAMO - ViT-S, BERT-S',

    config = {
        'architecture': model_name,
        'dataset':'ImageNet1K',
        'warmup_epochs': warmup_epochs,
        'epochs' : EPOCHS,
        'batch_size': BATCH_SIZE,
        'masking_ratio_img' : MASKING_RATIO_IMG,
        'masking_ratio_itxt' : MASKING_RATIO_TXT,
        'mask_patch_size': 196,
        'image_size' : DIMENSION,
        'optim_params':{
            'optim': 'AdamW',
            'beta1': beta1,
            'beta2': beta2,
            'weight_decay': decay,
            'learning_rate': lr,
        },
        'accumulation_iters': 1,
        'patch_size_mask' : 32,
        'alpha_ewma': ALPHA,
    },
)

In [None]:
import re
nums = [re.match(r'.*checkpoint_(.*).pth', x).group(1) for x in glob.glob(MODEL_SAVE_PATH+'*.pth')]
# if len(nums) > 0:
#     nums.remove("final")
nums = [int(x) for x in nums]

CHKPT = -1

if len(nums) != 0:
    CHKPT = max(nums)

    load_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, CHKPT)
    chkpt = torch.load(load_path, map_location = {'cuda:1': device, 
                                                  'cuda:0': device})

    online_network.load_state_dict(chkpt['online_model_state_dict'])
    target_network.load_state_dict(chkpt['target_model_state_dict'])
    optim.load_state_dict(chkpt['optim_state_dict'])
    # lr_scheduler.load_state_dict(chkpt['scheduler_state_dict'])
    
    print(load_path)
    
    print("loaded earlier settings")

In [None]:
target_network = freeze_weights(target_network).to(DEVICE).eval()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
itm_loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in range(CHKPT+1, EPOCHS + warmup_epochs):
    num_samples = 0
    pretrain_loss = 0
    # net all losses
    net_mrm_loss = 0
    net_mim_loss = 0
    net_mlm_loss = 0
    net_itc_loss = 0
    net_itm_loss = 0
    for idx, data in (pbar := tqdm(enumerate(pretrain_dataloader), total = len(pretrain_dataloader))):
        img, img_mask, txt, attn_mask, masked_toks, masked_attn_mask, mask_indices = data
        
        # vision
        img = img.to(DEVICE)
        img_mask = img_mask.to(DEVICE)
        
        # language
        txt = txt.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)
        masked_toks = masked_toks.to(DEVICE)
        masked_attn_mask = masked_attn_mask.to(DEVICE)

        # indices for masked text: will be used for masked modeling
        mask_indices = mask_indices.float().to(DEVICE)
        
        # masked image
        masked_image = create_masked_image(img, img_mask)
        flattened_img_mask = img_mask.float().flatten(1)
        
        # create masks for joint representation modeling
        img_rep_masks = torch.cat([flattened_img_mask, torch.zeros_like(mask_indices)], axis = 1).unsqueeze(-1)
        txt_rep_masks = torch.cat([torch.zeros_like(flattened_img_mask), mask_indices], axis = 1).unsqueeze(-1)
        
        
        
        # masked modeling pretraining
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):                           # casting to f16
            
            
            # forward step for target network
            with torch.no_grad():
                target_img_rep, target_txt_rep, target_mm_rep, target_itm = target_network(img,
                                                                                           txt,
                                                                                           attn_mask,
                                                                                           image_text_matching = True)
            
            # forward step for online network
            c_img_m_txt, m_img_c_txt, img_txt_joint, mask_img_rep, txt_prediction, img_rep, txt_rep = online_network(img,
                                                                                                      txt,
                                                                                                      attn_mask,
                                                                                                      image_text_matching = False,
                                                                                                      masked_image = masked_image,
                                                                                                      masked_text = masked_toks)


            # MRM loss
            mrm_loss_txt = online_network.get_mrm_loss(c_img_m_txt, target_mm_rep, txt_rep_masks)
            mrm_loss_img = online_network.get_mrm_loss(m_img_c_txt, target_mm_rep, img_rep_masks)
            
            # MIM loss
            mim_loss = online_network.get_mim_loss(m_img_c_txt, target_img_rep, flattened_img_mask)
            
            # MLM loss
            mlm_loss = online_network.get_mlm_loss(txt_prediction, txt, masked_toks)
            
            
            # ITC loss
            sim, itc_loss = online_network.get_itc_loss(img_rep, txt_rep)
            
            #itm loss
            # sample for each image and each text separately
            neg_txt, neg_img = online_network.get_samples(sim)
            
            itm_labels = torch.cat([torch.ones(len(img)),torch.zeros(2*len(img))],
                               dim=0).unsqueeze(1).float().to(DEVICE)
            # stack 
            itm_img_feats = torch.vstack([img_rep, img_rep[neg_img]])
            itm_txt_feats = torch.vstack([txt_rep[neg_txt], txt_rep])
            itm_txt_attn = torch.vstack([attn_mask[neg_txt], attn_mask])

            joint_rep_negs = online_network.mamo(itm_img_feats, itm_txt_feats, itm_txt_attn)['last_hidden_state']
            combined_mamo_reps = torch.vstack([img_txt_joint, joint_rep_negs])
            
            itm_outputs = online_network.itm__head(combined_mamo_reps[:, 0, :])
            
            
            # softmax probabilities
            itm_loss = itm_loss_fn(itm_outputs, itm_labels)
            
            # TOTAL LOSS
            net_loss = (mrm_loss_img + mrm_loss_txt) + (mim_loss) + (mlm_loss) + (itc_loss) + (itm_loss)
            
        scaler.scale(net_loss).backward()
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(online_network.parameters(), 1.)
            
        # BACKPROP
        scaler.step(optim)        # fp16
        scaler.update()           # fp16
        optim.zero_grad(set_to_none = True)
        lr_scheduler.step_update(epoch * epoch_steps + idx)
        
        # update and calc loss
        num_samples+=1
        net_mrm_loss+= mrm_loss_img.item() + mrm_loss_txt.item()
        net_mim_loss+= mim_loss.item()
        net_mlm_loss+= mlm_loss.item()
        net_itc_loss+= itc_loss.item()
        net_itm_loss+= itm_loss.item()
        pretrain_loss+= net_loss.item()
        pbar.set_description(f"Train Loss: {pretrain_loss/num_samples}")
        
        
        # EWMA for weights
        target_network = ewma_weights(target_network, online_network, alpha = ALPHA)
        
        



    wandb.log({
            'epoch': epoch,
            'pretrain_loss': pretrain_loss/num_samples,
            'mrm_loss': net_mrm_loss/num_samples,
            'mim_loss': net_mim_loss/num_samples,
            'mlm_loss': net_mlm_loss/num_samples,
            'itc_loss': net_itc_loss/num_samples,
            'itm_loss': net_itm_loss/num_samples,
    })
    save_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, epoch)
    torch.save(
            {
            'epoch': epoch,
            'online_model_state_dict': online_network.state_dict(),
            'target_model_state_dict': target_network.state_dict(),
            'optim_state_dict': optim.state_dict()
            },
        save_path
        )
    if (epoch-warmup_epochs+1) % 10 == 0:
        wandb.save(save_path)

In [None]:
save_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, 'final')
torch.save(online_network.state_dict(), save_path)
wandb.save(save_path)

In [None]:
wandb.finish()