In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

import pandas as pd
import spacy

from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import time
%matplotlib inline

In [8]:
class Vocabulary:
    def __init__(self, freq_threshold=2):
        self.freq_threshold = freq_threshold
        self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
        self.stoi = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
        
    def __len__(self):
        return len(self.itos)
    
    @staticmethod
    def tokenize_en(text):
        return [tok.text.lower() for tok in spacy_en.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in tqdm(sentence_list):
            for word in self.tokenize_en(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    def numericalize(self, text):
        tokenizd_text = self.tokenize_en(text)
        
        return [self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] 
                for token in tokenizd_text]

In [9]:
df = pd.read_csv('one_desc.txt')
captions = df['caption'].tolist()

In [10]:
spacy_en = spacy.load('en')

In [11]:
vocab = Vocabulary()
vocab.build_vocabulary(captions)

100%|██████████| 8091/8091 [03:30<00:00, 38.37it/s]


In [12]:
[vocab.itos[i] for i in vocab.numericalize(captions[0])]

['a',
 'child',
 'in',
 'a',
 'pink',
 'dress',
 'is',
 'climbing',
 'up',
 'a',
 'set',
 'of',
 'stairs',
 'in',
 'an',
 '<UNK>',
 'way',
 '.']

In [13]:
class FlickrDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform=None, freq_threshold=2):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform
        
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(self.root_dir+'/Images/'+img_id).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])
        
        return img, torch.tensor(numericalized_caption)

In [14]:
ds = FlickrDataset('.', 'one_desc.txt')
plt.imshow(ds[0][0]);
' '.join([ds.vocab.itos[i.item()] for i in ds[0][1]])

 10%|█         | 814/8091 [00:28<04:13, 28.72it/s]


KeyboardInterrupt: 

In [10]:
pad_idx = ds.vocab.stoi["<PAD>"]

In [15]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=pad_idx)
        
#         print(imgs.shape)
#         print(targets.shape)
        return imgs, targets

In [None]:
class MyCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs, dim=0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first=False, padding_value=pad_idx)
        
#         print(imgs.shape)
#         print(targets.shape)
        return imgs, targets

class EncoderCNN(nn.Module):
    def __init__(self, hid_dim, dropout, train_cnn=False):
        super().__init__()
        
        self.hid_dim = hid_dim        
        self.train_cnn = train_cnn
        
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, hid_dim)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, images):
        features = self.dropout(self.relu(self.inception(images)))
        return features
    
    

In [16]:
class EncoderCNN(nn.Module):
    def __init__(self, hid_dim, dropout, train_cnn=False):
        super().__init__()
        
        self.hid_dim = hid_dim        
        self.train_cnn = train_cnn
        
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, hid_dim)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, images):
        features = self.dropout(self.relu(self.inception(images)))
        return features

In [17]:
class DecoderRNN(nn.Module):
    def __init__(self, emb_dim, hid_dim, vocab_sz, dropout):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.vocab_sz = vocab_sz
        
        self.embedding = nn.Embedding(vocab_sz, emb_dim)
        
        # GRU: inputs-> [embeddings + context], output -> hidden, output
        self.rnn = nn.GRU(emb_dim+hid_dim, hid_dim)
        
        # FC: inputs:-> [embeddings + context + output]
        self.fc_out = nn.Linear(emb_dim + hid_dim*2, vocab_sz)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, input, hidden, context):
        #input = [batch size]
        #hidden = [1, batch size, hid dim]
        #context = [1, batch size, hid dim]
        
        input = input.unsqueeze(0)
        #input = [1, batch size]
        
        embedded = self.dropout(self.embedding(input))
        #embedded = [1, batch size, emb dim]
        
#         print(f'Context shape {context.shape}')
#         print(f'embeded shape {embedded.shape}')
        
        emb_con = torch.cat((embedded, context), dim=2)
        #emb_con = [1, batch size, emb_dim + hid_dim]
        
#         print(f"Hidden shape {hidden.shape}")
#         print(f"embedded shape {embedded.shape}")
#         print(f"context shape {context.shape}")
#         print(f"emb_con shape {emb_con.shape}")
#         print('_'*22)
        
        output, hidden = self.rnn(emb_con, hidden)
        #output = [1, batch size, hid dim] -> seq len and n directions = 1
        #hidden = [1, batch size, hid dim] -> n layers and n directions =1
        
        output = torch.cat((embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), dim=1)
        #output = [batch size, emb_dim + hid_dim*2]
        
        predictions = self.fc_out(output)
        #predictions = [batch size, vocab size]
        
        return predictions, hidden

In [18]:
###
### No teacher forcing for now
###
class Img2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        
        assert self.encoder.hid_dim == self.decoder.hid_dim,\
            'The hid dim of the context vector must equal the hid dim of the decoder !!'
            
    def forward(self, img, trg):
        #src = [batch size, 3, 224, 224]
        #trg = [trg len, batch size]
        
        batch_size = trg.shape[1]
        trg_len = trg.shape[0]
        trg_vocab_size = self.decoder.vocab_sz
        
        outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to(self.device)
        
        context = self.encoder(img)
        context = context.unsqueeze(0)
        hidden = context
        
        input = trg[0,:]
        
        for t in range(1, trg_len):
            output, hidden = self.decoder(input, hidden, context)
            
            outputs[t] = output
            
            top1 = output.argmax(1)
            
            input = top1
            
        return outputs

In [26]:
HID_DIM = 256
EMB_DIM = 256
DROPOUT = .5
VOCAB_LENGTH = len(vocab)
TRAIN_CNN = False
bs = 128
lr = 3e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

enc = EncoderCNN(HID_DIM, DROPOUT)
dec = DecoderRNN(EMB_DIM, HID_DIM, VOCAB_LENGTH, DROPOUT)

model = Img2Seq(enc, dec, device).to(device)

In [22]:
criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=lr)

# only finetune the CNN
for name, param in model.encoder.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = TRAIN_CNN

In [23]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 3,534,904 trainable parameters


In [24]:
# transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [19]:
train_loader = DataLoader(
    dataset=FlickrDataset('.', 'one_desc.txt', transform),
    batch_size=bs,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
    collate_fn=MyCollate(pad_idx=pad_idx),
    drop_last=True
)

In [20]:
for xb, yb in train_loader:
    print(xb.shape, yb.shape)
    break

torch.Size([128, 3, 224, 224]) torch.Size([27, 128])


In [21]:
xb[0].shape

torch.Size([3, 224, 224])

In [22]:
test_encoder = EncoderCNN(HID_DIM, DROPOUT)
preds = test_encoder(xb)
preds.shape

torch.Size([128, 256])

In [23]:
test_decoder = DecoderRNN(EMB_DIM, HID_DIM, VOCAB_LENGTH, DROPOUT)
outs, hidden = test_decoder(yb[0], preds.unsqueeze(0), preds.unsqueeze(0))

In [24]:
outs.shape, hidden.shape

(torch.Size([128, 8091]), torch.Size([1, 128, 256]))

In [25]:
outputs = model(xb.to(device), yb.to(device))
outputs.shape

torch.Size([27, 128, 8091])

In [26]:
del xb, yb, ds

In [27]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    
    for idx, (imgs, captions) in tqdm(
        enumerate(iterator), total=len(iterator), leave=False
    ):
        
        optimizer.zero_grad()
        
        imgs = imgs.to(model.device)
        captions = captions.to(model.device)
        
        output = model(imgs, captions)
        
        #output = [trg len, batch size, output dim]
        output_dim = output.shape[-1]
        #flattening the output and target + ignoring <sos>
        output = output[1:].view(-1, output_dim)
        trg = captions[1:].view(-1)
        
        #trg = [(trg len -1) * batch size]
        #output = [(trg len -1) * batch size, output dim]
        
        
        loss = criterion(output, trg)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [28]:
# def evaluate(model, iterator, criterion):
    
#     model.eval()
    
#     epoch_loss = 0
    
#     with torch.no_grad():
    
#         for i, batch in enumerate(iterator):

#             src = batch.src
#             trg = batch.trg

#             output = model(src, trg, 0) #turn off teacher forcing

#             #trg = [trg len, batch size]
#             #output = [trg len, batch size, output dim]

#             output_dim = output.shape[-1]
            
#             output = output[1:].view(-1, output_dim)
#             trg = trg[1:].view(-1)

#             #trg = [(trg len - 1) * batch size]
#             #output = [(trg len - 1) * batch size, output dim]

#             loss = criterion(output, trg)

#             epoch_loss += loss.item()
        
#     return epoch_loss / len(iterator)

In [29]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [30]:
import math
N_EPOCHS = 50
CLIP = 1

best_train_loss = float('inf')

for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_loader, optimizer, criterion, CLIP)
#     valid_loss = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if train_loss < best_train_loss:
        best_train_loss = train_loss
        torch.save(model.state_dict(), 'GRU-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    #print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

                                               

Epoch: 01 | Time: 0m 30s
	Train Loss: 5.149 | Train PPL: 172.326


                                               

Epoch: 02 | Time: 0m 29s
	Train Loss: 4.588 | Train PPL:  98.336


                                               

Epoch: 03 | Time: 0m 29s
	Train Loss: 4.514 | Train PPL:  91.281


                                               

Epoch: 04 | Time: 0m 28s
	Train Loss: 4.465 | Train PPL:  86.902


                                               

Epoch: 05 | Time: 0m 27s
	Train Loss: 4.434 | Train PPL:  84.276


                                               

Epoch: 06 | Time: 0m 29s
	Train Loss: 4.399 | Train PPL:  81.397


                                               

Epoch: 07 | Time: 0m 29s
	Train Loss: 4.363 | Train PPL:  78.526


                                               

Epoch: 08 | Time: 0m 29s
	Train Loss: 4.345 | Train PPL:  77.107


                                               

Epoch: 09 | Time: 0m 29s
	Train Loss: 4.330 | Train PPL:  75.914


                                               

Epoch: 10 | Time: 0m 28s
	Train Loss: 4.310 | Train PPL:  74.476


                                               

Epoch: 11 | Time: 0m 27s
	Train Loss: 4.299 | Train PPL:  73.643


                                               

Epoch: 12 | Time: 0m 28s
	Train Loss: 4.287 | Train PPL:  72.725


                                               

Epoch: 13 | Time: 0m 29s
	Train Loss: 4.277 | Train PPL:  72.029


                                               

Epoch: 14 | Time: 0m 29s
	Train Loss: 4.266 | Train PPL:  71.270


                                               

Epoch: 15 | Time: 0m 29s
	Train Loss: 4.256 | Train PPL:  70.530


                                               

Epoch: 16 | Time: 0m 29s
	Train Loss: 4.251 | Train PPL:  70.155


                                               

Epoch: 17 | Time: 0m 28s
	Train Loss: 4.246 | Train PPL:  69.856


                                               

Epoch: 18 | Time: 0m 30s
	Train Loss: 4.233 | Train PPL:  68.917


                                               

Epoch: 19 | Time: 0m 32s
	Train Loss: 4.232 | Train PPL:  68.849


                                               

Epoch: 20 | Time: 0m 32s
	Train Loss: 4.225 | Train PPL:  68.382


                                               

Epoch: 21 | Time: 0m 32s
	Train Loss: 4.213 | Train PPL:  67.530


                                               

Epoch: 22 | Time: 0m 31s
	Train Loss: 4.210 | Train PPL:  67.369


                                               

Epoch: 23 | Time: 0m 32s
	Train Loss: 4.204 | Train PPL:  66.962


                                               

Epoch: 24 | Time: 0m 32s
	Train Loss: 4.197 | Train PPL:  66.468


                                               

Epoch: 25 | Time: 0m 31s
	Train Loss: 4.195 | Train PPL:  66.338


                                               

Epoch: 26 | Time: 0m 29s
	Train Loss: 4.180 | Train PPL:  65.351


                                               

Epoch: 27 | Time: 0m 28s
	Train Loss: 4.184 | Train PPL:  65.597


                                               

Epoch: 28 | Time: 0m 28s
	Train Loss: 4.186 | Train PPL:  65.767


                                               

Epoch: 29 | Time: 0m 28s
	Train Loss: 4.177 | Train PPL:  65.147


                                               

Epoch: 30 | Time: 0m 28s
	Train Loss: 4.173 | Train PPL:  64.887


                                               

Epoch: 31 | Time: 0m 28s
	Train Loss: 4.167 | Train PPL:  64.528


                                               

Epoch: 32 | Time: 0m 27s
	Train Loss: 4.162 | Train PPL:  64.175


                                               

Epoch: 33 | Time: 0m 28s
	Train Loss: 4.160 | Train PPL:  64.087


                                               

Epoch: 34 | Time: 0m 29s
	Train Loss: 4.154 | Train PPL:  63.702


                                               

Epoch: 35 | Time: 0m 28s
	Train Loss: 4.158 | Train PPL:  63.925


                                               

Epoch: 36 | Time: 0m 29s
	Train Loss: 4.147 | Train PPL:  63.267


                                               

Epoch: 37 | Time: 0m 28s
	Train Loss: 4.150 | Train PPL:  63.435


                                               

Epoch: 38 | Time: 0m 29s
	Train Loss: 4.149 | Train PPL:  63.391


                                               

Epoch: 39 | Time: 0m 29s
	Train Loss: 4.141 | Train PPL:  62.859


                                               

Epoch: 40 | Time: 0m 28s
	Train Loss: 4.141 | Train PPL:  62.870


                                               

Epoch: 41 | Time: 0m 27s
	Train Loss: 4.136 | Train PPL:  62.558


                                               

Epoch: 42 | Time: 0m 29s
	Train Loss: 4.134 | Train PPL:  62.439


                                               

Epoch: 43 | Time: 0m 28s
	Train Loss: 4.126 | Train PPL:  61.934


                                               

Epoch: 44 | Time: 0m 28s
	Train Loss: 4.121 | Train PPL:  61.592


                                               

Epoch: 45 | Time: 0m 27s
	Train Loss: 4.120 | Train PPL:  61.572


                                               

Epoch: 46 | Time: 0m 28s
	Train Loss: 4.120 | Train PPL:  61.540


                                               

Epoch: 47 | Time: 0m 29s
	Train Loss: 4.119 | Train PPL:  61.475


                                               

Epoch: 48 | Time: 0m 29s
	Train Loss: 4.116 | Train PPL:  61.318


                                               

Epoch: 49 | Time: 0m 28s
	Train Loss: 4.122 | Train PPL:  61.699


                                               

Epoch: 50 | Time: 0m 27s
	Train Loss: 4.113 | Train PPL:  61.151


In [27]:
# model = Img2Seq(encoder, decoder, device).to(device)
model.load_state_dict(torch.load("../models/gru_no_tf.pth"))

RuntimeError: Error(s) in loading state_dict for Img2Seq:
	size mismatch for decoder.embedding.weight: copying a param with shape torch.Size([4461, 256]) from checkpoint, the shape in current model is torch.Size([2360, 256]).
	size mismatch for decoder.fc_out.weight: copying a param with shape torch.Size([4461, 768]) from checkpoint, the shape in current model is torch.Size([2360, 768]).
	size mismatch for decoder.fc_out.bias: copying a param with shape torch.Size([4461]) from checkpoint, the shape in current model is torch.Size([2360]).

In [None]:
test_dict = get_test_data(DF_PATH)

preds, trgs = predict_test(test_dict, IMAGES_PATH, model, vocab, n_images=1000)
print_scores(preds, trgs)