In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [2]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = torch.device('cpu')
print(device)

cuda


In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"

IMAGES_PATH = f"{datasets}/flickr30k/flickr30k_images/flickr30k_images"  # Directory with training images
CAPTIONS_PATH = f"{datasets}/flickr30k/flickr30k_images/results.csv"  # Caption file
TEST_IMAGES_PATH = "..\\test_images"  # Directory with test images

In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import nltk
import os
import random

# Update the folowing datset call class to handle flikr8k dataset formate


import torch
from torch.utils.data import Dataset
from PIL import Image
import nltk
import os
import random
import csv

class FlickrDataset30K(Dataset):
    def __init__(self, root, captions_path, vocab, transform=None, max_samples=None):
        self.root = root
        self.vocab = vocab
        self.transform = transform

        image_captions = {}
        with open(captions_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                # Skip header
                if line.startswith("image_name"):
                    continue

                # Split by '|'
                parts = [p.strip() for p in line.split("|")]
                if len(parts) < 3:
                    continue

                img_name,_, caption = parts
                if img_name not in image_captions:
                        image_captions[img_name] = []
                image_captions[img_name].append(caption)
        # Load captions
        self.image_captions = image_captions

        self.image_ids = list(self.image_captions.keys())

        if max_samples:
            self.image_ids = self.image_ids[:max_samples]

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        vocab = self.vocab
        img_name = self.image_ids[index]
        img_path = os.path.join(self.root, img_name)

        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
        except:
            # Skip corrupted image
            return self.__getitem__((index + 1) % len(self.image_ids))
        

        # Randomly select one caption per image
        caption = random.choice(self.image_captions[img_name])

        tokens = nltk.tokenize.word_tokenize(caption)

        caption_indices = [vocab.word2idx["<start>"]]
        caption_indices += [
            vocab.word2idx.get(token, vocab.word2idx["<unk>"])
            for token in tokens
        ]
        caption_indices.append(vocab.word2idx["<end>"])

        caption_tensor = torch.tensor(caption_indices)

        return image, caption_tensor


In [5]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class import Vocabulary
nltk.download('punkt')
import json

tokens = []
counter = Counter()

import csv
import string

def build_vocab(captions_path, threshold=3, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(captions_path, "r", encoding="utf-8") as f:
        # reader = csv.reader(f)
        # next(reader)  # skip header: image,caption

        for line in f:
            line = line.strip()
            # Skip header
            if line.startswith("image_name"):
                continue

            # Split by '|'
            parts = [p.strip() for p in line.split("|")]
            if len(parts) < 3:
                continue

            img_name,_, caption = parts
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = [
                token for token in nltk.tokenize.word_tokenize(caption)
                if token not in string.punctuation
            ]
            if len(tokens) < 1:
                continue
            counter.update(tokens)
            if limit and len(image_captions) >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    

    return vocab, image_captions

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [6]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=3)
print("Total vocabulary size:", len(vocab))

indices = list(vocab.word2idx.values())
print("max idx:", max(indices))
print("len(vocab):", len(vocab))


Total vocabulary size: 8323
max idx: 8322
len(vocab): 8323


In [7]:
len(image_captions)

22248

In [8]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def collate_fn(batch):
    images, captions = zip(*batch)

    images = torch.stack(images, 0)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions

from torch.utils.data import DataLoader
# from flickr_dataset  import FlickrDataset 

train_dataset = FlickrDataset30K(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn= collate_fn
)

print(len(train_dataset))
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

22248
<class 'torch.Tensor'>
torch.Size([3, 224, 224])
tensor([ 1,  3, 27, 17, 28, 29, 22, 30, 17, 31, 19,  3,  2])
13


In [9]:
import torch.nn as nn
from model  import TransformerEncoderViT
from model  import TransformerDecoder

encoder = TransformerEncoderViT(embed_size=256).to(device)
decoder = TransformerDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])

params = filter(lambda p: p.requires_grad, decoder.parameters())
optimizer = torch.optim.AdamW(
    params,
    lr=3e-4,
    weight_decay=1e-2
)


In [10]:
train_losses = []
for epoch in range(10):
    total_train_loss = 0
    for images, captions in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        # Teacher forcing:
        # input:  <start> w1 w2 ... w(T-1)
        # target: w1 w2 ... w(T-1) <end>
        captions_in = captions[:, :-1]
        targets     = captions[:, 1:]

        optimizer.zero_grad(set_to_none=True)
        
        memory = encoder(images)
        logits = decoder(memory, captions_in)   # <-- IMPORTANT
        
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1)
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), 1.0)
        optimizer.step()

        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch}: Train={avg_train_loss:.4f}")

100%|██████████| 696/696 [03:36<00:00,  3.21it/s]


Epoch 0: Train=4.1688


100%|██████████| 696/696 [03:36<00:00,  3.22it/s]


Epoch 1: Train=3.3918


100%|██████████| 696/696 [03:36<00:00,  3.21it/s]


Epoch 2: Train=3.1435


100%|██████████| 696/696 [03:36<00:00,  3.21it/s]


Epoch 3: Train=2.9874


100%|██████████| 696/696 [03:36<00:00,  3.21it/s]


Epoch 4: Train=2.8855


100%|██████████| 696/696 [03:36<00:00,  3.21it/s]


Epoch 5: Train=2.7909


100%|██████████| 696/696 [03:37<00:00,  3.21it/s]


Epoch 6: Train=2.7174


100%|██████████| 696/696 [03:37<00:00,  3.21it/s]


Epoch 7: Train=2.6634


100%|██████████| 696/696 [03:37<00:00,  3.20it/s]


Epoch 8: Train=2.6137


100%|██████████| 696/696 [03:36<00:00,  3.21it/s]

Epoch 9: Train=2.5650





In [11]:
torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")
torch.save(vocab, "models/vocab.pkl")