## Dataset

*We need to conver text to numerical value*
* We need a vocabulary mapping for each word(or character) to int
* We need to setup a pytorch dataset
* Make sure that each sentence (input) is same size (padding) and dataloader

In [58]:
import os
import pandas as pd
import spacy
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from torchvision.transforms import transforms

In [None]:
!python -m spacy download en_core_web_sm

In [77]:
class Vocabulary():
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    
    def __len__(self):
        return len(self.itos)


    @staticmethod
    def tokenizer_eng(text):
        spacy_eng = spacy.load("en_core_web_sm")
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequency = {}
        idx = 4
        N = 50

        for sentence in sentence_list:
            for token in self.tokenizer_eng(sentence):
                print(idx)
                if idx - 4 == N:
                    break

                frequency[token] = 1 + frequency.get(token, 0)

                if frequency[token] == self.freq_threshold:
                    self.stoi[token] = idx
                    self.itos[idx] = token
                    idx += 1
    
    def numericalize(self, text):
        print("normalize")
        token_sent = self.tokenizer_eng(text)

        return [self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
                for token in token_sent
            ]

In [78]:
class FlickrDataset(Dataset):

    def __init__(self, root_dir, caption_file, transform=None, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(caption_file)
        self.transform = transform

        # get the image and caption
        self.images = self.df['image']
        self.caption = self.df['caption']

        # Create our own vocabulary
        self.vocabulary = Vocabulary(freq_threshold)
        self.vocabulary.build_vocabulary(self.caption.to_list())
        print('done __int__')
    
    def __len__(self):
        return len(self.df)


    def __getitem__(self, index):
        # get image
        image_path = os.path.join(self.root_dir, self.images[index])
        img = Image.open(image_path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # get caption
        caption = self.caption[index]
        num_caption = self.vocabulary.stoi['<SOS>']
        num_caption += self.vocabulary.numericalize(caption)
        num_caption.append(self.vocabulary.stoi['<EOS>'])
        

        return img, torch.tensor(num_caption)
    


In [79]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
    
    def __call__(self, batch):
        img = [item[0].unsqueeze(0) for item in batch]
        img = torch.cat(img, 0)
        target = [item[1] for item in batch]
        target = pad_sequence(target, batch_first=False, padding_value=self.pad_idx)
        return img, target

In [80]:
def get_loader(
        root_folder,
        annotation_file,
        transform,
        batch_size=32,
        num_worker=1,
        shuffle=True,
        pin_memory=False
):
    dataset = FlickrDataset(root_dir=root_folder,
                            caption_file=annotation_file, transform=transform)
    pad_idx = dataset.vocabulary.stoi["<PAD>"]
    
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=MyCollate(pad_idx=pad_idx)
    )

    return loader

In [88]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
6
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
7
8
8
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
9
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
10
11
11
11
11
11
11
11
12
12
12
12
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
13
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
14
15
15
15
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
16
17
17
17
17
17
17
17
17
17
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
18
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
19
20
20
20
20
20
20
20
20
20
20
20
20
20
20
20
20
2

KeyboardInterrupt: 

## Model

In [87]:
text = pd.read_csv('data/captions.txt')
text = text['caption'].to_list()
spacy_eng = spacy.load("en_core_web_sm")
output = []
ele = set()
for i in text:
    for tok in spacy_eng.tokenizer(i):
        ele.add(tok.text.lower())


print(len(ele))

8504


## Training