In [1]:
#torchtext give us a lot of ease but sometimes we need to create customer dataset
#we will be creating image captioning dataset

In [1]:
#convert text to numerical values
#need vocab mapping
#setup pytorch dataset
#setup padding of every batch as, all datapoints should be of same seq_len in a batch and setup data loader

In [2]:
import os
import pandas as pd
import spacy #for tokenization
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms

In [3]:
spacy_eng= spacy.load('en_core_web_sm')

In [20]:
class Vocabulary:
    def __init__(self, freq_threshold): #if word is spoken less than the frequency then thats not useful for us, so we ignore that
        self.itos= {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'} #itos means index to string
        self.stoi= {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold= freq_threshold

    def __len__(self):
        return len(self.itos)
    
    @staticmethod #it means fucntion dont have access to instance or class attributes, used for like utility functions
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)] #using spacy to tokenize, can also use split with space but its better
    
    def build_vocabulary(self, sentence_list): #all captions in dataset is passed so can create vocab and keep track of their freq and start index from 4 as till 3 given
        frequencies= {}
        idx= 4

        for i in sentence_list:
            for word in self.tokenizer_eng(i):
                if word not in frequencies:
                    frequencies[word]= 1
                else:
                    frequencies[word]+= 1

                if frequencies[word]== self.freq_threshold:
                    self.stoi[word]= idx
                    self.itos[idx]= word
                    idx+=1

    def numericalize(self, text):
        tokenized_text= self.tokenizer_eng(text)

        return [
            self.stoi[word] if word in self.stoi else self.stoi['<UNK>']
            for word in tokenized_text
        ]




In [21]:
class FlickerDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform= None, freq_theshold= 5):
        self.root_dir= root_dir
        self.df= pd.read_csv(captions_file)
        self.transform= transform

        #get image and cpation columns
        self.imgs= self.df['image']
        self.captions= self.df['caption']

        #initialize vocab and build it
        self.vocab= Vocabulary(freq_theshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption= self.captions[index]
        img_id= self.imgs[index]
        img= Image.open(os.path.join(self.root_dir, img_id)).convert('RGB')

        if self.transform is not None:
            img= self.transform(img)

        numericalize_caption= [self.vocab.stoi['<SOS>']] #stoi is string to index, we first create for start token
        numericalize_caption+= self.vocab.numericalize(caption)
        numericalize_caption.append(self.vocab.stoi['<EOS>'])

        return img, torch.tensor(numericalize_caption)
        



In [22]:
#now we want that in a batch all text(caption) should be of same length
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx= pad_idx

    def __call__(self, batch): #we get list of list that contain img and caption
        imgs= [item[0].unsqueeze(0) for item in batch] #need extra dimension for the batch
        imgs= torch.cat(imgs, dim= 0)
        targets= [item[1] for item in batch]
        targets= pad_sequence(targets, batch_first= False, padding_value= self.pad_idx)
        #batch_first is set according to model requirement that what we want first
        #False (default): Returns a tensor of shape (max_seq_length, batch_size, features)
        #True: Returns a tensor of shape (batch_size, max_seq_length, features)

        return imgs, targets

In [26]:
def get_loader(
        root_folder,
        annotation_file,
        transform,
        batch_size= 32,
        num_workers= 4,
        shuffle= True, #if working with time series data then we dont shuffle
        pin_memory= True,
):
    dataset= FlickerDataset(root_folder,
                            annotation_file,
                            transform
                            )
    
    pad_idx= dataset.vocab.stoi['<PAD>']

    loader= DataLoader(
        dataset= dataset,
        batch_size= batch_size,
        #num_workers= num_workers,
        shuffle= shuffle,
        pin_memory= pin_memory,
        collate_fn= MyCollate(pad_idx= pad_idx),
    )

    return loader

In [27]:
transform= transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataloader= get_loader('data/flicker8k/Images', 'data/flicker8k/captions.txt', transform)

In [30]:
for idx, i in enumerate(dataloader):
    print(i[0].shape, i[1].shape)
    if idx==5:
        break

torch.Size([32, 3, 224, 224]) torch.Size([21, 32])
torch.Size([32, 3, 224, 224]) torch.Size([23, 32])
torch.Size([32, 3, 224, 224]) torch.Size([22, 32])
torch.Size([32, 3, 224, 224]) torch.Size([24, 32])
torch.Size([32, 3, 224, 224]) torch.Size([24, 32])
torch.Size([32, 3, 224, 224]) torch.Size([21, 32])
