In [1]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [2]:
IMAGES_PATH = "../phase_2/data/train2014/train2014"  # Directory with training images
VAL_IMAGES_PATH = "../phase_2/data/val2014/val2014"  # Directory with validation images
CAPTIONS_PATH = "../phase_2/data/annotations_trainval2014/annotations/captions_train2014.json"  # Caption file
VAL_CAPTIONS_PATH = "../phase_2/data/annotations_trainval2014/annotations/captions_val2014.json"  # Validation caption file

In [3]:
from pycocotools.coco import COCO
coco = COCO(CAPTIONS_PATH)

print(len(list(coco.anns.keys())))  # Total number of annotations
print("Num images:", len(coco.getImgIds()))
print("Num captions:", len(coco.getAnnIds()))

loading annotations into memory...
Done (t=0.38s)
creating index...
index created!
414113
Num images: 82783
Num captions: 414113


In [4]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class import Vocabulary
nltk.download('punkt_tab')
import json

tokens = []
counter = Counter()

def build_vocab(json_path, threshold=5, limit=None):
    with open(json_path, 'r') as f:
        data = json.load(f) 

    counter = Counter()
    count =0

    for ann in tqdm.tqdm(data['annotations']):
        caption = ann['caption'].lower()
        tokens = nltk.tokenize.word_tokenize(caption)
        counter.update(tokens)
        count +=1
        if limit and count >= limit:
            break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [5]:
vocab = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

100%|██████████| 414113/414113 [00:17<00:00, 24247.42it/s]


Total vocabulary size: 8853


In [6]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [7]:

def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    max_length = max(lengths)
    padded_captions = torch.zeros(len(captions), max_length).long()

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return images, padded_captions, lengths

In [None]:
from torch.utils.data import DataLoader
from coco_dataset import CocoDatasetClass 

train_dataset = CocoDatasetClass(
    root=IMAGES_PATH,
    json_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=3000
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn= collate_fn
)



loading annotations into memory...
Done (t=0.55s)
creating index...
index created!


In [9]:
print(len(train_dataset))

82783


In [10]:
from model import EncoderCNN, DecoderRNN
import torch.nn as nn

encoder = EncoderCNN(embed_size=256).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])

params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

for epoch in range(15):
    for images, captions, lengths in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        optimizer.zero_grad()
        
        features = encoder(images)
        outputs = decoder(features, captions, lengths)   # <-- IMPORTANT
        targets = captions[:, 1:]                       # shift left
        outputs = outputs[:, :-1, :]                    # align prediction
        loss = criterion(outputs.reshape(-1, len(vocab)),
                            targets.reshape(-1))

        
        loss.backward()
        optimizer.step()

    print("Epoch:", epoch, "Loss:", loss.item())

100%|██████████| 2587/2587 [05:59<00:00,  7.19it/s]


Epoch: 0 Loss: 3.6718809604644775


100%|██████████| 2587/2587 [05:58<00:00,  7.21it/s]


Epoch: 1 Loss: 3.862231969833374


100%|██████████| 2587/2587 [08:57<00:00,  4.82it/s]   


Epoch: 2 Loss: 3.2863659858703613


100%|██████████| 2587/2587 [05:55<00:00,  7.27it/s]


Epoch: 3 Loss: 3.0443227291107178


100%|██████████| 2587/2587 [05:55<00:00,  7.27it/s]


Epoch: 4 Loss: 3.2176079750061035


100%|██████████| 2587/2587 [31:30<00:00,  1.37it/s]    


Epoch: 5 Loss: 3.414350748062134


100%|██████████| 2587/2587 [07:12<00:00,  5.98it/s]


Epoch: 6 Loss: 3.0578784942626953


100%|██████████| 2587/2587 [05:56<00:00,  7.26it/s]


Epoch: 7 Loss: 3.006190776824951


100%|██████████| 2587/2587 [06:16<00:00,  6.87it/s]


Epoch: 8 Loss: 3.1500911712646484


100%|██████████| 2587/2587 [11:51<00:00,  3.63it/s]


Epoch: 9 Loss: 3.201063871383667


100%|██████████| 2587/2587 [06:16<00:00,  6.87it/s]


Epoch: 10 Loss: 2.845913887023926


100%|██████████| 2587/2587 [06:48<00:00,  6.34it/s]


Epoch: 11 Loss: 2.925001859664917


100%|██████████| 2587/2587 [05:54<00:00,  7.29it/s]


Epoch: 12 Loss: 3.0150487422943115


100%|██████████| 2587/2587 [05:55<00:00,  7.28it/s]


Epoch: 13 Loss: 2.621262311935425


100%|██████████| 2587/2587 [05:56<00:00,  7.26it/s]

Epoch: 14 Loss: 3.049644947052002





In [11]:
def generate_caption(image, encoder, decoder, vocab):
    encoder.eval()
    decoder.eval()
    image = image.unsqueeze(0).to(device)
    feature = encoder(image)              # [1, 256]
    feature = feature.unsqueeze(1)        # [1, 1, 256]

    # 2. Start sequence with <start> token
    start_token = vocab.word2idx["<start>"]
    end_token = vocab.word2idx["<end>"]

    sampled_ids = []
    inputs = torch.LongTensor([[start_token]]).to(image.device)

    # 3. FIRST STEP: concatenate image feature + embedding(<start>)
    embeddings = decoder.embed(inputs)     # [1,1,256]
    lstm_input = torch.cat((feature, embeddings), dim=1)  # [1,2,256]

    hiddens, states = decoder.lstm(lstm_input)

    outputs = decoder.linear(hiddens[:, -1, :])
    predicted = outputs.argmax(dim=1).item()
    sampled_ids.append(predicted)

    # 4. NEXT STEPS: only feed predicted tokens (NO concatenation!)
    inputs = torch.LongTensor([[predicted]]).to(image.device)

    for _ in range(20):
        embeddings = decoder.embed(inputs)  # [1,1,256]

        hiddens, states = decoder.lstm(embeddings, states)
        outputs = decoder.linear(hiddens[:, -1, :])
        
        predicted = outputs.argmax(dim=1).item()
        sampled_ids.append(predicted)
        
        if predicted == end_token:
            break

        inputs = torch.LongTensor([[predicted]]).to(image.device)

    words = [vocab.idx2word[id] for id in sampled_ids]
    return " ".join(words)

In [14]:
from pycocotools.coco import COCO
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

coco_val = COCO(VAL_CAPTIONS_PATH)
val_img_ids = coco_val.getImgIds()

predictions = []
references = {}
smoothie = SmoothingFunction().method4
scores = []

import matplotlib.pyplot as plt


print("Running Evaluation on 200 validation images...")
subset_ids = val_img_ids   # Evaluate on 200 images (faster)

for img_id in tqdm.tqdm(subset_ids):
    # Load the actual image file name
    img_info = coco_val.loadImgs(img_id)[0]
    file_name = img_info["file_name"]

    # Load image
    img_path = f"{VAL_IMAGES_PATH}/{file_name}"
    image = Image.open(img_path).convert("RGB")
    image = transform(image)

    # print(image.shape)

    # Generate prediction
    pred_caption = generate_caption(image, encoder, decoder, vocab)
    pred_tokens = pred_caption.lower().split()

    # Ground truth captions
    ann_ids = coco_val.getAnnIds(imgIds=img_id)
    anns = coco_val.loadAnns(ann_ids)
    gt_caps = [ann["caption"] for ann in anns]

    # Compute BLEU-4
    bleu4 = sentence_bleu(gt_caps, pred_tokens, smoothing_function=smoothie)
    scores.append(bleu4)
    
print("BLEU-4 Score:", sum(scores) / len(scores))

loading annotations into memory...
Done (t=0.17s)
creating index...
index created!
Running Evaluation on 200 validation images...


100%|██████████| 40504/40504 [11:53<00:00, 56.73it/s]

BLEU-4 Score: 0.0012808335140070558





In [13]:
# torch.save({
#     'encoder_state_dict': encoder.state_dict(),
#     'decoder_state_dict': decoder.state_dict(),
#     'vocab': vocab,
#     'embed_size': 256,
#     'hidden_size': 512
# }, 'model.pth')

torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")
torch.save(vocab, "models/vocab.pkl")