In [1]:
import pandas as pd
import numpy as np

from tqdm.notebook import tqdm

import torch

from utils.dataset import FeatureCaptionDataset
from utils.models import CNNEncoder, RNNDecoder
from utils.helpers import get_tokenizer_vocab, collate_fn_pad
from utils.trainer import train

In [2]:
PATH = "dataset/"

In [3]:
df = pd.read_csv(PATH + "captions.txt")

In [4]:
df.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [5]:
tokenizer, vocab = get_tokenizer_vocab(df)

In [6]:
vocab_size = len(vocab)
print(vocab_size)

9213


In [7]:
dataset = FeatureCaptionDataset(PATH + "embeddings/", df, tokenizer, vocab)

In [8]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=collate_fn_pad)
len(dataloader)

2529

In [9]:
encoder_dim, decoder_dim, attention_dim, embed_dim, vocab_size = 512, 256, 128, 64, len(vocab)

In [10]:
# max_len = 0
# for i in tqdm(range(len(dataset))):
#     _, cap = dataset[i]
#     cap_len = len(cap)
#     if cap_len > max_len:
#         max_len = cap_len
max_len = 42

In [11]:
encoder = CNNEncoder(1536, encoder_dim)

In [12]:
decoder = RNNDecoder(encoder_dim=encoder_dim, 
                     decoder_dim=decoder_dim, 
                     attention_dim=attention_dim, 
                     embedding_dim=embed_dim, 
                     vocab_size=vocab_size
                    )

In [13]:
train(dataloader, encoder, decoder, vocab)

0it [00:00, ?it/s]

********** Epoch 1/2 **********
Average Batch 0 Loss: 5.762305218240489
Average Batch 100 Loss: 2.8189464569091798
Average Batch 200 Loss: 3.1434729726690995
Average Batch 300 Loss: 3.1518243153889975
Average Batch 400 Loss: 2.501363927667791
Average Batch 500 Loss: 2.253016789754232
Average Batch 600 Loss: 2.0805463790893555
Average Batch 700 Loss: 2.069033355712891
Average Batch 800 Loss: 2.8799352645874023
Average Batch 900 Loss: 2.181780208240856
Average Batch 1000 Loss: 2.288524322509766
Average Batch 1100 Loss: 2.321051491631402
Average Batch 1200 Loss: 2.2378260963841488
Average Batch 1300 Loss: 2.602543917569247
Average Batch 1400 Loss: 1.8899390880878155
Average Batch 1500 Loss: 2.2468927171495228
Average Batch 1600 Loss: 2.0327493286132814
Average Batch 1700 Loss: 2.1308316124810114
Average Batch 1800 Loss: 2.5826867756090666
Average Batch 1900 Loss: 1.6867263793945313
Average Batch 2000 Loss: 1.924961471557617
Average Batch 2100 Loss: 1.8724252382914226
Average Batch 2200 Lo

In [14]:
real = torch.randn((16, 23))

In [15]:
real.item()

ValueError: only one element tensors can be converted to Python scalars