In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [2]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = torch.device('cpu')
print(device)

cuda


In [3]:
IMAGES_PATH = "../phase_1/data/flickr30k_images/flickr30k_images"  # Directory with training images
CAPTIONS_PATH = "../phase_1/data/flickr30k_images/results_training.csv"  # Caption file
TEST_IMAGES_PATH = "../phase_1/test copy/"  # Directory with test images

In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import nltk
import os
import random

# Update the folowing datset call class to handle flikr8k dataset formate


import torch
from torch.utils.data import Dataset
from PIL import Image
import nltk
import os
import random
import csv

class FlickrDataset30K(Dataset):
    def __init__(self, root, captions_path, vocab, transform=None, max_samples=None):
        self.root = root
        self.vocab = vocab
        self.transform = transform

        image_captions = {}
        with open(captions_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                # Skip header
                if line.startswith("image_name"):
                    continue

                # Split by '|'
                parts = [p.strip() for p in line.split("|")]
                if len(parts) < 3:
                    continue

                img_name,_, caption = parts
                if img_name not in image_captions:
                        image_captions[img_name] = []
                image_captions[img_name].append(caption)
        # Load captions
        self.image_captions = image_captions

        self.image_ids = list(self.image_captions.keys())

        if max_samples:
            self.image_ids = self.image_ids[:max_samples]

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        vocab = self.vocab
        img_name = self.image_ids[index]
        img_path = os.path.join(self.root, img_name)

        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
        except:
            # Skip corrupted image
            return self.__getitem__((index + 1) % len(self.image_ids))
        

        # Randomly select one caption per image
        caption = random.choice(self.image_captions[img_name])

        tokens = nltk.tokenize.word_tokenize(caption)

        caption_indices = [vocab.word2idx["<start>"]]
        caption_indices += [
            vocab.word2idx.get(token, vocab.word2idx["<unk>"])
            for token in tokens
        ]
        caption_indices.append(vocab.word2idx["<end>"])

        caption_tensor = torch.tensor(caption_indices)

        return image, caption_tensor


In [5]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class import Vocabulary
nltk.download('punkt')
import json

tokens = []
counter = Counter()

import csv
import string

def build_vocab(captions_path, threshold=3, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(captions_path, "r", encoding="utf-8") as f:
        # reader = csv.reader(f)
        # next(reader)  # skip header: image,caption

        for line in f:
            line = line.strip()
            # Skip header
            if line.startswith("image_name"):
                continue

            # Split by '|'
            parts = [p.strip() for p in line.split("|")]
            if len(parts) < 3:
                continue

            img_name,_, caption = parts
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = [
                token for token in nltk.tokenize.word_tokenize(caption)
                if token not in string.punctuation
            ]
            if len(tokens) < 1:
                continue
            counter.update(tokens)
            if limit and len(image_captions) >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    

    return vocab, image_captions

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [6]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=3)
print("Total vocabulary size:", len(vocab))

indices = list(vocab.word2idx.values())
print("max idx:", max(indices))
print("len(vocab):", len(vocab))


Total vocabulary size: 8323
max idx: 8322
len(vocab): 8323


In [7]:
len(image_captions)

22248

In [8]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def collate_fn(batch):
    images, captions = zip(*batch)

    images = torch.stack(images, 0)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions

from torch.utils.data import DataLoader
# from flickr_dataset  import FlickrDataset 

train_dataset = FlickrDataset30K(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn= collate_fn
)

print(len(train_dataset))
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

22248
<class 'torch.Tensor'>
torch.Size([3, 224, 224])
tensor([ 1,  3, 27, 17, 28, 29, 22, 30, 17, 31, 19,  3,  2])
13


In [9]:
import torch.nn as nn
from model  import TransformerEncoderViT
from model  import TransformerDecoder

encoder = TransformerEncoderViT(embed_size=256).to(device)
decoder = TransformerDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])

params = filter(lambda p: p.requires_grad, decoder.parameters())
optimizer = torch.optim.AdamW(
    params,
    lr=3e-4,
    weight_decay=1e-2
)


In [None]:
train_losses = []
for epoch in range(10):
    total_train_loss = 0
    for images, captions in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        # Teacher forcing:
        # input:  <start> w1 w2 ... w(T-1)
        # target: w1 w2 ... w(T-1) <end>
        captions_in = captions[:, :-1]
        targets     = captions[:, 1:]

        optimizer.zero_grad(set_to_none=True)
        
        memory = encoder(images)
        logits = decoder(memory, captions_in)   # <-- IMPORTANT
        
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1)
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), 1.0)
        optimizer.step()

        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch}: Train={avg_train_loss:.4f}")

  0%|          | 1/696 [00:00<06:56,  1.67it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7724
pos_max_len: 80 T: 33


  0%|          | 2/696 [00:00<04:57,  2.34it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7918
pos_max_len: 80 T: 30


  0%|          | 3/696 [00:01<04:16,  2.71it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7812
pos_max_len: 80 T: 29


  1%|          | 4/696 [00:01<04:05,  2.82it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6739
pos_max_len: 80 T: 26


  1%|          | 5/696 [00:01<03:54,  2.94it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6975
pos_max_len: 80 T: 24


  1%|          | 6/696 [00:02<03:47,  3.03it/s]

vocab_size: 8323 cap_min: 0 cap_max: 5821
pos_max_len: 80 T: 29


  1%|          | 7/696 [00:02<03:44,  3.07it/s]

vocab_size: 8323 cap_min: 0 cap_max: 5962
pos_max_len: 80 T: 35


  1%|          | 8/696 [00:02<03:42,  3.10it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7301
pos_max_len: 80 T: 22


  1%|▏         | 9/696 [00:03<03:39,  3.12it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8066
pos_max_len: 80 T: 30


  1%|▏         | 10/696 [00:03<03:39,  3.13it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7349
pos_max_len: 80 T: 30


  2%|▏         | 11/696 [00:03<03:39,  3.12it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7581
pos_max_len: 80 T: 29


  2%|▏         | 12/696 [00:04<03:37,  3.14it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7585
pos_max_len: 80 T: 28


  2%|▏         | 13/696 [00:04<03:35,  3.16it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6794
pos_max_len: 80 T: 29


  2%|▏         | 14/696 [00:04<03:33,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7805
pos_max_len: 80 T: 24


  2%|▏         | 15/696 [00:04<03:32,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7927
pos_max_len: 80 T: 33


  2%|▏         | 16/696 [00:05<03:32,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7388
pos_max_len: 80 T: 50


  2%|▏         | 17/696 [00:05<03:32,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7592
pos_max_len: 80 T: 25


  3%|▎         | 18/696 [00:05<03:31,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7458
pos_max_len: 80 T: 23


  3%|▎         | 19/696 [00:06<03:30,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7995
pos_max_len: 80 T: 18


  3%|▎         | 20/696 [00:06<03:31,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8074
pos_max_len: 80 T: 32


  3%|▎         | 21/696 [00:06<03:31,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6948
pos_max_len: 80 T: 30


  3%|▎         | 22/696 [00:07<03:29,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7141
pos_max_len: 80 T: 28


  3%|▎         | 23/696 [00:07<03:29,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7962
pos_max_len: 80 T: 26


  3%|▎         | 24/696 [00:07<03:28,  3.23it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8195
pos_max_len: 80 T: 26


  4%|▎         | 25/696 [00:08<03:28,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7688
pos_max_len: 80 T: 29


  4%|▎         | 26/696 [00:08<03:26,  3.25it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6948
pos_max_len: 80 T: 25


  4%|▍         | 27/696 [00:08<03:26,  3.24it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6356
pos_max_len: 80 T: 27


  4%|▍         | 28/696 [00:09<03:26,  3.24it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7744
pos_max_len: 80 T: 30


  4%|▍         | 29/696 [00:09<03:26,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6985
pos_max_len: 80 T: 21


  4%|▍         | 30/696 [00:09<03:26,  3.23it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6969
pos_max_len: 80 T: 19


  4%|▍         | 31/696 [00:09<03:26,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8115
pos_max_len: 80 T: 29


  5%|▍         | 32/696 [00:10<03:26,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7517
pos_max_len: 80 T: 36


  5%|▍         | 33/696 [00:10<03:26,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7254
pos_max_len: 80 T: 31


  5%|▍         | 34/696 [00:10<03:28,  3.17it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6806
pos_max_len: 80 T: 25


  5%|▌         | 35/696 [00:11<03:28,  3.16it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8092
pos_max_len: 80 T: 34


  5%|▌         | 36/696 [00:11<03:28,  3.17it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8048
pos_max_len: 80 T: 27


  5%|▌         | 37/696 [00:11<03:28,  3.16it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7633
pos_max_len: 80 T: 48


  5%|▌         | 38/696 [00:12<03:29,  3.15it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7140
pos_max_len: 80 T: 34


  6%|▌         | 39/696 [00:12<03:29,  3.14it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7471
pos_max_len: 80 T: 28


  6%|▌         | 40/696 [00:12<03:28,  3.14it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8055
pos_max_len: 80 T: 31


  6%|▌         | 41/696 [00:13<03:27,  3.16it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6558
pos_max_len: 80 T: 25


  6%|▌         | 42/696 [00:13<03:26,  3.17it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7863
pos_max_len: 80 T: 26


  6%|▌         | 43/696 [00:13<03:24,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7978
pos_max_len: 80 T: 23


  6%|▋         | 44/696 [00:14<03:23,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7805
pos_max_len: 80 T: 26


  6%|▋         | 45/696 [00:14<03:23,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8012
pos_max_len: 80 T: 40


  7%|▋         | 46/696 [00:14<03:21,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8267
pos_max_len: 80 T: 27


  7%|▋         | 47/696 [00:14<03:21,  3.23it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6786
pos_max_len: 80 T: 41


  7%|▋         | 48/696 [00:15<03:22,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8160
pos_max_len: 80 T: 25


  7%|▋         | 49/696 [00:15<03:22,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6535
pos_max_len: 80 T: 34


  7%|▋         | 50/696 [00:15<03:21,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7385
pos_max_len: 80 T: 23


  7%|▋         | 51/696 [00:16<03:21,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6995
pos_max_len: 80 T: 33


  7%|▋         | 52/696 [00:16<03:21,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8253
pos_max_len: 80 T: 25


  8%|▊         | 53/696 [00:16<03:20,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7927
pos_max_len: 80 T: 36


  8%|▊         | 54/696 [00:17<03:19,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7000
pos_max_len: 80 T: 22


  8%|▊         | 55/696 [00:17<03:18,  3.23it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7146
pos_max_len: 80 T: 23


  8%|▊         | 56/696 [00:17<03:19,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8266
pos_max_len: 80 T: 58


  8%|▊         | 57/696 [00:18<03:20,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7727
pos_max_len: 80 T: 25


  8%|▊         | 58/696 [00:18<03:19,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7775
pos_max_len: 80 T: 32


  8%|▊         | 59/696 [00:18<03:18,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6749
pos_max_len: 80 T: 21


  9%|▊         | 60/696 [00:19<03:19,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8060
pos_max_len: 80 T: 40


  9%|▉         | 61/696 [00:19<03:18,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7606
pos_max_len: 80 T: 22


  9%|▉         | 62/696 [00:19<03:20,  3.16it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7998
pos_max_len: 80 T: 22


  9%|▉         | 63/696 [00:20<03:19,  3.18it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7399
pos_max_len: 80 T: 24


  9%|▉         | 64/696 [00:20<03:18,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7672
pos_max_len: 80 T: 19


  9%|▉         | 65/696 [00:20<03:17,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7885
pos_max_len: 80 T: 23


  9%|▉         | 66/696 [00:20<03:15,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8047
pos_max_len: 80 T: 23


 10%|▉         | 67/696 [00:21<03:16,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7954
pos_max_len: 80 T: 38


 10%|▉         | 68/696 [00:21<03:16,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7755
pos_max_len: 80 T: 36


 10%|▉         | 69/696 [00:21<03:16,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7036
pos_max_len: 80 T: 23


 10%|█         | 70/696 [00:22<03:16,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6996
pos_max_len: 80 T: 26


 10%|█         | 71/696 [00:22<03:15,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7037
pos_max_len: 80 T: 29


 10%|█         | 72/696 [00:22<03:14,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6655
pos_max_len: 80 T: 24


 10%|█         | 73/696 [00:23<03:13,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6297
pos_max_len: 80 T: 24


 11%|█         | 74/696 [00:23<03:13,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6895
pos_max_len: 80 T: 23


 11%|█         | 75/696 [00:23<03:14,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7729
pos_max_len: 80 T: 37


 11%|█         | 76/696 [00:24<03:15,  3.17it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7950
pos_max_len: 80 T: 43


 11%|█         | 77/696 [00:24<03:14,  3.18it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7376
pos_max_len: 80 T: 28


 11%|█         | 78/696 [00:24<03:13,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7547
pos_max_len: 80 T: 24


 11%|█▏        | 79/696 [00:25<03:12,  3.20it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7358
pos_max_len: 80 T: 23


 11%|█▏        | 80/696 [00:25<03:12,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7679
pos_max_len: 80 T: 19


 12%|█▏        | 81/696 [00:25<03:11,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6407
pos_max_len: 80 T: 23


 12%|█▏        | 82/696 [00:25<03:11,  3.21it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8159
pos_max_len: 80 T: 27


 12%|█▏        | 83/696 [00:26<03:10,  3.22it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7458
pos_max_len: 80 T: 25


 12%|█▏        | 84/696 [00:26<03:11,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8038
pos_max_len: 80 T: 37


 12%|█▏        | 85/696 [00:26<03:12,  3.18it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7943
pos_max_len: 80 T: 28


 12%|█▏        | 86/696 [00:27<03:11,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 8206
pos_max_len: 80 T: 33


 12%|█▎        | 87/696 [00:27<03:10,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7562
pos_max_len: 80 T: 33


 13%|█▎        | 88/696 [00:27<03:11,  3.18it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7595
pos_max_len: 80 T: 29


 13%|█▎        | 89/696 [00:28<03:10,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6670
pos_max_len: 80 T: 34


 13%|█▎        | 90/696 [00:28<03:10,  3.19it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7723
pos_max_len: 80 T: 24


 13%|█▎        | 91/696 [00:28<03:11,  3.17it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7656
pos_max_len: 80 T: 46


 13%|█▎        | 92/696 [00:29<03:11,  3.16it/s]

vocab_size: 8323 cap_min: 0 cap_max: 7869
pos_max_len: 80 T: 39


 13%|█▎        | 93/696 [00:29<03:12,  3.14it/s]

vocab_size: 8323 cap_min: 0 cap_max: 6934
pos_max_len: 80 T: 23


 13%|█▎        | 93/696 [00:29<03:12,  3.13it/s]


KeyboardInterrupt: 

In [None]:
torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")
torch.save(vocab, "models/vocab.pkl")