In [1]:
CSV_PATH = '/kaggle/input/flickr-image-dataset/flickr30k_images/results.csv'
IMG_PATH = '/kaggle/input/flickr-image-dataset/flickr30k_images/flickr30k_images/'

In [47]:
import torch
import pandas as pd
from spacy.lang.en import English
import torchtext
from sklearn.model_selection import train_test_split
import tqdm
seed = 1234
import random
import numpy as np
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
df = pd.read_csv(CSV_PATH, delimiter='|')

In [4]:
df.head(5)

Unnamed: 0,image_name,comment_number,comment
0,1000092795.jpg,0,Two young guys with shaggy hair look at their...
1,1000092795.jpg,1,"Two young , White males are outside near many..."
2,1000092795.jpg,2,Two men in green shirts are standing in a yard .
3,1000092795.jpg,3,A man in a blue shirt standing in a garden .
4,1000092795.jpg,4,Two friends enjoy time spent together .


In [5]:
df.dropna(inplace=True)

In [6]:
df = df.iloc[:2000]

In [7]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2000 entries, 0 to 1999
Data columns (total 3 columns):
 #   Column           Non-Null Count  Dtype 
---  ------           --------------  ----- 
 0   image_name       2000 non-null   object
 1    comment_number  2000 non-null   object
 2    comment         2000 non-null   object
dtypes: object(3)
memory usage: 62.5+ KB


In [8]:
train, remaining = train_test_split(df, test_size=0.2, random_state=42)

test, val = train_test_split(remaining, test_size=0.5, random_state=42)

print("Size of train set:", len(train))
print("Size of test set:", len(test))
print("Size of validation set:", len(val))

Size of train set: 1600
Size of test set: 200
Size of validation set: 200


In [9]:
def load_data(df):
    return df.to_dict(orient='list')

In [10]:
train_data = load_data(train[["image_name"," comment"]])
test_data = load_data(test[["image_name"," comment"]])
val_data = load_data(val[["image_name"," comment"]])

In [11]:
print(train_data['image_name'][0], train_data[' comment'][0])
print(test_data['image_name'][0], test_data[' comment'][0])
print(val_data['image_name'][0], val_data[' comment'][0])

1065323785.jpg  Kids are riding bikes near a dirt mound in the street .
1083240835.jpg  A woman applies mascara to her eyelash .
1000366164.jpg  Two men are at the stove preparing food .


In [12]:
tokenizer = English()

In [13]:
def tokenize_sample(comment, tokenizer, max_length, sos_token, eos_token, lower):
    tokens = [token.text for token in tokenizer.tokenizer(comment.strip())][:max_length]
    if lower:
        tokens = [token.lower() for token in tokens]
    tokens = [sos_token] + tokens + [eos_token]
    return tokens

In [14]:
sos_token = "<sos>"
eos_token = "<eos>"
max_length = 200
lower = True
train_data["tokens"] = [tokenize_sample(cm,tokenizer,max_length,sos_token,eos_token,lower) for cm in train_data[" comment"]]
val_data["tokens"] = [tokenize_sample(cm,tokenizer,max_length,sos_token,eos_token,lower) for cm in val_data[" comment"]]
test_data["tokens"] = [tokenize_sample(cm,tokenizer,max_length,sos_token,eos_token,lower) for cm in test_data[" comment"]]

In [15]:
train_data[" comment"][0],train_data["tokens"][0], 

(' Kids are riding bikes near a dirt mound in the street .',
 ['<sos>',
  'kids',
  'are',
  'riding',
  'bikes',
  'near',
  'a',
  'dirt',
  'mound',
  'in',
  'the',
  'street',
  '.',
  '<eos>'])

In [16]:
def yield_token(data, key):
    for token in data[key]:
        yield token

In [17]:
import torchtext.vocab


min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"
special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token
]
vocab = torchtext.vocab.build_vocab_from_iterator(
    yield_token(train_data,"tokens"),
    min_freq=min_freq,
    specials=special_tokens
)

In [18]:
vocab.get_itos()[:10]

['<unk>', '<pad>', '<sos>', '<eos>', 'a', '.', 'in', 'the', 'on', 'man']

In [19]:
len(vocab)

1150

In [20]:
unk_index = vocab[unk_token]
pad_index = vocab[pad_token]
unk_index, pad_index

(0, 1)

In [21]:
vocab.set_default_index(unk_index)

In [22]:
def token2ids(tokens, vocab):
    ids = vocab.lookup_indices(tokens)
    return ids

In [23]:
train_data["ids"] = [token2ids(t,vocab) for t in train_data["tokens"]]
val_data["ids"] = [token2ids(t,vocab) for t in val_data["tokens"]]
test_data["ids"] = [token2ids(t,vocab) for t in test_data["tokens"]]

In [24]:
train_data[" comment"][0],train_data["tokens"][0],train_data["ids"][0]

(' Kids are riding bikes near a dirt mound in the street .',
 ['<sos>',
  'kids',
  'are',
  'riding',
  'bikes',
  'near',
  'a',
  'dirt',
  'mound',
  'in',
  'the',
  'street',
  '.',
  '<eos>'],
 [2, 225, 14, 72, 236, 55, 4, 222, 708, 6, 7, 40, 5, 3])

In [25]:
import torch.utils
import torch.utils.data
from PIL import Image

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform, path_prefix):
        self.img_paths = data["image_name"]
        self.ids = data["ids"]
        self.transform = transform
        self.path_prefix = path_prefix

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        image = Image.open(self.path_prefix+self.img_paths[index])
        torch_img = self.transform(image)
        return torch_img, torch.tensor(self.ids[index],dtype=torch.int64)

In [26]:
from torchvision.transforms import Compose, transforms

In [27]:
transform = Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [28]:
train_dataset = CustomDataset(train_data, transform, IMG_PATH)
val_dataset = CustomDataset(val_data, transform, IMG_PATH)
test_dataset = CustomDataset(test_data, transform, IMG_PATH)

In [29]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_img = torch.stack([b[0] for b in batch])
        batch_ids = [b[1] for b in batch]
        batch_ids = torch.nn.utils.rnn.pad_sequence(batch_ids,batch_first=True,padding_value=pad_index)
        return batch_img, batch_ids
    return collate_fn

In [30]:
def get_dataloader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader =  torch.utils.data.DataLoader(dataset,batch_size,shuffle,collate_fn=collate_fn)
    return data_loader

In [31]:
batch_size = 128
shuffle = True
train_dataloader = get_dataloader(train_dataset,batch_size,pad_index,shuffle)
val_dataloader = get_dataloader(val_dataset,batch_size,pad_index)
test_dataloader = get_dataloader(test_dataset,batch_size,pad_index)

In [32]:
from torchvision.models import resnet50, ResNet50_Weights
import random

In [33]:
class Encoder(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        resnet.fc = torch.nn.Linear(2048, hidden_dim)
        self.model = resnet
    
    def forward(self,x):
        return self.model(x)

In [34]:
class Decoder(torch.nn.Module):
    def __init__(self, hidden_dim, embedding_dim, vocab_size, padding_idx):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.embedding = torch.nn.Embedding(vocab_size,embedding_dim,padding_idx=padding_idx)
        self.rnn = torch.nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim,vocab_size)
    
    def forward(self, input, hidden):
        input = input.unsqueeze(1)
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded,(hidden))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden

In [35]:
class Img2Seq(torch.nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert(self.encoder.hidden_dim == self.decoder.hidden_dim), "Hidden dimensions of encoder and decoder must be equal"
    
    def forward(self, batch_imgs, batch_ids, teacher_forcing_ratio):
        batch_size = batch_imgs.shape[0]
        vocab_size = self.decoder.vocab_size
        seq_length = batch_ids.shape[1]
        outputs = torch.zeros(batch_size,seq_length,vocab_size).to(self.device)
        hidden = self.encoder(batch_imgs).unsqueeze(0) # D * numlayer x batch x hidden_dim
        inputs = batch_ids[:,0]
        for i in range(1, seq_length):
            output, hidden = self.decoder(inputs,hidden)
            outputs[:,i,:] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            inputs = batch_ids[:,i] if teacher_force else top1
        return outputs

In [36]:
vocab_size = len(vocab)
hidden_dim = 256
embedding_dim = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(hidden_dim)
decoder = Decoder(hidden_dim,embedding_dim,vocab_size,pad_index)
img2seq = Img2Seq(encoder,decoder,device).to(device)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 65.2MB/s]


In [37]:
def init_weights(m):
    for name, param in m.named_parameters():
        torch.nn.init.uniform_(param.data, -0.08, 0.08)
img2seq.apply(init_weights)

Img2Seq(
  (encoder): Encoder(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=Tr

In [38]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {count_parameters(img2seq):,} trainable parameters")

The model has 24,574,142 trainable parameters


In [39]:
optimizer = torch.optim.Adam(img2seq.parameters())
criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_index)

In [40]:
def train_fn(model, data_loader, optimizer, criterion, teacher_forcing_ratio, device):
    model.train()
    epoch_loss = 0
    for batch_imgs, batch_ids in data_loader:
        batch_imgs, batch_ids = batch_imgs.to(device), batch_ids.to(device)
        optimizer.zero_grad()
        output = model(batch_imgs,batch_ids,teacher_forcing_ratio)
        output_dim = output.shape[-1]
        output = output[:,1:,].reshape(-1,output_dim)
        target = batch_ids[:,1:].reshape(-1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [41]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_imgs, batch_ids in data_loader:
            batch_imgs, batch_ids = batch_imgs.to(device), batch_ids.to(device)
            output = model(batch_imgs,batch_ids,0)
            output_dim = output.shape[-1]
            output = output[:,1:,].reshape(-1,output_dim)
            target = batch_ids[:,1:].reshape(-1)
            loss = criterion(output, target)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [48]:
n_epochs = 50
teacher_forcing_ratio = 0.7
best_valid_loss = float("inf")
for epoch in tqdm.tqdm(range(n_epochs)):
    train_loss = train_fn(
        img2seq,
        train_dataloader,
        optimizer,
        criterion,
        teacher_forcing_ratio,
        device,
    )
    valid_loss = evaluate_fn(
        img2seq, 
        val_dataloader, 
        criterion, 
        device)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(img2seq.state_dict(), "tut1-model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")

  2%|▏         | 1/50 [00:40<32:58, 40.38s/it]

	Train Loss:   4.623 | Train PPL: 101.751
	Valid Loss:   4.461 | Valid PPL:  86.532


  4%|▍         | 2/50 [01:20<32:13, 40.29s/it]

	Train Loss:   4.558 | Train PPL:  95.392
	Valid Loss:   4.458 | Valid PPL:  86.273


  6%|▌         | 3/50 [02:01<31:36, 40.35s/it]

	Train Loss:   4.476 | Train PPL:  87.925
	Valid Loss:   4.477 | Valid PPL:  87.933


  8%|▊         | 4/50 [02:41<30:58, 40.41s/it]

	Train Loss:   4.407 | Train PPL:  82.021
	Valid Loss:   4.542 | Valid PPL:  93.921


 10%|█         | 5/50 [03:21<30:16, 40.36s/it]

	Train Loss:   4.325 | Train PPL:  75.593
	Valid Loss:   4.559 | Valid PPL:  95.454


 12%|█▏        | 6/50 [04:02<29:36, 40.37s/it]

	Train Loss:   4.247 | Train PPL:  69.894
	Valid Loss:   4.571 | Valid PPL:  96.662


 14%|█▍        | 7/50 [04:42<28:52, 40.30s/it]

	Train Loss:   4.172 | Train PPL:  64.842
	Valid Loss:   4.570 | Valid PPL:  96.584


 16%|█▌        | 8/50 [05:22<28:14, 40.35s/it]

	Train Loss:   4.079 | Train PPL:  59.095
	Valid Loss:   4.652 | Valid PPL: 104.788


 18%|█▊        | 9/50 [06:03<27:34, 40.36s/it]

	Train Loss:   4.038 | Train PPL:  56.701
	Valid Loss:   4.594 | Valid PPL:  98.850


 20%|██        | 10/50 [06:43<26:49, 40.25s/it]

	Train Loss:   3.918 | Train PPL:  50.293
	Valid Loss:   5.156 | Valid PPL: 173.442


 22%|██▏       | 11/50 [07:23<26:08, 40.21s/it]

	Train Loss:   3.861 | Train PPL:  47.496
	Valid Loss:   4.843 | Valid PPL: 126.911


 24%|██▍       | 12/50 [08:03<25:29, 40.24s/it]

	Train Loss:   3.894 | Train PPL:  49.095
	Valid Loss:   4.710 | Valid PPL: 111.081


 26%|██▌       | 13/50 [08:43<24:47, 40.21s/it]

	Train Loss:   3.835 | Train PPL:  46.272
	Valid Loss:   4.700 | Valid PPL: 109.994


 28%|██▊       | 14/50 [09:24<24:09, 40.27s/it]

	Train Loss:   3.794 | Train PPL:  44.440
	Valid Loss:   4.785 | Valid PPL: 119.700


 30%|███       | 15/50 [10:04<23:29, 40.27s/it]

	Train Loss:   3.760 | Train PPL:  42.939
	Valid Loss:   4.874 | Valid PPL: 130.906


 32%|███▏      | 16/50 [10:44<22:48, 40.25s/it]

	Train Loss:   3.698 | Train PPL:  40.366
	Valid Loss:   4.943 | Valid PPL: 140.163


 34%|███▍      | 17/50 [11:24<22:08, 40.25s/it]

	Train Loss:   3.778 | Train PPL:  43.709
	Valid Loss:   4.938 | Valid PPL: 139.517


 36%|███▌      | 18/50 [12:05<21:26, 40.21s/it]

	Train Loss:   3.561 | Train PPL:  35.185
	Valid Loss:   4.902 | Valid PPL: 134.592


 38%|███▊      | 19/50 [12:45<20:45, 40.17s/it]

	Train Loss:   3.634 | Train PPL:  37.863
	Valid Loss:   4.845 | Valid PPL: 127.109


 40%|████      | 20/50 [13:25<20:06, 40.21s/it]

	Train Loss:   3.582 | Train PPL:  35.953
	Valid Loss:   5.008 | Valid PPL: 149.676


 42%|████▏     | 21/50 [14:05<19:26, 40.21s/it]

	Train Loss:   3.545 | Train PPL:  34.636
	Valid Loss:   4.900 | Valid PPL: 134.320


 44%|████▍     | 22/50 [14:45<18:44, 40.16s/it]

	Train Loss:   3.575 | Train PPL:  35.703
	Valid Loss:   5.064 | Valid PPL: 158.263


 46%|████▌     | 23/50 [15:25<18:03, 40.13s/it]

	Train Loss:   3.470 | Train PPL:  32.146
	Valid Loss:   4.875 | Valid PPL: 131.007


 48%|████▊     | 24/50 [16:06<17:25, 40.20s/it]

	Train Loss:   3.469 | Train PPL:  32.106
	Valid Loss:   5.192 | Valid PPL: 179.774


 50%|█████     | 25/50 [16:47<16:50, 40.44s/it]

	Train Loss:   3.530 | Train PPL:  34.133
	Valid Loss:   4.772 | Valid PPL: 118.174


 52%|█████▏    | 26/50 [17:27<16:12, 40.54s/it]

	Train Loss:   3.445 | Train PPL:  31.343
	Valid Loss:   4.750 | Valid PPL: 115.588


 54%|█████▍    | 27/50 [18:08<15:32, 40.54s/it]

	Train Loss:   3.459 | Train PPL:  31.777
	Valid Loss:   4.952 | Valid PPL: 141.405


 56%|█████▌    | 28/50 [18:48<14:49, 40.45s/it]

	Train Loss:   3.494 | Train PPL:  32.912
	Valid Loss:   4.745 | Valid PPL: 115.020


 58%|█████▊    | 29/50 [19:28<14:07, 40.34s/it]

	Train Loss:   3.380 | Train PPL:  29.362
	Valid Loss:   5.178 | Valid PPL: 177.318


 60%|██████    | 30/50 [20:08<13:25, 40.27s/it]

	Train Loss:   3.407 | Train PPL:  30.161
	Valid Loss:   4.864 | Valid PPL: 129.566


 62%|██████▏   | 31/50 [20:48<12:44, 40.24s/it]

	Train Loss:   3.306 | Train PPL:  27.274
	Valid Loss:   4.926 | Valid PPL: 137.852


 64%|██████▍   | 32/50 [21:29<12:03, 40.21s/it]

	Train Loss:   3.307 | Train PPL:  27.301
	Valid Loss:   4.808 | Valid PPL: 122.503


 66%|██████▌   | 33/50 [22:09<11:24, 40.25s/it]

	Train Loss:   3.206 | Train PPL:  24.672
	Valid Loss:   4.918 | Valid PPL: 136.771


 68%|██████▊   | 34/50 [22:49<10:44, 40.27s/it]

	Train Loss:   3.401 | Train PPL:  29.988
	Valid Loss:   4.837 | Valid PPL: 126.044


 70%|███████   | 35/50 [23:30<10:04, 40.33s/it]

	Train Loss:   3.174 | Train PPL:  23.906
	Valid Loss:   4.928 | Valid PPL: 138.158


 72%|███████▏  | 36/50 [24:10<09:24, 40.29s/it]

	Train Loss:   3.181 | Train PPL:  24.075
	Valid Loss:   4.902 | Valid PPL: 134.512


 74%|███████▍  | 37/50 [24:50<08:42, 40.23s/it]

	Train Loss:   3.137 | Train PPL:  23.031
	Valid Loss:   4.902 | Valid PPL: 134.559


 76%|███████▌  | 38/50 [25:30<08:02, 40.20s/it]

	Train Loss:   3.068 | Train PPL:  21.502
	Valid Loss:   4.942 | Valid PPL: 140.012


 78%|███████▊  | 39/50 [26:10<07:22, 40.23s/it]

	Train Loss:   3.224 | Train PPL:  25.133
	Valid Loss:   4.985 | Valid PPL: 146.211


 80%|████████  | 40/50 [26:51<06:42, 40.24s/it]

	Train Loss:   3.134 | Train PPL:  22.960
	Valid Loss:   5.271 | Valid PPL: 194.674


 82%|████████▏ | 41/50 [27:31<06:01, 40.19s/it]

	Train Loss:   3.103 | Train PPL:  22.254
	Valid Loss:   4.969 | Valid PPL: 143.924


 84%|████████▍ | 42/50 [28:11<05:21, 40.17s/it]

	Train Loss:   3.122 | Train PPL:  22.697
	Valid Loss:   4.891 | Valid PPL: 133.121


 86%|████████▌ | 43/50 [28:51<04:41, 40.19s/it]

	Train Loss:   3.184 | Train PPL:  24.146
	Valid Loss:   4.988 | Valid PPL: 146.621


 88%|████████▊ | 44/50 [29:31<04:01, 40.22s/it]

	Train Loss:   2.932 | Train PPL:  18.758
	Valid Loss:   5.182 | Valid PPL: 178.084


 90%|█████████ | 45/50 [30:11<03:20, 40.16s/it]

	Train Loss:   3.069 | Train PPL:  21.529
	Valid Loss:   5.036 | Valid PPL: 153.912


 92%|█████████▏| 46/50 [30:52<02:40, 40.15s/it]

	Train Loss:   2.958 | Train PPL:  19.255
	Valid Loss:   5.167 | Valid PPL: 175.326


 94%|█████████▍| 47/50 [31:32<02:00, 40.25s/it]

	Train Loss:   2.966 | Train PPL:  19.421
	Valid Loss:   5.083 | Valid PPL: 161.178


 96%|█████████▌| 48/50 [32:12<01:20, 40.24s/it]

	Train Loss:   2.879 | Train PPL:  17.791
	Valid Loss:   5.039 | Valid PPL: 154.278


 98%|█████████▊| 49/50 [32:52<00:40, 40.20s/it]

	Train Loss:   2.897 | Train PPL:  18.127
	Valid Loss:   5.018 | Valid PPL: 151.086


100%|██████████| 50/50 [33:33<00:00, 40.26s/it]

	Train Loss:   2.792 | Train PPL:  16.311
	Valid Loss:   5.058 | Valid PPL: 157.326



