In [1]:
import transformers
import torch
import torchvision
import torchmetrics

import torchtext

device = 'cuda:1'
n_workers = 8

from tqdm import tqdm
from PIL import Image

import pandas as pd
import numpy as np

from torchinfo import summary
import os
import glob

torch.set_num_threads(n_workers)

DEVICE = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')
model_name = 'vit_bert_b'
algo = 'MAMO'

import tokenizers
import itertools
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATASET_SRC = '../Datasets/Flickr30k/'
MODEL_SAVE_PATH = f'Models/{model_name}/{algo}/checkpoint'

VOCAB_PATH = 'Vocabulary/flickr30k.vocab'

if os.path.exists(os.path.dirname(MODEL_SAVE_PATH)) == False:
    os.makedirs(os.path.dirname(MODEL_SAVE_PATH))

In [3]:
DIMENSION = 224

# ViT config


In [4]:
# loading dataset

dataset = torchvision.datasets.Flickr30k(DATASET_SRC + 'flickr30k-images',
                                         DATASET_SRC + 'results_20130124.token')
VOCAB_SIZE = 30000
MAX_LEN = 100

dataset.target_transform = None

In [15]:
import torchvision.transforms.v2 as v2

## image transforms
img_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.int8, scale = True),
    v2.RandomResizedCrop(size = (DIMENSION, DIMENSION), 
                                scale = [0.67,1], 
                                ratio = [3/4, 4/3],
                                antialias = False),
    v2.RandomVerticalFlip(),
    v2.RandomHorizontalFlip(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(
        mean = [0.485, 0.456, 0.406],
        std =  [0.229, 0.224, 0.225]
    )
])



## text processing

list_of_strings = list(itertools.chain.from_iterable(list(dataset.annotations.values())))
#tokenizer
tokenizer = torchtext.data.utils.get_tokenizer(None, language = 'en')
all_toks = [tokenizer(x) for x in list_of_strings]

#vocabulary
vocab = torchtext.vocab.build_vocab_from_iterator(
    all_toks,
    min_freq = 2,
    specials = ['[PAD]', '[UNK]', '[MASK]'],
    special_first=False,
    max_tokens = VOCAB_SIZE
)
vocab.set_default_index = vocab['[UNK]']

with open(VOCAB_PATH, 'w') as f:
    for token in vocab.get_itos():
        f.write(token + '\n')

#text transform
text_transform = torchtext.transforms.Sequential(
    # torchtext.transforms.RegexTokenizer([]),
    torchtext.transforms.BERTTokenizer(VOCAB_PATH, 
                                       do_lower_case=False,
                                       strip_accents=False,
                                       return_tokens = True,
                                       ),
    torchtext.transforms.VocabTransform(vocab),
    torchtext.transforms.Truncate(MAX_LEN),
    torchtext.transforms.ToTensor(padding_value = vocab['[PAD]']),
    torchtext.transforms.PadTransform(MAX_LEN, vocab['[PAD]']),
)

In [33]:
class Flickr30K_MAMO(torchvision.datasets.Flickr30k):
    def __init__(self, data_path, ann_path):
        super().__init__(data_path, ann_path)
        
    def __getitem__(self, idx):
        img, cap = super().__getitem__(idx)
        
Flickr30K_MAMO(DATASET_SRC + 'flickr30k-images',
               DATASET_SRC + 'results_20130124.token')

Dataset Flickr30K_MAMO
    Number of datapoints: 31783
    Root location: ../Datasets/Flickr30k/flickr30k-images

In [16]:
# setting transformations

dataset.transform = img_transform
dataset.target_transform = text_transform

In [17]:
dataset[1]

(Image([[[1.4925, 1.4925, 1.4925,  ..., 1.4925, 1.5612, 1.5268],
         [1.5268, 1.4925, 1.4925,  ..., 1.4237, 1.5268, 1.4925],
         [1.4925, 1.4925, 1.5268,  ..., 1.5268, 1.5612, 1.5268],
         ...,
         [1.4925, 1.4581, 1.4581,  ..., 1.4581, 1.5268, 1.4925],
         [1.4581, 1.4581, 1.4581,  ..., 1.4581, 1.4925, 1.5268],
         [1.4581, 1.4581, 1.4925,  ..., 1.5268, 1.4925, 1.4925]],
 
        [[1.6552, 1.6552, 1.6552,  ..., 1.7255, 1.7255, 1.6904],
         [1.6904, 1.6552, 1.6552,  ..., 1.6201, 1.6904, 1.6552],
         [1.6904, 1.6552, 1.6904,  ..., 1.7255, 1.7255, 1.6904],
         ...,
         [1.6552, 1.6201, 1.6201,  ..., 1.6201, 1.6904, 1.6552],
         [1.6201, 1.6201, 1.6201,  ..., 1.6201, 1.6552, 1.6904],
         [1.6201, 1.6201, 1.6552,  ..., 1.6904, 1.6552, 1.6552]],
 
        [[1.8351, 1.8351, 1.8351,  ..., 1.8701, 1.9051, 1.8701],
         [1.8701, 1.8351, 1.8351,  ..., 1.7651, 1.8701, 1.8351],
         [1.8351, 1.8351, 1.8701,  ..., 1.8701, 1.9051, 

In [None]:
tokenizer.encode(dataset[1][1][1]).tokens