In [3]:
import os  
import pandas as pd  
import spacy  
import torch
import torch.nn as nn
import numpy as np
import statistics
import torch.optim as optim
from tqdm import tqdm
import torchvision.models as models
from torch.nn.utils.rnn import pad_sequence  
from torch.utils.data import DataLoader, Dataset
from PIL import Image 
import torchvision.transforms as transforms

In [4]:
torch.cuda.get_device_name()

'Tesla T4'

In [5]:
spacy_en = spacy.load("en")

In [6]:
class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<pad>", 1: "<sos>", 2: "<eos>", 3: "<unk>"}
        self.stoi = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_en(text):
        return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

    def build_vocab(self, sentences):
        frequencies = {}
        i = 4

        for sentence in sentences:
            for word in self.tokenizer_en(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = i
                    self.itos[i] = word
                    i += 1

    def numericalize(self, text):
        tokens = self.tokenizer_en(text)

        return [self.stoi[t] if t in self.stoi else self.stoi["<unk>"] for t in tokens]

In [7]:
class LoadDataset(Dataset):
    def __init__(self, root_dir, captions_file, transform, freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        self.imgs = self.df["image"]
        self.captions = self.df["caption"]

        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
        img = self.transform(img)

        num_caption = [self.vocab.stoi["<sos>"]]
        num_caption += self.vocab.numericalize(caption)
        num_caption.append(self.vocab.stoi["<eos>"])
        num_caption = torch.tensor(num_caption)

        return img, num_caption

In [8]:
class CollateFn:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        imgs = [b[0].unsqueeze(0) for b in batch]
        imgs = torch.cat(imgs, dim=0)
        captions = [b[1] for b in batch]
        captions = pad_sequence(captions, batch_first=False, padding_value=self.pad_idx)

        return imgs, captions

In [9]:
class CNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super().__init__()
        self.train_CNN = train_CNN
        self.inception = models.inception_v3(pretrained=True, aux_logits=False)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, images):
        features = self.inception(images)
        return self.dropout(self.relu(features))

In [10]:
class RNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(0.5)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embedding(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hidden, _ = self.lstm(embeddings)
        output = self.linear(hidden)

        return output

In [11]:
class Net(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super().__init__()
        self.cnn = CNN(embed_size)
        self.rnn = RNN(embed_size, hidden_size, vocab_size, num_layers)

    def forward(self, images, captions):
        features = self.cnn(images)
        output = self.rnn(features, captions)
        return output

    def caption_image(self, image, vocab, max_len=50):
        caption = []
        with torch.no_grad():
            x = self.cnn(image).unsqueeze(0)
            states = None
            
            for _ in range(max_len):
                hidden, states = self.rnn.lstm(x, states)
                output = self.rnn.linear(hidden.squeeze(0))
                prediction = output.argmax(1)
                caption.append(prediction.item())
                x = self.rnn.embedding(prediction).unsqueeze(0)

                if vocab.itos[prediction.item()] == "<eos>":
                    break

        return [vocab.itos[i] for i in caption]

In [12]:
transform = transforms.Compose([transforms.Resize((356, 356)), 
                                transforms.RandomCrop((299, 299)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [13]:
dataset = LoadDataset("flickr8k/images/", "flickr8k/captions.txt", transform)
pad_idx = dataset.vocab.stoi["<pad>"]
batches = DataLoader(dataset=dataset, batch_size=32, num_workers=2, shuffle=True, 
                    pin_memory=True, collate_fn=CollateFn(pad_idx=pad_idx))

In [14]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_CNN = False
embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
lr = 3e-4
epochs = 100

In [15]:
device

device(type='cuda')

In [16]:
net = Net(embed_size, hidden_size, vocab_size, num_layers).to(device)

In [17]:
net

Net(
  (cnn): CNN(
    (inception): Inception3(
      (Conv2d_1a_3x3): BasicConv2d(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv2d_2a_3x3): BasicConv2d(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (Conv2d_2b_3x3): BasicConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (Conv2d_3b_1x1): BasicConv2d(
        (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_

In [18]:
loss_fn = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<pad>"])
opt = optim.Adam(net.parameters(), lr)

In [19]:
for name, param in net.cnn.inception.named_parameters():
    if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True
    else:
        param.requires_grad = train_CNN

In [20]:
def save_checkpoint(net, opt, filename):
    checkpoint = {"net_dict": net.state_dict(), "opt_dict": opt.state_dict()}
    torch.save(checkpoint, filename)

def load_checkpoint(net, opt, filename):
    checkpoint = torch.load(filename)
    net.load_state_dict(checkpoint["net_dict"])
    opt.load_state_dict(checkpoint["opt_dict"])

In [21]:
def CaptionImage(net, device, image, dataset):
    transform = transforms.Compose([transforms.Resize((299, 299)),
                                    transforms.ToTensor(), 
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    net.eval()

    img = transform(Image.open(image).convert("RGB")).unsqueeze(0)
    words = net.caption_image(img.to(device), dataset.vocab)
    caption = ""
    for w in words:
        caption += w + " "
    
    net.train()
    
    print(caption)

In [22]:
CaptionImage(net, device, "test_images/dog.jpg", dataset)
CaptionImage(net, device, "test_images/child.jpg", dataset)

bale several busy babies lab playhouse painting combat crouched cliffs bat shop eyed swimsuit festive woodland & bench spraying huddled bows stair raising fan candy strewn outdoor after wand sky flowers wades vehicles headband life kind costumes european turning gravel gated gated gated show lit box handbag retrieving into sister 
display bib youth drenched food handbag retrieving beret skyscraper stairs safety sky pinata gym cardboard music sailboat tightrope frozen dancers carpeted woods outfit tackled camping double greenery telescope unseen themselves much bulls snowbank watery football except incoming shoulder canoes bathroom kicks snowball fellow stop somthing chip poodle chairs plaid bouquet 


In [24]:
net.train()
for epoch in range(epochs):
    batch_losses = []
    loop = tqdm(batches, total=len(batches))
    for imgs, captions in loop:
        imgs = imgs.to(device)
        captions = captions.to(device)

        outputs = net(imgs, captions[:-1])
        loss = loss_fn(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))

        opt.zero_grad()
        loss.backward(loss)
        opt.step()
        batch_losses.append(loss.item())

    print(f" Epoch: {epoch} | Loss: {np.round(sum(batch_losses)/len(batch_losses), 4)}")
    CaptionImage(net, device, "test_images/dog.jpg", dataset)
    CaptionImage(net, device, "test_images/child.jpg", dataset)
    print("")

    if epoch % 5 == 0:
        save_checkpoint(net, opt, f"checkpoint_{epoch}.pth.tar")

100%|██████████| 1265/1265 [06:24<00:00,  3.29it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 21 | Loss: 2.2553
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:24<00:00,  3.29it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 22 | Loss: 2.2435
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:25<00:00,  3.28it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 23 | Loss: 2.2308
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is running on a grassy field . <eos> 



100%|██████████| 1265/1265 [06:25<00:00,  3.28it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 24 | Loss: 2.2187
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:25<00:00,  3.28it/s]


 Epoch: 25 | Loss: 2.2067
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:26<00:00,  3.27it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 26 | Loss: 2.1966
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:27<00:00,  3.27it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 27 | Loss: 2.1845
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:28<00:00,  3.26it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 28 | Loss: 2.1756
<sos> a dog is running through the water . <eos> 
<sos> a young boy in a blue shirt is running on a grassy field . <eos> 



100%|██████████| 1265/1265 [06:28<00:00,  3.26it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 29 | Loss: 2.1643
<sos> a dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:28<00:00,  3.26it/s]


 Epoch: 30 | Loss: 2.1575
<sos> a dog is running through the water . <eos> 
<sos> a young boy wearing a blue shirt is running through a field of grass . <eos> 



100%|██████████| 1265/1265 [06:24<00:00,  3.29it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 31 | Loss: 2.1472
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:26<00:00,  3.27it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 32 | Loss: 2.1376
<sos> a brown dog is running through the water . <eos> 
<sos> a little boy in a blue shirt is playing with a soccer ball . <eos> 



100%|██████████| 1265/1265 [06:27<00:00,  3.26it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 33 | Loss: 2.1302
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:27<00:00,  3.26it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 34 | Loss: 2.1219
<sos> a brown dog is running through the water . <eos> 
<sos> a little boy in a blue shirt is playing with a soccer ball . <eos> 



100%|██████████| 1265/1265 [06:26<00:00,  3.28it/s]


 Epoch: 35 | Loss: 2.1143
<sos> a dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:24<00:00,  3.29it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 36 | Loss: 2.1064
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a toy . <eos> 



100%|██████████| 1265/1265 [06:26<00:00,  3.28it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 37 | Loss: 2.0989
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is running on a grassy lawn . <eos> 



100%|██████████| 1265/1265 [06:26<00:00,  3.28it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 38 | Loss: 2.0923
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a ball . <eos> 



100%|██████████| 1265/1265 [06:27<00:00,  3.26it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 39 | Loss: 2.0848
<sos> a brown dog is running through the water . <eos> 
<sos> a little boy in a red shirt is playing with a ball . <eos> 



100%|██████████| 1265/1265 [06:28<00:00,  3.26it/s]


 Epoch: 40 | Loss: 2.0781
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is running on a grassy field . <eos> 



100%|██████████| 1265/1265 [06:26<00:00,  3.27it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 41 | Loss: 2.0717
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is running through a grassy field . <eos> 



100%|██████████| 1265/1265 [06:28<00:00,  3.26it/s]
  0%|          | 0/1265 [00:00<?, ?it/s]

 Epoch: 42 | Loss: 2.0644
<sos> a brown dog is running through the water . <eos> 
<sos> a little girl in a pink shirt is playing with a ball in a field . <eos> 



 21%|██        | 263/1265 [01:20<04:40,  3.57it/s]

KeyboardInterrupt: ignored

In [25]:
save_checkpoint(net, opt, f"checkpoint_{epoch-1}.pth.tar")

In [26]:
CaptionImage(net, device, "test_images/dog.jpg", dataset)

<sos> a dog is running through the water . <eos> 


In [27]:
CaptionImage(net, device, "test_images/child.jpg", dataset)

<sos> a little girl in a pink shirt is playing with a ball in a field . <eos> 


In [28]:
CaptionImage(net, device, "test_images/boat.jpg", dataset)

<sos> a man in a wetsuit is surfing . <eos> 


In [29]:
CaptionImage(net, device, "test_images/street.jpg", dataset)

<sos> a man in a red shirt is standing in front of a large crowd of people . <eos> 


In [30]:
CaptionImage(net, device, "test_images/elon_musk.jpg", dataset)

<sos> a man in a white shirt and a woman in a white shirt and black pants . <eos> 
