In [162]:
import transformers
import torch
import torchvision
import torchmetrics

import torchtext

device = 'cuda:0'
n_workers = 8

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

torch.set_num_threads(n_workers)

DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'vit_bert_s'
algo = 'MAMO'

import tokenizers
import itertools
import matplotlib.pyplot as plt

import random
import math

import copy

In [163]:
DATASET_SRC = '../Datasets/Flickr30k/'
MODEL_SAVE_PATH = f'Models/{model_name}/{algo}/checkpoint'

VOCAB_PATH = 'Vocabulary/flickr30k.vocab'

n_layers = 2

if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))

In [164]:
def deleteEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(0, num_layers_to_keep):
        newModuleList.append(oldModuleList[i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel

def deleteLaterEncodingLayers(model, num_layers_to_keep):  # must pass in the full bert model
    oldModuleList = model.encoder.layer
    newModuleList = torch.nn.ModuleList()

    # Now iterate over all layers, only keepign only the relevant layers.
    for i in range(num_layers_to_keep, 0, -1):
        newModuleList.append(oldModuleList[-i])

    # create a copy of the model, modify it with the new list, and return
    copyOfModel = copy.deepcopy(model)
    copyOfModel.encoder.layer = newModuleList

    return copyOfModel


def get_bert_model(model, num_layers):
    return deleteEncodingLayers(model, num_layers)



class MAMO_mixer(torch.nn.Module):
    def __init__(self, base_bert, n_layers = 2, n_visual_tokens = 197, vision_embedding_dim = 384):
        # prepare decoder
        super().__init__()
        self.n_visual_tokens = n_visual_tokens
        self.vision_emb_dim = vision_embedding_dim
        self.emb_dimension = base_bert.pooler.dense.out_features
        self.base_model = deleteLaterEncodingLayers(base_bert, n_layers).encoder
        
        self.pooler = torch.nn.AdaptiveAvgPool1d(1)
        
        if self.vision_emb_dim == self.emb_dimension:
            self.dimension_caster = torch.nn.Identity()
        else:
            self.dimension_caster = torch.nn.Linear(self.vision_emb_dim, self.emb_dimension, bias = False)  # no bias here
        
        
    def forward(self, vision_embedding, text_embedding, text_attn_mask):
        # assert len(vision_embedding) == len(text_embedding)
        n_batch = len(vision_embedding)
        
        # normalize dimensions
        new_vision_emb = self.dimension_caster(vision_embedding)
        
        # concatenate
        concatenated_emb = torch.cat([new_vision_emb, text_embedding], dim = 1)
        
        # create attention mask
        vision_attention_mask = torch.ones(n_batch, self.n_visual_tokens)
        attn_mask = torch.cat([vision_attention_mask, text_attn_mask], dim = 1)
        
        attn_mask = attn_mask[:, None, None, :]
        
        # forward
        return self.base_model(concatenated_emb, attn_mask)

In [165]:
# mamo_model = transformers.BertModel.from_pretrained("prajjwal1/bert-small")
# mamo_model = MAMO_mixer(mamo_model, 2, 197, 384)#.to(DEVICE)

# vision_inp = torch.empty(10, 197, 384)
# text_inp = torch.empty(10, 50, 512)
# text_attn = torch.randint(0, 1, (10, 50))

# mamo_model(vision_inp, text_inp, text_attn)

In [166]:
DIMENSION = 224

MAX_LEN = 50

# ViT config
# image_processor = transformers.AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
# model = transformers.ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained('WinKawaks/vit-small-patch16-224')
vit_model = transformers.ViTModel.from_pretrained('WinKawaks/vit-small-patch16-224').to(DEVICE)

# DistilBERT config
# tokenizer = transformers.AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# bert_model = transformers.BertModel.from_pretrained("google-bert/bert-base-uncased").to(DEVICE)

tokenizer = transformers.AutoTokenizer.from_pretrained("prajjwal1/bert-small")
bert_model = transformers.BertModel.from_pretrained("prajjwal1/bert-small").to(DEVICE)

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [167]:
summary(vit_model, input_size = (5, 3, DIMENSION, DIMENSION))

Layer (type:depth-idx)                             Output Shape              Param #
ViTModel                                           [5, 384]                  --
├─ViTEmbeddings: 1-1                               [5, 197, 384]             76,032
│    └─ViTPatchEmbeddings: 2-1                     [5, 196, 384]             --
│    │    └─Conv2d: 3-1                            [5, 384, 14, 14]          295,296
│    └─Dropout: 2-2                                [5, 197, 384]             --
├─ViTEncoder: 1-2                                  [5, 197, 384]             --
│    └─ModuleList: 2-3                             --                        --
│    │    └─ViTLayer: 3-2                          [5, 197, 384]             1,774,464
│    │    └─ViTLayer: 3-3                          [5, 197, 384]             1,774,464
│    │    └─ViTLayer: 3-4                          [5, 197, 384]             1,774,464
│    │    └─ViTLayer: 3-5                          [5, 197, 384]             1,774,46

In [168]:
# get the dimension of the BERT model to map ViT and BERT to

with torch.no_grad():
    outs = bert_model(torch.zeros(5, MAX_LEN, dtype=torch.int).to(DEVICE), torch.ones(5, MAX_LEN, dtype = torch.int).to(DEVICE))['last_hidden_state']
# uniform_embedding_dimension = outs.shape[-1]
# uniform_embedding_dimension, outs.shape
emb_dim = outs.shape[-1]

In [169]:
import torchvision.transforms.v2 as v2

## image transforms
img_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.int8, scale = True),
    v2.RandomResizedCrop(size = (DIMENSION, DIMENSION), 
                                scale = [0.67,1], 
                                ratio = [3/4, 4/3],
                                antialias = False),
    v2.RandomVerticalFlip(),
    v2.RandomHorizontalFlip(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean = [0.5, 0.5, 0.5],
        std =  [0.5, 0.5, 0.5]
    )
])

In [170]:
import nltk

class MaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size
        self.model_patch_size = model_patch_size
        self.mask_ratio = mask_ratio
        
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        
        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size
        
        self.token_count = self.rand_size ** 2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
        
    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1
        
        mask = mask.reshape((self.rand_size, self.rand_size))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
        
        return mask
    
    
class TextMaskGenerator:
    def __init__(self, masking_ratio = 0.25, mask_token = '[MASK]'):
        self.masking_ratio = masking_ratio
        self.mask_token = mask_token
        
    def __call__(self, text):
        text = np.array(text.split())  # tokenized
        len_txt = len(text)
        
        n_to_mask = math.ceil(len_txt * self.masking_ratio)
        rankings = np.random.randn(len_txt)
        
        indices = np.argpartition(rankings, -n_to_mask)[-n_to_mask:]
        text[indices] = self.mask_token
        
        
        
        return " ".join(text)
        

class Flickr30K_MAMO(torchvision.datasets.Flickr30k):
    def __init__(self,
                 data_path,
                 ann_path,
                 img_transform = None,
                 txt_transform = None,
                 max_length = 100,
                 ):
        super().__init__(data_path, ann_path)
        self.img_transform = img_transform
        self.tokenizer = txt_transform
        self.max_length = max_length
        
        self.img_masker = MaskGenerator(input_size = 224,
                                        mask_patch_size = 32,
                                        model_patch_size = 16,
                                        mask_ratio = 0.75)       #0.75 masking ratio with MAMO
        
        self.txt_masker = TextMaskGenerator(masking_ratio=0.25,
                                            mask_token = self.tokenizer.mask_token)
        
        
    def process_string(self, string):
        tok_str = string.lower().split() # separated by spaces
        stopwords = nltk.corpus.stopwords.words('english')
        proc_str = [x for x in tok_str if x not in stopwords]   # stopword removal
        
        return " ".join(proc_str)
        
        
        
    def __getitem__(self, idx):
        img, txt = super().__getitem__(idx)
        txt = random.choice(txt)
        
        
        # get images and texts
        img = self.img_transform(img)
        
        # process string
        txt = self.process_string(txt)
        mask_txt = self.txt_masker(txt)
        
        tok_text = self.tokenizer(txt, truncation = True, padding = 'max_length', max_length = self.max_length, return_token_type_ids=False)
        tok_masked_txt = self.tokenizer(mask_txt, truncation = True, padding = 'max_length', max_length = self.max_length, return_token_type_ids=False)
        toks, attn_mask = tok_text['input_ids'], tok_text['attention_mask']
        masked_toks, masked_attn_mask = tok_masked_txt['input_ids'], tok_masked_txt['attention_mask']
        
        toks, attn_mask, masked_toks, masked_attn_mask = torch.tensor(toks), torch.tensor(attn_mask), torch.tensor(masked_toks), torch.tensor(masked_attn_mask)
        
        # masked indices
        mask_indices = (masked_toks == tokenizer.mask_token_id)
        
        # generate mask for image and text
        img_mask = self.img_masker()
        
        return img, img_mask, toks, attn_mask, masked_toks, masked_attn_mask, mask_indices
        
        
        
dataset = Flickr30K_MAMO(DATASET_SRC + 'flickr30k-images',
               DATASET_SRC + 'results_20130124.token',
               img_transform=img_transform,
               txt_transform=tokenizer,
               max_length = MAX_LEN)

In [171]:
# dataloader

pretrain_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size = 128,
    pin_memory = True,
    # num_workers = None,
)

In [None]:
class MAMO()

In [None]:


for idx, data in enumerate(pretrain_dataloader):
    img, img_mask, toks, attn_mask, masked_toks, masked_attn_mask, mask_indices = data
    
    # vision
    img = img.to(DEVICE)
    img_mask = img_mask.to(DEVICE)
    
    # language
    toks = toks.to(DEVICE)
    attn_mask = attn_mask.to(DEVICE)
    masked_toks = masked_toks.to(DEVICE)
    masked_attn_mask = masked_attn_mask.to(DEVICE)

    mask_indices = mask_indices.to(DEVICE)
    
    
    
    # MAMO has 5 components now...
    # masked modeling
    
    
    
    # masked joint representation modeling
    
    
    # masked image modeling
    
    
    # masked text modeling
    
    
    
    
    
    
    # global image-text alignment
    # itc loss : contrastive
    
    
    
    # image-text matching : sample according to similarity measure for each example, then do supervised prediction