In [48]:
CSV_PATH = '/home/thnhan301/Documents/GIT/image_captioning/small_results.csv'
IMG_PATH = '/home/thnhan301/Documents/GIT/image_captioning/flickr30k_images/flickr30k_images/'

In [49]:
import torch
import pandas as pd
from spacy.lang.en import English
import torchtext

In [50]:
df = pd.read_csv(CSV_PATH, delimiter='|')

In [51]:
df.head(5)

Unnamed: 0,image_name,comment
0,1000092795.jpg,Two young guys with shaggy hair look at their...
1,1000092795.jpg,"Two young , White males are outside near many..."
2,1000092795.jpg,Two men in green shirts are standing in a yard .
3,1000092795.jpg,A man in a blue shirt standing in a garden .
4,1000092795.jpg,Two friends enjoy time spent together .


In [52]:
def load_data(df):
    return df.to_dict(orient='list')

In [53]:
train_data = load_data(df)

In [54]:
train_data['image_name'][0], train_data[' comment'][0]

('1000092795.jpg',
 ' Two young guys with shaggy hair look at their hands while hanging out in the yard .')

In [55]:
tokenizer = English()

In [56]:
def tokenize_sample(comment, tokenizer, max_length, sos_token, eos_token, lower):
    tokens = [token.text for token in tokenizer.tokenizer(comment.strip())][:max_length]
    if lower:
        tokens = [token.lower() for token in tokens]
    tokens = [sos_token] + tokens + [eos_token]
    return tokens

In [57]:
sos_token = "<sos>"
eos_token = "<eos>"
max_length = 200
lower = True
train_data["tokens"] = [tokenize_sample(cm,tokenizer,max_length,sos_token,eos_token,lower) for cm in train_data[" comment"]]

In [58]:
train_data[" comment"][0],train_data["tokens"][0], 

(' Two young guys with shaggy hair look at their hands while hanging out in the yard .',
 ['<sos>',
  'two',
  'young',
  'guys',
  'with',
  'shaggy',
  'hair',
  'look',
  'at',
  'their',
  'hands',
  'while',
  'hanging',
  'out',
  'in',
  'the',
  'yard',
  '.',
  '<eos>'])

In [59]:
def yield_token(data, key):
    for token in data[key]:
        yield token

In [60]:
import torchtext.vocab


min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"
special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token
]
vocab = torchtext.vocab.build_vocab_from_iterator(
    yield_token(train_data,"tokens"),
    min_freq=min_freq,
    specials=special_tokens
)

In [61]:
vocab.get_itos()[:10]

['<unk>', '<pad>', '<sos>', '<eos>', 'a', '.', 'in', 'on', 'the', 'of']

In [62]:
len(vocab)

160

In [63]:
unk_index = vocab[unk_token]
pad_index = vocab[pad_token]
unk_index, pad_index

(0, 1)

In [64]:
vocab.set_default_index(unk_index)

In [65]:
def token2ids(tokens, vocab):
    ids = vocab.lookup_indices(tokens)
    return ids

In [66]:
train_data["ids"] = [token2ids(t,vocab) for t in train_data["tokens"]]

In [67]:
train_data[" comment"][0],train_data["tokens"][0],train_data["ids"][0]

(' Two young guys with shaggy hair look at their hands while hanging out in the yard .',
 ['<sos>',
  'two',
  'young',
  'guys',
  'with',
  'shaggy',
  'hair',
  'look',
  'at',
  'their',
  'hands',
  'while',
  'hanging',
  'out',
  'in',
  'the',
  'yard',
  '.',
  '<eos>'],
 [2, 14, 25, 62, 11, 0, 130, 139, 20, 98, 132, 37, 0, 0, 6, 8, 159, 5, 3])

In [68]:
import torch.utils
import torch.utils.data
from PIL import Image

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform, path_prefix):
        self.img_paths = data["image_name"]
        self.ids = data["ids"]
        self.transform = transform
        self.path_prefix = path_prefix

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        image = Image.open(self.path_prefix+self.img_paths[index])
        torch_img = self.transform(image)
        return torch_img, torch.tensor(self.ids[index],dtype=torch.int64)

In [69]:
from torchvision.transforms import Compose, transforms

In [70]:
transform = Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [71]:
train_dataset = CustomDataset(train_data, transform, IMG_PATH)

In [72]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_img = torch.stack([b[0] for b in batch])
        batch_ids = [b[1] for b in batch]
        batch_ids = torch.nn.utils.rnn.pad_sequence(batch_ids,batch_first=True,padding_value=pad_index)
        return batch_img, batch_ids
    return collate_fn

In [73]:
def get_dataloader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader =  torch.utils.data.DataLoader(dataset,batch_size,shuffle,collate_fn=collate_fn)
    return data_loader

In [74]:
batch_size = 10
shuffle = True
train_dataloader = get_dataloader(train_dataset,batch_size,pad_index,shuffle)

In [124]:
from torchvision.models import resnet50, ResNet50_Weights
import random

In [75]:
class Encoder(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        resnet.fc = torch.nn.Linear(2048, hidden_dim)
        self.model = resnet
    
    def forward(self,x):
        return self.model(x)

In [121]:
class Decoder(torch.nn.Module):
    def __init__(self, hidden_dim, embedding_dim, vocab_size, padding_idx):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.embedding = torch.nn.Embedding(vocab_size,embedding_dim,padding_idx=padding_idx)
        self.rnn = torch.nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim,vocab_size)
    
    def forward(self, input, hidden):
        input = input.unsqueeze(1)
        embedded = self.embedding(input)
        output, hidden = self.rnn(embedded,(hidden))
        prediction = self.fc(output.squeeze(1))
        return prediction, hidden

In [130]:
class Img2Seq(torch.nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert(self.encoder.hidden_dim == self.decoder.hidden_dim), "Hidden dimensions of encoder and decoder must be equal"
    
    def forward(self, batch_imgs, batch_ids, teacher_forcing_ratio):
        batch_size = batch_imgs.shape[0]
        vocab_size = self.decoder.vocab_size
        seq_length = batch_ids.shape[1]
        outputs = torch.zeros(batch_size,seq_length,vocab_size).to(self.device)
        hidden = self.encoder(batch_imgs).unsqueeze(0) # D * numlayer x batch x hidden_dim
        inputs = batch_ids[:,0]
        for i in range(1, seq_length):
            output, hidden = self.decoder(inputs,hidden)
            outputs[:,i,:] = output
            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            inputs = batch_ids[:,i] if teacher_force else top1
        return outputs

In [131]:
vocab_size = len(vocab)
hidden_dim = 256
embedding_dim = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(hidden_dim)
decoder = Decoder(hidden_dim,embedding_dim,vocab_size,pad_index)
img2seq = Img2Seq(encoder,decoder,device).to(device)

In [132]:
img2seq.forward(torch.zeros(5,3,224,224),torch.ones((5,20),dtype=torch.int64),0.5).shape

torch.Size([5, 20, 160])

In [133]:
def init_weights(m):
    for name, param in m.named_parameters():
        torch.nn.init.uniform_(param.data, -0.08, 0.08)
img2seq.apply(init_weights)

Img2Seq(
  (encoder): Encoder(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=Tr

In [134]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"The model has {count_parameters(img2seq):,} trainable parameters")

The model has 24,192,992 trainable parameters


In [135]:
optimizer = torch.optim.Adam(img2seq.parameters())
criterion = torch.nn.CrossEntropyLoss(ignore_index=pad_index)

In [137]:
def train_fn(model, data_loader, optimizer, criterion, teacher_forcing_ratio, device):
    model.train()
    epoch_loss = 0
    for batch_imgs, batch_ids in data_loader:
        batch_imgs, batch_ids = batch_imgs.to(device), batch_ids.to(device)
        optimizer.zero_grad()
        output = model(batch_imgs,batch_ids,teacher_forcing_ratio)
        output_dim = output.shape[-1]
        output = output[:,1:,].reshape(-1,output_dim)
        target = batch_ids[:,1:].reshape(-1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [138]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for batch_imgs, batch_ids in data_loader:
            batch_imgs, batch_ids = batch_imgs.to(device), batch_ids.to(device)
            output = model(batch_imgs,batch_ids,0)
            output_dim = output.shape[-1]
            output = output[:,1:,].reshape(-1,output_dim)
            target = batch_ids[:,1:].reshape(-1)
            loss = criterion(output, target)
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

In [139]:
n_epochs = 1
teacher_forcing_ratio = 0.7
for epoch in range(n_epochs):
    train_loss = train_fn(
        img2seq,
        train_dataloader,
        optimizer,
        criterion,
        teacher_forcing_ratio,
        device,
    )
    print(train_loss)

4.565027046203613
