In [5]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image

In [30]:
tokenizer_eng = spacy.load('en_core_web_sm')

class Vocabulary:
    def __init__(self, freq_threshold):
        self.freq_threshold = freq_threshold
        self.itos = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
        self.stoi = {'<PAD>': 0,'<SOS>': 1,'<EOS>': 2,'<UNK>': 3}
        
    def __len__(self):
        return len(self.itos)
    
    def build_vocab(self, sentence_list):
        freq = {}
        idx = 4
        for sentence in sentence_list:
            for word in tokenizer_eng(sentence):
                if word not in freq:
                    freq[word] = 1
                else:
                    freq[word] += 1
                
                if freq[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self, text):
            tokenized_text = self.tokenizer_eng(text)
            return [self.stoi[token] if token in self.stoi else self.stoi['<UNK>'] for token in tokenized_text]

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]


In [41]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transforms=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transforms = transforms
        self.freq_threshold = freq_threshold
        
        # Getting img and captions
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        # Initialize & build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        
        if self.transforms is not None:
            img = self.transforms(img)
            
        numericalized_caption = [self.vocab.stoi['<SOS>']]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi['<EOS>'])
        
        return img, torch.tensor(numericalized_caption)

In [49]:
class CollateForPadding:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
        return imgs, targets

In [50]:
def get_loader(root_dir, annotation_file, transform, batch_size, num_workers, shuffle, pin_memory):
    dataset = FlickrDataset(root_dir, annotation_file, transforms=transform)
    pad_idx = dataset.vocab.stoi['<PAD>']
    loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn = CollateForPadding(pad_idx=pad_idx)
    )
    return loader
    

In [51]:
transform = transforms.Compose([
        transforms.Resize((500,500)),
        transforms.ToTensor()
    ])
dataloader = get_loader(
        root_dir='Data/flickr8k/images/',
        annotation_file='Data/flickr8k/captions.txt',
        transform=transform,
        num_workers=0,
        batch_size=32,
        shuffle=True,
        pin_memory=True,
    )
checkpoint = {'data_loader': dataloader}
torch.save(checkpoint, 'Custom Utils/Flickr8k_Data_Loader.pth.tar')
print("Loader Saved Successfully!")


Loader Saved Successfully!


In [52]:
for idx, (imgs, captions) in enumerate(dataloader):
    print(imgs.shape)
    print(captions.shape)
    print()
    break

torch.Size([32, 3, 500, 500])
torch.Size([27, 32])

