In [1]:
import transformers
import torch
import torchvision

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

import tokenizers
import itertools

import random
import math
import copy
from timm.scheduler import CosineLRScheduler

import wandb 
import nltk
nltk.download('stopwords')

# from utils.RegVLM import RegVLM
from utils.hog import HOGLayer

from utils.dataset import pretrain_dataset
from utils.mim_utils import create_masked_image

device = 'cuda:0'


DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'ViT-S BERT-S (fixed everything)'
algo = 'RegVLM'


NUM_WORKERS = 8
torch.set_num_threads(12)


id = wandb.util.generate_id()
wandb.login()

# set earlier ID
# id = ''

print(id)

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to /home/ml/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmadhava20217[0m. Use [1m`wandb login --relogin`[0m to force relogin


xa027m34


In [2]:
DATASET_JSON = 'Jsons/flickr30k_train.json'
MODEL_SAVE_PATH = f'Models/{model_name}/{algo}/checkpoint'

lr = 2.5e-4
init_lr = 1e-6
min_lr = 1e-5
decay = 0.01
beta1 = 0.9
beta2 = 0.999

warmup_epochs = 3
EPOCHS = 12

BATCH_SIZE = 96

MASKING_RATIO_IMG = 0.75
MASKING_RATIO_TXT = 0.25

HOG_BINS = 9

ALPHA = 0.995               # EWMA


n_layers = 2

if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))

In [3]:
DIMENSION = 224

MAX_LEN = 30

# ViT config
tokenizer = transformers.AutoTokenizer.from_pretrained("prajjwal1/bert-small")

In [4]:
import torchvision.transforms.v2 as v2

## image transforms
img_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.int8, scale = True),
    v2.Resize(size = (DIMENSION, DIMENSION), antialias = False),
    v2.RandAugment(),
    # v2.RandomVerticalFlip(),
    # v2.RandomHorizontalFlip(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean = [0.5, 0.5, 0.5],
        std =  [0.5, 0.5, 0.5]
    )
])

In [5]:
dataset = pretrain_dataset(
               ann_file = [DATASET_JSON],
               transform = img_transform,
               tokenizer = tokenizer,
               max_words = MAX_LEN,
               input_size = DIMENSION,
               mask_patch_size = 32,
               model_patch_size = 16,
               masking_ratio = MASKING_RATIO_IMG,
               txt_masking_ratio = MASKING_RATIO_TXT,
               mask_token = tokenizer.mask_token,
               mask_token_id = tokenizer.mask_token_id,
               max_length = MAX_LEN + 5,
               )

In [6]:
# dataloader

pretrain_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size = BATCH_SIZE,
    pin_memory = True,
    num_workers = NUM_WORKERS,
    shuffle = True
)

In [7]:
def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel

def deleteLaterEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(num_layers_to_keep, 0, -1):
        newModuleList.append(oldModuleList[-i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel

class RegVLM_Mixer(torch.nn.Module):
    def __init__(self, base_bert,
                 n_layers = 2,
                 n_visual_tokens = 196,
                 vision_embedding_dim = 384,
                 emb_dims = 512,
                 cls_token_id = 101):
        # prepare decoder
        super().__init__()
        self.cls_token_id = cls_token_id
        self.embedding_module = base_bert.embeddings
        self.n_visual_tokens = n_visual_tokens
        self.vision_emb_dim = vision_embedding_dim
        self.base_model = deleteLaterEncodingLayers(base_bert, n_layers).encoder
        
        self.pooler = torch.nn.AdaptiveAvgPool1d(1)
        self.emb_dimension = emb_dims
        
        if self.vision_emb_dim == self.emb_dimension:
            self.dimension_caster = torch.nn.Identity()
        else:
            self.dimension_caster = torch.nn.Linear(self.vision_emb_dim, self.emb_dimension, bias = False)  # no bias here
        
        
    def forward(self, vision_embedding, text_embedding, text_attn_mask):
        # assert len(vision_embedding) == len(text_embedding)
        n_batch = len(vision_embedding)
        
        cls_emb = self.embedding_module(torch.tensor([[self.cls_token_id]]*n_batch, 
                                                     device = vision_embedding.device),
                                        torch.tensor([[1]]*n_batch,
                                                     device = vision_embedding.device))
        
        # normalize dimensions
        new_vision_emb = self.dimension_caster(vision_embedding[:, 1:, :])   # remove cls token here
        
        # concatenate
        concatenated_emb = torch.cat([cls_emb, new_vision_emb, text_embedding], dim = 1)
        
        # create attention mask
        vision_attention_mask = torch.ones(n_batch, self.n_visual_tokens + 1).to(text_attn_mask.device) # add a cls token here
        attn_mask = torch.cat([vision_attention_mask, text_attn_mask], dim = 1)
        
        attn_mask = attn_mask[:, None, None, :]
        
        # forward
        return self.base_model(concatenated_emb, attn_mask)

In [56]:
class RegVLM(torch.nn.Module):
    def __init__(self,
                 vit,
                 bert,
                 vit_num_patches = 196,
                 vit_emb_dim = 384,
                 bert_emb_dim = 512,
                 bert_layers = 2,
                 vocab_size = 30522,
                 mask_token_id = 103,
                 cls_token_id = 101,
                 tau = None):
       super().__init__()
       self.vit = vit.vit
       self.mim_reconstruction = vit.decoder
       self.bert = bert.base_model
       self.bert = deleteEncodingLayers(self.bert.base_model, bert_layers)
       self.fusion = RegVLM_Mixer(bert.base_model,
                              n_layers = bert_layers,
                              n_visual_tokens=vit_num_patches,
                              vision_embedding_dim=vit_emb_dim,
                              emb_dims = bert_emb_dim,
                              cls_token_id = cls_token_id)
       
       # vit patches data
       self.vit_num_patches = vit_num_patches
       
       # vocab size
       self.vocab_size = vocab_size
       # mask token
       self.mask_token_id = mask_token_id
       
       # learnable temperature parameter
       self.tau = torch.nn.Parameter(torch.FloatTensor([0.07]))      # uniform in range 1 to 5
       if tau is not None:
           self.tau = torch.nn.Parameter(torch.FloatTensor([tau]))      # uniform in range 1 to 5
       
       self.tau.requires_grad = True

       # joint representation
       self.pooler = torch.nn.Sequential(
           torch.nn.AdaptiveAvgPool1d(1),
           torch.nn.Flatten()
       )
       self.img_proj = torch.nn.Linear(vit_emb_dim, min(vit_emb_dim, bert_emb_dim))
       self.txt_proj = torch.nn.Linear(bert_emb_dim, min(vit_emb_dim, bert_emb_dim))

       
       # masked representation modeling
       self.mrm_proj = torch.nn.Sequential(
            torch.nn.Linear(bert_emb_dim, bert_emb_dim),
            torch.nn.Tanh(),
       )
       
       # head for masked image modeling
       self.mim_proj = torch.nn.Sequential(
           torch.nn.Linear(bert_emb_dim, vit_emb_dim),
       )
        
       # head for masked language modeling
       self.mlm_head_joint = copy.deepcopy(bert.cls)
       self.mlm_head = bert.cls
       
       self.itc_head = torch.nn.Linear(bert_emb_dim, bert_emb_dim)
       
       self.itm__head = torch.nn.Sequential(
           torch.nn.Linear(bert_emb_dim, bert_emb_dim),
           torch.nn.LeakyReLU(),
           torch.nn.Linear(bert_emb_dim, 1)
       )
    
       
       
    def forward(self, image, text, attn_mask,
                masked_pos_img = None,
                masked_text = None,
                image_text_matching = False,
                retrieval = False,
                ):
        
        if retrieval is True:
            img_rep = self.vit(image)['last_hidden_state']
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']
            joint_rep = self.fusion(img_rep, txt_rep, attn_mask)['last_hidden_state']
            
            
            # img_rep = self.img_proj(self.pooler(img_rep.transpose(1,2)))
            # txt_rep = self.txt_proj(self.pooler(txt_rep.transpose(1,2)))
            return img_rep, txt_rep, joint_rep, self.itm__head(joint_rep[:, 0, :])
        
        if image_text_matching == True:
            img_rep = self.vit(image)['last_hidden_state']
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']
            joint_rep = self.fusion(img_rep, txt_rep, attn_mask)['last_hidden_state']
            
            return img_rep, txt_rep, joint_rep, self.itm__head(joint_rep[:, 0, :])
        
        else:
            # return mask_img-clean_txt, clean_img,-mask_txt, 
            img_rep = self.vit(image)['last_hidden_state']              # clean image
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']   # clean text
            
            mask_img_rep = self.vit(image, bool_masked_pos = masked_pos_img)['last_hidden_state']
            mask_txt_rep = self.bert(masked_text, attn_mask)['last_hidden_state']
            
            # multimodal prediction
            c_img_m_txt = self.fusion(img_rep, mask_txt_rep, attn_mask)['last_hidden_state']
            m_img_c_txt = self.fusion(mask_img_rep, txt_rep, attn_mask)['last_hidden_state']
            
            # pure txt
            txt_prediction = self.mlm_head(mask_txt_rep)
            
            # pure fusion
            img_txt_joint = self.fusion(img_rep, txt_rep, attn_mask)['last_hidden_state']
        
            return (c_img_m_txt, m_img_c_txt, img_txt_joint, mask_img_rep, txt_prediction, img_rep, txt_rep)
    
    def get_mrm_loss(self, online_representation, target_representation, mask):
        # remove cls token
        on_rep = torch.nn.functional.normalize(self.mrm_proj(online_representation[:, 1:, :]), dim = -1)
        tr_rep = torch.nn.functional.normalize(self.mrm_proj(target_representation[:, 1:, :]), dim = -1)

        
        loss = torch.nn.functional.mse_loss(on_rep, tr_rep, reduction = 'none')
        mrm_loss = (loss * mask).sum()/(mask.sum() + 1e-5)              # add for 0 division errors
        return mrm_loss
    
    def get_joint_mim_loss(self, joint_rep, img, mask):
        joint_rep = joint_rep[:, :self.vit_num_patches+1, :]
        joint_rep = self.mim_proj(joint_rep)                    # bs x vit_npatch x vit dim
        sequence_output = joint_rep

        # Reshape to (batch_size, num_channels, height, width)
        sequence_output = sequence_output[:, 1:]
        batch_size, sequence_length, num_channels = sequence_output.shape
        height = width = math.floor(sequence_length**0.5)
        sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)

        # Reconstruct pixel values
        reconstructed_pixel_values = self.decoder(sequence_output)

        masked_im_loss = None
        if bool_masked_pos is not None:
            size = self.config.image_size // self.config.patch_size
            bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
            mask = (
                bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
                .repeat_interleave(self.config.patch_size, 2)
                .unsqueeze(1)
                .contiguous()
            )
            reconstruction_loss = torch.nn.functional.l1_loss(img, reconstructed_pixel_values, reduction="none")
            masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / 3

        return masked_im_loss
    
    
    def get_joint_mlm_loss(self, )
    
    def get_mlm_loss(self, scores, sen, masked_sen):
        labels = torch.where(masked_sen == self.mask_token_id, sen, -100)
        loss = torch.nn.functional.cross_entropy(scores.view(-1, self.vocab_size), labels.view(-1), ignore_index=-100)
        
        return loss
    
    
    def get_itc_loss(self, img_feats, txt_feats):
        # Calculate similarity
        with torch.no_grad():
            self.tau.clamp_(0.001,0.5)

        # pool and flatten text and visual features obtained before fusion
        img_feats = self.img_proj(self.pooler(img_feats.transpose(1,2)))
        txt_feats = self.txt_proj(self.pooler(txt_feats.transpose(1,2)))
        
        
        sim = (img_feats@txt_feats.T)/self.tau
        # sim = torch.clip(sim, max = 1e4, min = 1e-4)
        self_mask = torch.eye(sim.shape[0], device=sim.device)
        
        loss_i2t = -torch.sum(torch.nn.functional.log_softmax(sim, dim = 1)*self_mask, dim = 1).mean()
        loss_t2i = -torch.sum(torch.nn.functional.log_softmax(sim.T, dim = 1)*self_mask, dim = 1).mean()

        return sim, (loss_i2t+loss_t2i)/2.0
    
    def get_samples(self, similarities):
        probs = torch.nn.functional.softmax(similarities, dim = 1)
        probs = probs.fill_diagonal_(0)         # eliminate full samples
        txt_indices = torch.multinomial(probs, num_samples=1, replacement=True).squeeze(1)
        img_indices = torch.multinomial(probs.T, num_samples=1, replacement=True).squeeze(1)
        
        return txt_indices, img_indices

SyntaxError: expected ':' (3336982678.py, line 142)

In [42]:
vit_model = transformers.ViTForMaskedImageModeling.from_pretrained('WinKawaks/vit-small-patch16-224').to(DEVICE)
bert_model = transformers.BertForMaskedLM.from_pretrained("prajjwal1/bert-small")

online_network = RegVLM(
                    vit = vit_model,
                    bert = bert_model,
                    vit_num_patches= 196,
                    vit_emb_dim=384,
                    bert_emb_dim=512,
                    bert_layers=n_layers,
                    vocab_size=tokenizer.vocab_size,
                    mask_token_id= tokenizer.mask_token_id,
                    # cls_token_id=tokenizer.cls_token_id
                ).train().to(DEVICE)

Some weights of ViTForMaskedImageModeling were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['decoder.0.bias', 'decoder.0.weight', 'vit.embeddings.mask_token']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
for data in pretrain_dataloader:
    img, img_mask, txt, attn_mask, masked_toks, masked_attn_mask, mask_indices = data
    img, img_mask, txt, attn_mask, masked_toks, masked_attn_mask, mask_indices = data
        
    # vision
    img = img.to(DEVICE)
    img_mask = img_mask.to(DEVICE)
        
    # language
    txt = txt.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)
    masked_toks = masked_toks.to(DEVICE)
    masked_attn_mask = masked_attn_mask.to(DEVICE)

    # indices for masked text: will be used for masked modeling
    mask_indices = mask_indices.float().to(DEVICE)
    
    # # masked image
    # masked_image = create_masked_image(img, img_mask)
    flattened_img_mask = img_mask.flatten(1).to(DEVICE)
    
    break

In [44]:
flattened_img_mask.bool()

tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True, False, False],
        ...,
        [ True,  True,  True,  ..., False,  True,  True],
        [False, False, False,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True,  True]], device='cuda:0')

In [54]:
hog_transform = HOGLayer(HOG_BINS, pool = 16, groupp).to(DEVICE)

In [62]:
with torch.no_grad():
    # need to do masked image modeling with text supervision 
    outs = online_network.vit(img, bool_masked_pos =  flattened_img_mask)[0]
    txt_rep = online_network.bert(txt, attn_mask)[0]  # clean text
    joint = online_network.fusion(outs, txt_rep, attn_mask)['last_hidden_state']

joint[:, :online_network.vit_num_patches+1, :].shape

torch.Size([96, 197, 512])

In [52]:
outs.shape

torch.Size([96, 197, 384])

In [20]:
outs

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 3.3762e+00,  1.1898e+00,  4.9747e+00,  ..., -9.4658e-01,
          -6.3289e-01,  9.4220e-01],
         [ 9.2341e-02,  1.5397e+00,  1.6549e+00,  ...,  7.8462e-01,
          -2.2992e+00,  3.3514e+00],
         [-1.7680e-01,  1.8329e+00,  2.6082e+00,  ...,  1.5202e+00,
          -2.7215e+00,  4.5280e+00],
         ...,
         [-1.1351e+00,  9.9755e-01,  4.0025e+00,  ...,  4.1630e-01,
          -7.8341e-02,  3.6474e+00],
         [-1.7787e+00,  1.5458e+00,  3.9983e+00,  ...,  5.5988e-01,
           2.3629e-01,  3.4860e+00],
         [-7.9056e-01,  8.4277e-01,  4.3316e+00,  ...,  6.4973e-02,
           1.6623e-01,  3.6845e+00]],

        [[ 2.2669e-01,  3.1071e-01, -2.0778e-02,  ..., -7.0768e-01,
           5.6413e-01,  2.4698e+00],
         [-2.5199e+00, -2.2083e+00,  4.4909e-01,  ...,  1.0419e-01,
           3.6589e-01,  3.7860e+00],
         [ 6.1990e-01, -6.0655e-01,  3.0824e-01,  ...,  3.1788e+00,
          -3.6679e-01,  2.2453e

In [25]:
hoglayer = HOGLayer(nbins = HOG_BINS, pool = 16)

In [None]:
# utils for target network

# freeze weights
def freeze_weights(nw):
    for param in nw.parameters():
        param.requires_grad = False
        
    return nw
    
def ewma_weights(target, current, alpha = 0.995):
    sdA = target.state_dict()
    sdB = current.state_dict()
    
    for key in sdA:
        sdA[key] = alpha*sdA[key] + (1-alpha)*sdB[key]
    
    target.load_state_dict(sdA)
    return target



In [None]:
target_network = copy.deepcopy(online_network)

In [None]:
#optimiser
optim = torch.optim.AdamW(online_network.parameters(),
                          lr = lr,
                          weight_decay = decay,
                          betas = [0.9, 0.999],
                          )

epoch_steps = math.ceil(len(dataset)/BATCH_SIZE)
num_steps = int(EPOCHS * epoch_steps)
warmup_steps = int(warmup_epochs * epoch_steps)

lr_scheduler = CosineLRScheduler(
        optim,
        t_initial=num_steps,
        # t_mul=1.,
        lr_min=min_lr,
        warmup_lr_init = init_lr,
        warmup_t=warmup_steps,
        cycle_limit=1,
        t_in_epochs=False,
    )


# # wandB init
# wandb.init(
#     id = id,# id,
#     resume =  'allow',
#     project = 'MAMO - Pretrain',
#     name = 'MAMO - ViT-S, BERT-S',

#     config = {
#         'architecture': model_name,
#         'dataset':'ImageNet1K',
#         'warmup_epochs': warmup_epochs,
#         'epochs' : EPOCHS,
#         'batch_size': BATCH_SIZE,
#         'masking_ratio_img' : MASKING_RATIO_IMG,
#         'masking_ratio_itxt' : MASKING_RATIO_TXT,
#         'mask_patch_size': 196,
#         'image_size' : DIMENSION,
#         'optim_params':{
#             'optim': 'AdamW',
#             'beta1': beta1,
#             'beta2': beta2,
#             'weight_decay': decay,
#             'learning_rate': lr,
#         },
#         'accumulation_iters': 1,
#         'patch_size_mask' : 32,
#         'alpha_ewma': ALPHA,
#     },
# )

In [None]:
import re
nums = [re.match(r'.*checkpoint_(.*).pth', x).group(1) for x in glob.glob(MODEL_SAVE_PATH+'*.pth')]
# if len(nums) > 0:
#     nums.remove("final")
nums = [int(x) for x in nums]

CHKPT = -1

if len(nums) != 0:
    CHKPT = max(nums)

    load_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, CHKPT)
    chkpt = torch.load(load_path, map_location = {'cuda:1': device, 
                                                  'cuda:0': device})

    online_network.load_state_dict(chkpt['online_model_state_dict'])
    target_network.load_state_dict(chkpt['target_model_state_dict'])
    optim.load_state_dict(chkpt['optim_state_dict'])
    # lr_scheduler.load_state_dict(chkpt['scheduler_state_dict'])
    
    print(load_path)
    
    print("loaded earlier settings")

In [None]:
target_network = freeze_weights(target_network).to(DEVICE).eval()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
itm_loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in range(CHKPT+1, EPOCHS + warmup_epochs):
    num_samples = 0
    pretrain_loss = 0
    # net all losses
    net_mrm_loss = 0
    net_mim_loss = 0
    net_mlm_loss = 0
    net_itc_loss = 0
    net_itm_loss = 0
    for idx, data in (pbar := tqdm(enumerate(pretrain_dataloader), total = len(pretrain_dataloader))):
        img, img_mask, txt, attn_mask, masked_toks, masked_attn_mask, mask_indices = data
        
        # vision
        img = img.to(DEVICE)
        img_mask = img_mask.to(DEVICE)
        
        # language
        txt = txt.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)
        masked_toks = masked_toks.to(DEVICE)
        masked_attn_mask = masked_attn_mask.to(DEVICE)

        # indices for masked text: will be used for masked modeling
        mask_indices = mask_indices.float().to(DEVICE)
        
        # # masked image
        # masked_image = create_masked_image(img, img_mask)
        flattened_img_mask = img_mask.float().flatten(1)
        
        # create masks for joint representation modeling
        img_rep_masks = torch.cat([flattened_img_mask, torch.zeros_like(mask_indices)], axis = 1).unsqueeze(-1)
        txt_rep_masks = torch.cat([torch.zeros_like(flattened_img_mask), mask_indices], axis = 1).unsqueeze(-1)
        
        
        
        # masked modeling pretraining
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):                           # casting to f16
            
            
            # forward step for target network
            with torch.no_grad():
                target_img_rep, target_txt_rep, target_mm_rep, target_itm = target_network(img,
                                                                                           txt,
                                                                                           attn_mask,
                                                                                           image_text_matching = True)
            
            # forward step for online network
            c_img_m_txt, m_img_c_txt, img_txt_joint, mask_img_rep, txt_prediction, img_rep, txt_rep = online_network(img,
                                                                                                      txt,
                                                                                                      attn_mask,
                                                                                                      image_text_matching = False,
                                                                                                      masked_pos = mask_indices,
                                                                                                      masked_text = masked_toks)


            # MRM loss
            mrm_loss_txt = online_network.get_mrm_loss(c_img_m_txt, target_mm_rep, txt_rep_masks)
            mrm_loss_img = online_network.get_mrm_loss(m_img_c_txt, target_mm_rep, img_rep_masks)
            
            # MIM loss
            mim_loss = online_network.get_mim_loss(m_img_c_txt, target_img_rep, flattened_img_mask)
            
            # MLM loss
            mlm_loss = online_network.get_mlm_loss(txt_prediction, txt, masked_toks)
            
            
            # ITC loss
            sim, itc_loss = online_network.get_itc_loss(img_rep, txt_rep)
            
            #itm loss
            # sample for each image and each text separately
            neg_txt, neg_img = online_network.get_samples(sim)
            
            itm_labels = torch.cat([torch.ones(len(img)),torch.zeros(2*len(img))],
                               dim=0).unsqueeze(1).float().to(DEVICE)
            # stack 
            itm_img_feats = torch.vstack([img_rep, img_rep[neg_img]])
            itm_txt_feats = torch.vstack([txt_rep[neg_txt], txt_rep])
            itm_txt_attn = torch.vstack([attn_mask[neg_txt], attn_mask])

            joint_rep_negs = online_network.mamo(itm_img_feats, itm_txt_feats, itm_txt_attn)['last_hidden_state']
            combined_mamo_reps = torch.vstack([img_txt_joint, joint_rep_negs])
            
            itm_outputs = online_network.itm__head(combined_mamo_reps[:, 0, :])
            
            
            # softmax probabilities
            itm_loss = itm_loss_fn(itm_outputs, itm_labels)
            
            # TOTAL LOSS
            net_loss = (mrm_loss_img + mrm_loss_txt) + (mim_loss) + (mlm_loss) + (itc_loss) + (itm_loss)
            
        scaler.scale(net_loss).backward()
        scaler.unscale_(optim)
        torch.nn.utils.clip_grad_norm_(online_network.parameters(), 1.)
            
        # BACKPROP
        scaler.step(optim)        # fp16
        scaler.update()           # fp16
        optim.zero_grad(set_to_none = True)
        lr_scheduler.step_update(epoch * epoch_steps + idx)
        
        # update and calc loss
        num_samples+=1
        net_mrm_loss+= mrm_loss_img.item() + mrm_loss_txt.item()
        net_mim_loss+= mim_loss.item()
        net_mlm_loss+= mlm_loss.item()
        net_itc_loss+= itc_loss.item()
        net_itm_loss+= itm_loss.item()
        pretrain_loss+= net_loss.item()
        pbar.set_description(f"Train Loss: {pretrain_loss/num_samples}")
        
        
        # EWMA for weights
        target_network = ewma_weights(target_network, online_network, alpha = ALPHA)
        
        



    wandb.log({
            'epoch': epoch,
            'pretrain_loss': pretrain_loss/num_samples,
            'mrm_loss': net_mrm_loss/num_samples,
            'mim_loss': net_mim_loss/num_samples,
            'mlm_loss': net_mlm_loss/num_samples,
            'itc_loss': net_itc_loss/num_samples,
            'itm_loss': net_itm_loss/num_samples,
    })
    save_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, epoch)
    torch.save(
            {
            'epoch': epoch,
            'online_model_state_dict': online_network.state_dict(),
            'target_model_state_dict': target_network.state_dict(),
            'optim_state_dict': optim.state_dict()
            },
        save_path
        )
    if (epoch-warmup_epochs+1) % 10 == 0:
        wandb.save(save_path)

In [None]:
save_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, 'final')
torch.save(online_network.state_dict(), save_path)
wandb.save(save_path)

In [None]:
wandb.finish()