In [1]:
import pandas as pd
import numpy as np

from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from utils.dataset import FeatureCaptionDataset
from utils.models import CNNEncoder, RNNDecoder
from utils.helpers import get_tokenizer_vocab, collate_fn_pad
from utils.trainer import Trainer

In [2]:
PATH = "dataset/"

In [3]:
df = pd.read_csv(PATH + "captions.txt")

In [4]:
df.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [5]:
tokenizer, vocab = get_tokenizer_vocab(df)

In [6]:
dataset = FeatureCaptionDataset(PATH + "embeddings/", df, tokenizer, vocab)

In [7]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=collate_fn_pad)
len(dataloader)

2529

In [8]:
encoder_dim, decoder_dim, attention_dim, embed_dim, vocab_size = 512, 256, 128, 64, len(vocab)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# max_len = 0
# for i in tqdm(range(len(dataset))):
#     _, cap = dataset[i]
#     cap_len = len(cap)
#     if cap_len > max_len:
#         max_len = cap_len
max_len = 42

In [10]:
encoder = CNNEncoder(1536, encoder_dim)

In [11]:
decoder = RNNDecoder(encoder_dim=encoder_dim, 
                     decoder_dim=decoder_dim, 
                     attention_dim=attention_dim, 
                     embedding_dim=embed_dim, 
                     vocab_size=vocab_size
                    )

In [12]:
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()))
ce_loss = nn.CrossEntropyLoss(reduction="none")

In [13]:
trainer = Trainer(encoder, decoder, optimizer, ce_loss, vocab, device)

In [14]:
trainer.train(dataloader, 2)

0it [00:00, ?it/s]

********** Epoch 1/2 **********
Average Batch 0 Loss: 5.746832142705503
Average Batch 100 Loss: 2.8956720352172853
Average Batch 200 Loss: 3.210963399786698
Average Batch 300 Loss: 3.152499198913574
Average Batch 400 Loss: 2.514632138338956
Average Batch 500 Loss: 2.224023183186849
Average Batch 600 Loss: 2.1356488863627114
Average Batch 700 Loss: 2.0676881408691408
Average Batch 800 Loss: 2.8970696131388345
Average Batch 900 Loss: 2.2200374603271484
Average Batch 1000 Loss: 2.3499270629882814
Average Batch 1100 Loss: 2.323818418714735
Average Batch 1200 Loss: 2.312026575991982
Average Batch 1300 Loss: 2.5658302307128906
Average Batch 1400 Loss: 1.861580335176908
Average Batch 1500 Loss: 2.323731952243381
Average Batch 1600 Loss: 2.0501084899902344
Average Batch 1700 Loss: 2.1074271731906467
Average Batch 1800 Loss: 2.6225140220240544
Average Batch 1900 Loss: 1.724917449951172
Average Batch 2000 Loss: 1.9071004867553711
Average Batch 2100 Loss: 1.8728458086649578
Average Batch 2200 Los