In [157]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [158]:
# IMAGES_PATH = "./data/train2017/train2017"  # Directory with training images
# VAL_IMAGES_PATH = "./data/val2017/val2017"  # Directory with validation images
# CAPTIONS_PATH = "./data/annotations_trainval2017/annotations/captions_train2017.json"  # Caption file
# VAL_CAPTIONS_PATH = "./data/annotations_trainval2017/annotations/captions_val2017.json"  # Validation caption file

IMAGES_PATH = "./data/train2014/train2014"  # Directory with training images
VAL_IMAGES_PATH = "./data/val2014/val2014"  # Directory with validation images
CAPTIONS_PATH = "./data/annotations_trainval2014/annotations/captions_train2014.json"  # Caption file
VAL_CAPTIONS_PATH = "./data/annotations_trainval2014/annotations/captions_val2014.json"  # Validation caption file

In [159]:
from pycocotools.coco import COCO
coco = COCO(CAPTIONS_PATH)

print(len(list(coco.anns.keys())))  # Total number of annotations
print("Num images:", len(coco.getImgIds()))
print("Num captions:", len(coco.getAnnIds()))

loading annotations into memory...
Done (t=0.33s)
creating index...
index created!
414113
Num images: 82783
Num captions: 414113


In [167]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class import Vocabulary
nltk.download('punkt_tab')
import json

tokens = []
counter = Counter()

def build_vocab(json_path, threshold=5, limit=None):
    with open(json_path, 'r') as f:
        data = json.load(f) 

    counter = Counter()
    count =0

    for ann in tqdm.tqdm(data['annotations']):
        caption = ann['caption'].lower()
        tokens = nltk.tokenize.word_tokenize(caption)
        counter.update(tokens)
        count +=1
        if limit and count >= limit:
            break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [161]:
vocab = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

100%|██████████| 414113/414113 [00:15<00:00, 26824.89it/s]

Total vocabulary size: 8853





In [162]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [163]:

def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    max_length = max(lengths)
    padded_captions = torch.zeros(len(captions), max_length).long()

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return images, padded_captions, lengths

In [164]:
from torch.utils.data import DataLoader
from coco_dataset import CocoDatasetClass 

train_dataset = CocoDatasetClass(
    root=IMAGES_PATH,
    json_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn= collate_fn
)



loading annotations into memory...
Done (t=0.30s)
creating index...
index created!
IDs example: [57870, 384029, 222016, 520950, 69675, 547471, 122688, 392136, 398494, 90570]
Total IDs: 82783


In [165]:
print(len(train_dataset))

82783


In [168]:
from model import EncoderCNN, DecoderRNN
import torch.nn as nn

encoder = EncoderCNN(embed_size=256).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])
params = list(decoder.parameters()) + list(encoder.embed.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

for epoch in range(20):
    for images, captions, lengths in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        features = encoder(images)
        outputs = decoder(features, captions)

        loss = criterion(outputs.reshape(-1, len(vocab)), captions.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch:", epoch, "Loss:", loss.item())

  0%|          | 0/2587 [00:00<?, ?it/s]


TypeError: 'Vocabulary' object is not callable

In [None]:
def generate_caption(image, encoder, decoder, vocab):
    encoder.eval()
    decoder.eval()
    image = image.unsqueeze(0).to(device)
    feature = encoder(image)              # [1, 256]
    feature = feature.unsqueeze(1)        # [1, 1, 256]

    # 2. Start sequence with <start> token
    start_token = vocab.word2idx["<start>"]
    end_token = vocab.word2idx["<end>"]

    sampled_ids = []
    inputs = torch.LongTensor([[start_token]]).to(image.device)

    # 3. FIRST STEP: concatenate image feature + embedding(<start>)
    embeddings = decoder.embed(inputs)     # [1,1,256]
    lstm_input = torch.cat((feature, embeddings), dim=1)  # [1,2,256]

    hiddens, states = decoder.lstm(lstm_input)

    outputs = decoder.linear(hiddens[:, -1, :])
    predicted = outputs.argmax(dim=1).item()
    sampled_ids.append(predicted)

    # 4. NEXT STEPS: only feed predicted tokens (NO concatenation!)
    inputs = torch.LongTensor([[predicted]]).to(image.device)

    for _ in range(20):
        embeddings = decoder.embed(inputs)  # [1,1,256]

        hiddens, states = decoder.lstm(embeddings, states)
        outputs = decoder.linear(hiddens[:, -1, :])
        
        predicted = outputs.argmax(dim=1).item()
        sampled_ids.append(predicted)
        
        if predicted == end_token:
            break

        inputs = torch.LongTensor([[predicted]]).to(image.device)

    words = [vocab.idx2word[id] for id in sampled_ids]
    return " ".join(words)

In [None]:
from pycocotools.coco import COCO
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

coco_val = COCO(VAL_CAPTIONS_PATH)
val_img_ids = coco_val.getImgIds()

predictions = []
references = {}
smoothie = SmoothingFunction().method4
scores = []

import matplotlib.pyplot as plt


print("Running Evaluation on 200 validation images...")
subset_ids = val_img_ids[:10]   # Evaluate on 200 images (faster)

for img_id in tqdm.tqdm(subset_ids):
    # Load the actual image file name
    img_info = coco_val.loadImgs(img_id)[0]
    file_name = img_info["file_name"]

    # Load image
    img_path = f"{VAL_IMAGES_PATH}/{file_name}"
    image = Image.open(img_path).convert("RGB")
    image = transform(image)

    # print(image.shape)

    # Generate prediction
    pred_caption = generate_caption(image, encoder, decoder, vocab)
    pred_tokens = pred_caption.lower().split()

    # Ground truth captions
    ann_ids = coco_val.getAnnIds(imgIds=img_id)
    anns = coco_val.loadAnns(ann_ids)
    gt_caps = [ann["caption"] for ann in anns]

    # Compute BLEU-4
    bleu4 = sentence_bleu(gt_caps, pred_tokens, smoothing_function=smoothie)
    scores.append(bleu4)
    
print("BLEU-4 Score:", sum(scores) / len(scores))

loading annotations into memory...
Done (t=0.23s)
creating index...
index created!
Running Evaluation on 200 validation images...


100%|██████████| 10/10 [00:00<00:00, 38.63it/s]

BLEU-4 Score: 0.002009378918434691





In [None]:
# torch.save({
#     'encoder_state_dict': encoder.state_dict(),
#     'decoder_state_dict': decoder.state_dict(),
#     'vocab': vocab,
#     'embed_size': 256,
#     'hidden_size': 512
# }, 'model.pth')

torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")
torch.save(vocab, "models/vocab.pkl")