In [5]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [6]:
IMAGES_PATH = "../phase_1/data/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = "../phase_1/data/flickr8k/captions_trainImages.txt"  # Caption file
TEST_IMAGES_PATH = "../phase_1/test copy/"  # Directory with test images

In [7]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class  import Vocabulary
nltk.download('punkt_tab')
import json

tokens = []
counter = Counter()

import csv

def build_vocab(json_path, threshold=5, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)  # skip header: image,caption

        for row in reader:
            if len(row) < 2:
                continue
            img_name, caption = row
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = nltk.tokenize.word_tokenize(caption)
            counter.update(tokens)
            count +=1
            if limit and count >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab, image_captions

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [8]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

Total vocabulary size: 2554


In [9]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def collate_fn(batch):
    images, captions = zip(*batch)

    images = torch.stack(images, 0)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions

from torch.utils.data import DataLoader
from flickr_dataset  import FlickrDataset 

train_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn= collate_fn
)

print(len(train_dataset))
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

6000
<class 'torch.Tensor'>
torch.Size([3, 224, 224])
tensor([ 1,  3,  5,  6,  4,  7,  8,  9, 10, 11,  4, 12, 13, 14,  6, 15,  3, 16,
        17,  2])
20


In [10]:
import torch.nn as nn
from model  import TransformerEncoderViT
from model  import TransformerDecoder

encoder = TransformerEncoderViT(embed_size=256).to(device)
decoder = TransformerDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])

params = filter(lambda p: p.requires_grad, decoder.parameters())
optimizer = torch.optim.AdamW(
    params,
    lr=3e-4,
    weight_decay=1e-2
)


In [11]:
train_losses = []
for epoch in range(10):
    total_train_loss = 0
    for images, captions in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        # Teacher forcing:
        # input:  <start> w1 w2 ... w(T-1)
        # target: w1 w2 ... w(T-1) <end>
        captions_in = captions[:, :-1]
        targets     = captions[:, 1:]

        optimizer.zero_grad(set_to_none=True)
        
        memory = encoder(images)
        logits = decoder(memory, captions_in)   # <-- IMPORTANT
        
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1)
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), 1.0)
        optimizer.step()

        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch}: Train={avg_train_loss:.4f}")

100%|██████████| 188/188 [00:54<00:00,  3.45it/s]


Epoch 0: Train=4.2695


100%|██████████| 188/188 [00:56<00:00,  3.32it/s]


Epoch 1: Train=3.4454


100%|██████████| 188/188 [00:53<00:00,  3.48it/s]


Epoch 2: Train=3.1766


100%|██████████| 188/188 [00:53<00:00,  3.50it/s]


Epoch 3: Train=2.9988


100%|██████████| 188/188 [00:53<00:00,  3.51it/s]


Epoch 4: Train=2.8679


100%|██████████| 188/188 [00:54<00:00,  3.47it/s]


Epoch 5: Train=2.7760


100%|██████████| 188/188 [00:53<00:00,  3.50it/s]


Epoch 6: Train=2.6627


100%|██████████| 188/188 [00:53<00:00,  3.50it/s]


Epoch 7: Train=2.5932


100%|██████████| 188/188 [00:53<00:00,  3.52it/s]


Epoch 8: Train=2.5180


100%|██████████| 188/188 [00:53<00:00,  3.50it/s]

Epoch 9: Train=2.4593





In [12]:
torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")
torch.save(vocab, "models/vocab.pkl")