In [1]:
import transformers
import torch
import torchvision

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

import tokenizers
import itertools
import matplotlib.pyplot as plt

import random
import math
import copy
from timm.scheduler import CosineLRScheduler

import wandb


device = 'cuda:0'


DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'vit_bert_s'
algo = 'MAMO'


NUM_WORKERS = 0
torch.set_num_threads(12)


id = wandb.util.generate_id()
wandb.login()

# set earlier ID
id = 'ci3xdvn5'

print(id)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [None]:
DATASET_SRC = '../Datasets/Flickr30k/'
MODEL_SAVE_PATH = f'Models/{model_name}/{algo}/checkpoint'

VOCAB_PATH = 'Vocabulary/flickr30k.vocab'

lr = 3e-4
init_lr = 1e-6
min_lr = 1e-5
decay = 0.01
beta1 = 0.9
beta2 = 0.999

warmup_epochs = 5
EPOCHS = 60

BATCH_SIZE = 64

MASKING_RATIO_IMG = 0.75
MASKING_RATIO_TXT = 0.25

ALPHA = 0.995               # EWMA


n_layers = 2

if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))

In [None]:
def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel

def deleteLaterEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(num_layers_to_keep, 0, -1):
        newModuleList.append(oldModuleList[-i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel


def get_bert_model(model, num_layers):
    return deleteEncodingLayers(model, num_layers)



class MAMO_mixer(torch.nn.Module):
    def __init__(self, base_bert, n_layers = 2, n_visual_tokens = 197, vision_embedding_dim = 384, emb_dims = 512):
        # prepare decoder
        super().__init__()
        self.n_visual_tokens = n_visual_tokens
        self.vision_emb_dim = vision_embedding_dim
        self.base_model = deleteLaterEncodingLayers(base_bert.base_model, n_layers).encoder
        
        self.pooler = torch.nn.AdaptiveAvgPool1d(1)
        self.emb_dimension = emb_dims
        
        if self.vision_emb_dim == self.emb_dimension:
            self.dimension_caster = torch.nn.Identity()
        else:
            self.dimension_caster = torch.nn.Linear(self.vision_emb_dim, self.emb_dimension, bias = False)  # no bias here
        
        
    def forward(self, vision_embedding, text_embedding, text_attn_mask):
        # assert len(vision_embedding) == len(text_embedding)
        n_batch = len(vision_embedding)
        
        # normalize dimensions
        new_vision_emb = self.dimension_caster(vision_embedding)
        
        # concatenate
        concatenated_emb = torch.cat([new_vision_emb, text_embedding], dim = 1)
        
        # create attention mask
        vision_attention_mask = torch.ones(n_batch, self.n_visual_tokens).to(text_attn_mask.device)
        attn_mask = torch.cat([vision_attention_mask, text_attn_mask], dim = 1)
        
        attn_mask = attn_mask[:, None, None, :]
        
        # forward
        return self.base_model(concatenated_emb, attn_mask)

In [None]:
DIMENSION = 224

MAX_LEN = 50

SAVE_PATH = 'Models/Mamo_S/checkpoint_'
os.makedirs(SAVE_PATH, exist_ok=True)

# ViT config
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained('WinKawaks/vit-small-patch16-224')
tokenizer = transformers.AutoTokenizer.from_pretrained("prajjwal1/bert-small")



In [None]:
import torchvision.transforms.v2 as v2

## image transforms
img_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.int8, scale = True),
    v2.Resize(size = (DIMENSION, DIMENSION), antialias = False),
    v2.RandAugment(),
    # v2.RandomVerticalFlip(),
    # v2.RandomHorizontalFlip(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean = [0.5, 0.5, 0.5],
        std =  [0.5, 0.5, 0.5]
    )
])

In [None]:
import nltk

class MaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio
        
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        
        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size
        
        self.token_count = self.rand_size ** 2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
        
    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1
        
        mask = mask.reshape((self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
        
        return mask

unfold_dim = int(DIMENSION/16)
fold_params = {'kernel_size' : unfold_dim,
              'dilation': 16}
# utils
def get_patches(batch, num_patches, size_patch):
    '''Function to get patches from an image
    
    Arguments:
    1. batch: batch of tensor images (N, C, H, W)
    2. num_patches: number of patches per side
    3. size_patch: number of pixels along each side of the patch
    
    Returns:
    Tensor containing patch-wise representation of each image in the batch
    '''
    return torch.nn.functional.unfold(
        batch,
        **fold_params
        ).reshape(batch.size(0), -1, num_patches**2, size_patch**2).permute(0,2,1,3)

def get_img_from_patches(batch, num_patches, size_patch):
    '''Function to reconstruct an image from its patches
    
    Arguments:
    1. batch: batch of tensor images (N, C, H, W)
    2. num_patches: number of patches per side
    3. size_patch: number of pixels along each side of the patch
    
    Returns:
    tensor containing reconstructed images'''
    # return fold(batch.permute(0,2,1,3).reshape(batch.size(0), -1, size_patch**2))
    return torch.nn.functional.fold(
        batch.permute(0,2,1,3).reshape(batch.size(0), -1, size_patch**2),
        output_size = (DIMENSION, DIMENSION),
        **fold_params
    )
    
class TextMaskGenerator:
    def __init__(self, masking_ratio = 0.25, mask_token = '[MASK]'):
        self.masking_ratio = masking_ratio
        self.mask_token = mask_token
        
    def __call__(self, text):
        text = np.array(text.split())  # tokenized
        len_txt = len(text)
        
        n_to_mask = math.ceil(len_txt * self.masking_ratio)
        rankings = np.random.randn(len_txt)
        
        indices = np.argpartition(rankings, -n_to_mask)[-n_to_mask:]
        text[indices] = self.mask_token
        
        
        
        return " ".join(text)
        

class Flickr30K_MAMO(torchvision.datasets.Flickr30k):
    def __init__(self,
                 data_path,
                 ann_path,
                 img_transform = None,
                 txt_transform = None,
                 max_length = 100,
                 ):
        super().__init__(data_path, ann_path)
        self.img_transform = img_transform
        self.tokenizer = txt_transform
        self.max_length = max_length
        
        self.img_masker = MaskGenerator(input_size = 224,
                                        mask_patch_size = 32,
                                        model_patch_size = 16,
                                        mask_ratio = MASKING_RATIO_IMG)       #0.75 masking ratio with MAMO
        
        self.txt_masker = TextMaskGenerator(masking_ratio= MASKING_RATIO_TXT,
                                            mask_token = self.tokenizer.mask_token)
        
        
    def process_string(self, string):
        tok_str = string.lower().split() # separated by spaces
        stopwords = nltk.corpus.stopwords.words('english')
        proc_str = [x for x in tok_str if x not in stopwords]               # stopword removal
        proc_str = [word.lower() for word in proc_str if word.isalpha()]    # punctuation removal
        
        return " ".join(proc_str)
        
    def __getitem__(self, idx):
        img, txt = super().__getitem__(idx)
        txt = random.choice(txt)
        
        
        # get images and texts
        img = self.img_transform(img)
        
        # process string
        txt = self.process_string(txt)
        mask_txt = self.txt_masker(txt)
        
        tok_text = self.tokenizer(txt, truncation = True, padding = 'max_length', max_length = self.max_length, return_token_type_ids=False)
        tok_masked_txt = self.tokenizer(mask_txt, truncation = True, padding = 'max_length', max_length = self.max_length, return_token_type_ids=False)
        toks, attn_mask = tok_text['input_ids'], tok_text['attention_mask']
        masked_toks, masked_attn_mask = tok_masked_txt['input_ids'], tok_masked_txt['attention_mask']
        
        toks, attn_mask, masked_toks, masked_attn_mask = torch.tensor(toks), torch.tensor(attn_mask), torch.tensor(masked_toks), torch.tensor(masked_attn_mask)
        
        # masked indices
        mask_indices = (masked_toks == tokenizer.mask_token_id)
        
        # generate mask for image and text
        img_mask = self.img_masker()
        
        return img, img_mask, toks, attn_mask, masked_toks, masked_attn_mask, mask_indices
        

def create_masked_image(img, mask):
    mask = (mask .repeat_interleave(16, 1)
                .repeat_interleave(16, 2)
                .unsqueeze(1)
                .contiguous()
            )
    return (img * (1 - mask))
      
        
dataset = Flickr30K_MAMO(DATASET_SRC + 'flickr30k-images',
               DATASET_SRC + 'results_20130124.token',
               img_transform=img_transform,
               txt_transform=tokenizer,
               max_length = MAX_LEN)

In [None]:
# dataloader

pretrain_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size = BATCH_SIZE,
    pin_memory = True,
    num_workers = NUM_WORKERS,
)

In [None]:
tokenizer.mask_token_id, tokenizer.vocab_size

(103, 30522)

In [None]:
class MAMO(torch.nn.Module):
    def __init__(self,
                 vit,
                 bert,
                 vit_num_patches = 196,
                 vit_emb_dim = 384,
                 bert_emb_dim = 512,
                 bert_layers = 2,
                 vocab_size = 30522,
                 mask_token_id = 103):
       super().__init__()
       self.vit = vit
       self.bert = bert.base_model
       self.bert = deleteEncodingLayers(self.bert.base_model, bert_layers)
       self.mamo = MAMO_mixer(bert, bert_layers, 197, vit_emb_dim)
       
       # vit patches data
       self.vit_num_patches = vit_num_patches
       
       # vocab size
       self.vocab_size = vocab_size
       # mask token
       self.mask_token_id = mask_token_id
       
       # learnable temperature parameter
       self.tau = torch.nn.Parameter(torch.randn(1))
       self.tau.requires_grad = True

       # joint representation
       self.pooler = torch.nn.Sequential(
           torch.nn.AdaptiveAvgPool1d(1),
           torch.nn.Flatten()
       )
       self.img_proj = torch.nn.Linear(vit_emb_dim, min(vit_emb_dim, bert_emb_dim))
       self.txt_proj = torch.nn.Linear(bert_emb_dim, min(vit_emb_dim, bert_emb_dim))

       
       # masked representation modeling
       self.mrm_proj = torch.nn.Sequential(
            torch.nn.Linear(bert_emb_dim, bert_emb_dim),
            torch.nn.Tanh(),
       )
       
       # head for masked image modeling
       self.mim_proj = torch.nn.Sequential(
           torch.nn.Linear(bert_emb_dim, vit_emb_dim),
       )
        
       # head for masked language modeling
       self.mlm_head = bert.cls
       
       self.itm__head = torch.nn.Sequential(
           torch.nn.Linear(bert_emb_dim, bert_emb_dim),
           torch.nn.LeakyReLU(),
           torch.nn.Linear(bert_emb_dim, 1)
       )
    
       
       
    def forward(self, image, text, attn_mask,
                masked_modeling = True,
                masked_image = None,
                masked_text = None,
                image_text_matching = False
                ):
        
        if image_text_matching == True:
            img_rep = self.vit(image)['last_hidden_state']
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']
            joint_rep = self.mamo(img_rep, txt_rep, attn_mask)['last_hidden_state'][:, 0, :]
            
            return self.itm__head(joint_rep)
            
        
        if not masked_modeling:
            img_rep = self.vit(image)['last_hidden_state']
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']
            joint_rep = self.mamo(img_rep, txt_rep, attn_mask)['last_hidden_state']
        
            return (img_rep, txt_rep, joint_rep)
        
        else:
            # return mask_img-clean_txt, clean_img,-mask_txt, 
            img_rep = self.vit(image)['last_hidden_state']              # clean image
            txt_rep = self.bert(text, attn_mask)['last_hidden_state']   # clean text
            
            mask_img_rep = self.vit(masked_image)['last_hidden_state']
            mask_txt_rep = self.bert(masked_text, attn_mask)['last_hidden_state']
            
            # multimodal prediction
            c_img_m_txt = self.mamo(img_rep, mask_txt_rep, attn_mask)['last_hidden_state']
            m_img_c_txt = self.mamo(mask_img_rep, txt_rep, attn_mask)['last_hidden_state']
            
            # pure txt
            txt_prediction = self.mlm_head(mask_txt_rep)
            
            # pool and flatten text and visual features obtained before fusion
            img_rep = self.img_proj(self.pooler(img_rep.transpose(1,2)))
            txt_rep = self.txt_proj(self.pooler(txt_rep.transpose(1,2)))
            
            return (c_img_m_txt, m_img_c_txt, mask_img_rep, txt_prediction, img_rep, txt_rep)
            
    def mrm_projection(self, rep):
        return self.mrm_proj(rep)
    
    def mim_projection(self, rep):
        return self.mim_proj(rep)
    
    def get_mrm_loss(self, online_representation, target_representation, mask):
        # remove cls token
        on_rep = self.mrm_projection(online_representation[:, 1:, :])
        tr_rep = target_representation[:, 1:, :]
        
        loss = torch.nn.functional.mse_loss(on_rep, tr_rep, reduction = 'none')
        mrm_loss = (loss * mask).sum()/(mask.sum() + 1e-5)              # add for 0 division errors
        return mrm_loss
    
    def get_mim_loss(self, online_representation, target_representation, mask):
        on_rep = self.mim_projection(online_representation[:, 1:self.vit_num_patches+1, :]) # omit cls token
        tr_rep = target_representation[:, 1:self.vit_num_patches+1, :]
        
        loss = torch.nn.functional.l1_loss(on_rep, tr_rep, reduction = 'none')
        if mask.ndim == 2:
            mask = mask[:, :, None]
        mim_loss = (loss * mask).sum() / (mask.sum() + 1e-5)
        return mim_loss 
    
    
    def get_mlm_loss(self, scores, sen, masked_sen):
        labels = torch.where(masked_sen == self.mask_token_id, sen, -100)
        loss = torch.nn.functional.cross_entropy(scores.view(-1, self.vocab_size), labels.view(-1), ignore_index=-100)
        # loss = loss_fct(scores.view(-1, self.vocab_size), labels.view(-1))
        
        return loss
    
    
    def get_itc_loss(self, img_feats, txt_feats):
        # Calculate cosine similarity
        sim = torch.exp((img_feats@txt_feats.T)/self.tau)
        self_mask = torch.eye(sim.shape[0], device=sim.device)

        return sim, (torch.nn.functional.cross_entropy(sim, self_mask) + torch.nn.functional.cross_entropy(sim.T, self_mask))/2.
    
    def get_samples(self, similarities):
        probs = torch.nn.functional.softmax(similarities, dim = 1)
        txt_indices = torch.multinomial(probs, num_samples=1, replacement=True).squeeze(1)
        img_indices = torch.multinomial(probs.T, num_samples=1, replacement=True).squeeze(1)
        
        return txt_indices, img_indices

In [None]:

vit_model = transformers.ViTModel.from_pretrained('WinKawaks/vit-small-patch16-224').to(DEVICE)
bert_model = transformers.BertForMaskedLM.from_pretrained("prajjwal1/bert-small")

online_network = MAMO(
                    vit = vit_model,
                    bert = bert_model,
                    vit_num_patches= 196,
                    vit_emb_dim=384,
                    bert_emb_dim=512,
                    bert_layers=n_layers,
                    vocab_size=tokenizer.vocab_size,
                    mask_token_id= tokenizer.mask_token_id
                ).to(DEVICE)

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# utils for target network

# freeze weights
def freeze_weights(nw):
    for param in nw.parameters():
        param.requires_grad = False
        
    return nw
    
def ewma_weights(target, current, alpha = 0.995):
    sdA = target.state_dict()
    sdB = current.state_dict()
    
    for key in sdA:
        sdA[key] = alpha*sdA[key] + (1-alpha)*sdB[key]
    
    target.load_state_dict(sdA)
    return target



In [None]:
target_network = copy.deepcopy(online_network)

In [None]:
#optimiser
optim = torch.optim.AdamW(online_network.parameters(),
                          lr = lr,
                          weight_decay = decay,
                          betas = [0.9, 0.999],
                          )

epoch_steps = math.ceil(len(dataset)/BATCH_SIZE)
num_steps = int(EPOCHS * epoch_steps)
warmup_steps = int(warmup_epochs * epoch_steps)

lr_scheduler = CosineLRScheduler(
        optim,
        t_initial=num_steps,
        # t_mul=1.,
        lr_min=min_lr,
        warmup_lr_init = init_lr,
        warmup_t=warmup_steps,
        cycle_limit=1,
        t_in_epochs=False,
    )


# wandB init
wandb.init(
    id = id,# id,
    resume =  'allow',
    project = 'MAMO - Pretrain',
    name = 'MAMO - ViT-S, BERT-S',

    config = {
        'architecture': model_name,
        'dataset':'ImageNet1K',
        'warmup_epochs': warmup_epochs,
        'epochs' : EPOCHS,
        'batch_size': BATCH_SIZE,
        'masking_ratio_img' : MASKING_RATIO_IMG,
        'masking_ratio_itxt' : MASKING_RATIO_TXT,
        'mask_patch_size': 196,
        'image_size' : DIMENSION,
        'optim_params':{
            'optim': 'AdamW',
            'beta1': beta1,
            'beta2': beta2,
            'weight_decay': decay,
            'learning_rate': lr,
        },
        'accumulation_iters': 1,
        'patch_size_mask' : 32,
        'alpha_ewma': ALPHA,
    },
)

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [None]:
import re
nums = [int(re.match(r'.*checkpoint_(\d+).*', x).group(1)) for x in glob.glob(MODEL_SAVE_PATH+'*.pth')]

CHKPT = -1

if len(nums) != 0:
    CHKPT = max(nums)

    load_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, CHKPT)
    chkpt = torch.load(load_path, map_location = {'cuda:1': device, 
                                                  'cuda:0': device})

    online_network.load_state_dict(chkpt['online_model_state_dict'])
    target_network.load_state_dict(chkpt['target_model_state_dict'])
    optim.load_state_dict(chkpt['optim_state_dict'])
    # lr_scheduler.load_state_dict(chkpt['scheduler_state_dict'])
    
    print(load_path)
    
    print("loaded earlier settings")

In [None]:
target_network = freeze_weights(target_network).to(DEVICE)
scaler = torch.cuda.amp.grad_scaler.GradScaler()
itm_loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in range(EPOCHS):
    num_samples = 0
    pretrain_loss = 0
    for idx, data in (pbar := tqdm(enumerate(pretrain_dataloader), total = len(pretrain_dataloader))):
        img, img_mask, txt, attn_mask, masked_toks, masked_attn_mask, mask_indices = data
        
        # vision
        img = img.to(DEVICE)
        img_mask = img_mask.to(DEVICE)
        
        # language
        txt = txt.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)
        masked_toks = masked_toks.to(DEVICE)
        masked_attn_mask = masked_attn_mask.to(DEVICE)

        # indices for masked text: will be used for masked modeling
        mask_indices = mask_indices.float().to(DEVICE)
        
        # masked image
        masked_image = create_masked_image(img, img_mask)
        flattened_img_mask = img_mask.float().flatten(1)
        
        # create masks for joint representation modeling
        img_rep_masks = torch.cat([flattened_img_mask, torch.zeros_like(mask_indices)], axis = 1).unsqueeze(-1)
        txt_rep_masks = torch.cat([torch.zeros_like(flattened_img_mask), mask_indices], axis = 1).unsqueeze(-1)
        
        
        
        # masked modeling pretraining
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):                           # casting to f16
            
            
            # forward step for target network
            with torch.no_grad():
                target_img_rep, target_txt_rep, target_mm_rep = target_network(img, txt, attn_mask, masked_modeling = False)
            
            # forward step for online network
            c_img_m_txt, m_img_c_txt, mask_img_rep, txt_prediction, img_rep, txt_rep = online_network(img, txt, attn_mask,
                                                                                        masked_modeling = True,
                                                                                        masked_image = masked_image,
                                                                                        masked_text = masked_toks)


            # MRM loss
            mrm_loss_txt = online_network.get_mrm_loss(c_img_m_txt, target_mm_rep, txt_rep_masks)
            mrm_loss_img = online_network.get_mrm_loss(m_img_c_txt, target_mm_rep, img_rep_masks)
            
            # MIM loss
            mim_loss = online_network.get_mim_loss(m_img_c_txt, target_img_rep, flattened_img_mask)
            
            # MLM loss
            mlm_loss = online_network.get_mlm_loss(txt_prediction, txt, masked_toks)
            
            
            # ITC loss
            sim, itc_loss = online_network.get_itc_loss(img_rep, txt_rep)
            
            #itm loss
            # sample for each image and each text separately
            img_maps, txt_maps = online_network.get_samples(sim)
            right_samples = torch.arange(0, img_maps.size(0)).to(DEVICE)
            
            labs_img = (img_maps == right_samples).float().unsqueeze(1)
            labs_txt = (txt_maps == right_samples).float().unsqueeze(1)
            
            outs_img = online_network(img, txt[txt_maps], attn_mask[txt_maps], image_text_matching = True)
            outs_txt = online_network(img[img_maps], txt, attn_mask, image_text_matching = True)
            
            # softmax probabilities
            itm_1 = itm_loss_fn(outs_img, labs_img)
            itm_2 = itm_loss_fn(outs_txt, labs_txt)
            itm = itm_1 + itm_2
            
            # TOTAL LOSS
            net_loss = (mrm_loss_img + mrm_loss_txt) + (mim_loss) + (mlm_loss) + (itc_loss) + (itm)
            
        scaler.scale(net_loss).backward()
            
        # BACKPROP
        scaler.step(optim)        # fp16
        scaler.update()           # fp16
        optim.zero_grad(set_to_none = True)
        lr_scheduler.step_update(epoch * epoch_steps + idx)
        
        # update and calc loss
        num_samples+=1
        pretrain_loss+= net_loss.item()
        pbar.set_description(f"Train Loss: {pretrain_loss/num_samples}")
        
        
        # EWMA for weights
        target_network = ewma_weights(target_network, online_network, alpha = ALPHA)
        
        
    online_network.load_state_dict(chkpt['online_model_state_dict'])
    target_network.load_state_dict(chkpt['target_model_state_dict'])
    save_path = '{}_{}.pth'.format(MODEL_SAVE_PATH, epoch)

    wandb.log({
            'epoch': epoch,
            'pretrain_loss': pretrain_loss/num_samples
    })
    
    torch.save(
            {
            'epoch': epoch,
            'online_model_state_dict': online_network.state_dict(),
            'target_model_state_dict': target_network.state_dict(),
            'optim_state_dict': optim.state_dict()
            # 'scheduler_state_dict': lr_scheduler.state_dict(),
            },
        save_path
        )
    if (epoch-warmup_epochs) % 10 == 0:
        wandb.save(save_path)

In [None]:
wandb.finish()

tensor(1429.7030, device='cuda:0', grad_fn=<AddBackward0>)