In [4]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [44]:
IMAGES_PATH = "./data/train2017/train2017"  # Directory with training images
VAL_IMAGES_PATH = "./data/val2017/val2017"  # Directory with validation images
CAPTIONS_PATH = "./data/annotations_trainval2017/annotations/captions_train2017.json"  # Caption file
VAL_CAPTIONS_PATH = "./data/annotations_trainval2017/annotations/captions_val2017.json"  # Validation caption file

In [6]:
import tqdm
import nltk 
from collections import Counter
nltk.download('punkt_tab')
import json

tokens = []
counter = Counter()

class Vocabulary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

        self.add_word("<pad>")
        self.add_word("<start>")
        self.add_word("<end>")
        self.add_word("<unk>")

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __len__(self):
        return len(self.word2idx)

def build_vocab(json_path, threshold=5, limit =5000):
    with open(json_path, 'r') as f:
        data = json.load(f) 

    counter = Counter()
    count =0

    for ann in tqdm.tqdm(data['annotations']):
        caption = ann['caption'].lower()
        tokens = nltk.tokenize.word_tokenize(caption)
        counter.update(tokens)
        count +=1
        if count >= limit:
            break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [7]:
vocab = build_vocab(CAPTIONS_PATH, threshold=5, limit=5000)
print("Total vocabulary size:", len(vocab))

  1%|          | 4999/591753 [00:00<00:23, 25375.72it/s]

Total vocabulary size: 927





In [8]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [28]:

def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    max_length = max(lengths)
    padded_captions = torch.zeros(len(captions), max_length).long()

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return images, padded_captions, lengths

In [32]:
from torch.utils.data import DataLoader
from dataset import CocoDataset

train_dataset = CocoDataset(
    root=IMAGES_PATH,
    json_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=5000
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn= collate_fn
)



loading annotations into memory...
Done (t=0.46s)
creating index...
index created!


In [34]:
from model import EncoderCNN, DecoderRNN
import torch.nn as nn

encoder = EncoderCNN(embed_size=256).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])
params = list(decoder.parameters()) + list(encoder.embed.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

for epoch in range(3):
    for images, captions, lengths in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        features = encoder(images)
        outputs = decoder(features, captions)

        loss = criterion(outputs.reshape(-1, len(vocab)), captions.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch:", epoch, "Loss:", loss.item())

100%|██████████| 157/157 [00:22<00:00,  7.03it/s]


Epoch: 0 Loss: 3.0615651607513428


100%|██████████| 157/157 [00:17<00:00,  8.73it/s]


Epoch: 1 Loss: 2.253413200378418


100%|██████████| 157/157 [00:18<00:00,  8.51it/s]

Epoch: 2 Loss: 2.868299961090088





In [None]:
def generate_caption(image, encoder, decoder, vocab):
    feature = encoder(image.unsqueeze(0))
    caption_ids = [vocab.word2idx["<start>"]]
    
    for _ in range(20):
        cap_tensor = torch.Tensor(caption_ids).long().unsqueeze(0).to(device)
        outputs = decoder(feature, cap_tensor)
        predicted = outputs.argmax(2)[:, -1].item()
        
        caption_ids.append(predicted)
        if vocab.idx2word[predicted] == "<end>":
            break

    return " ".join([vocab.idx2word[idx] for idx in caption_ids])

In [56]:
from pycocotools.coco import COCO
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

coco_val = COCO(VAL_CAPTIONS_PATH)
val_img_ids = coco_val.getImgIds()

predictions = []
references = {}
smoothie = SmoothingFunction().method4
scores = []

print("Running Evaluation on 200 validation images...")
subset_ids = val_img_ids[:200]   # Evaluate on 200 images (faster)

for img_id in tqdm.tqdm(subset_ids):
    # Load the actual image file name
    img_info = coco_val.loadImgs(img_id)[0]
    file_name = img_info["file_name"]

    # Load image
    img_path = f"{VAL_IMAGES_PATH}/{file_name}"
    image = Image.open(img_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    encoder.eval()
    decoder.eval()

    with torch.no_grad():
    # Generate prediction
        pred_caption = generate_caption(image, encoder, decoder, vocab)
        pred_tokens = pred_caption.lower().split()

    # Ground truth captions
    ann_ids = coco_val.getAnnIds(imgIds=img_id)
    anns = coco_val.loadAnns(ann_ids)
    gt_caps = [ann["caption"] for ann in anns]

    # Compute BLEU-4
    bleu4 = sentence_bleu(gt_caps, pred_tokens, smoothing_function=smoothie)
    scores.append(bleu4)
    
    print("BLEU-4 Score:", sum(scores) / len(scores))

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
Running Evaluation on 200 validation images...


  3%|▎         | 6/200 [00:00<00:03, 58.29it/s]

BLEU-4 Score: 0.007645216259060621
BLEU-4 Score: 0.007069411604074036
BLEU-4 Score: 0.005841428231392098
BLEU-4 Score: 0.005266906912128012
BLEU-4 Score: 0.005942541281982352
BLEU-4 Score: 0.0058988706930019645
BLEU-4 Score: 0.00642915056108395
BLEU-4 Score: 0.006135962497535574
BLEU-4 Score: 0.005981926260745811
BLEU-4 Score: 0.0058495886817115286
BLEU-4 Score: 0.005787306011515994
BLEU-4 Score: 0.005832103405165886
BLEU-4 Score: 0.005773080366492997


 10%|█         | 20/200 [00:00<00:02, 62.14it/s]

BLEU-4 Score: 0.005632438025464427
BLEU-4 Score: 0.005665132759096542
BLEU-4 Score: 0.0055220814266781154
BLEU-4 Score: 0.005450313940761436
BLEU-4 Score: 0.005680366306768723
BLEU-4 Score: 0.005936642905858208
BLEU-4 Score: 0.00587882069222515
BLEU-4 Score: 0.005740761820610403
BLEU-4 Score: 0.005758813398848159
BLEU-4 Score: 0.00576375311601057
BLEU-4 Score: 0.005806239434351788


 17%|█▋        | 34/200 [00:00<00:02, 63.21it/s]

BLEU-4 Score: 0.0059043530050078095
BLEU-4 Score: 0.005878390714681325
BLEU-4 Score: 0.005837958990894469
BLEU-4 Score: 0.005775304957387414
BLEU-4 Score: 0.005813648918667704
BLEU-4 Score: 0.005713162637916735
BLEU-4 Score: 0.005689263245256285
BLEU-4 Score: 0.005616983501354579
BLEU-4 Score: 0.0055660568942119525
BLEU-4 Score: 0.005597354588377741
BLEU-4 Score: 0.005664401020550422
BLEU-4 Score: 0.005671656856558059
BLEU-4 Score: 0.005641684673333946
BLEU-4 Score: 0.0056329306532367585


 24%|██▍       | 48/200 [00:00<00:02, 65.33it/s]

BLEU-4 Score: 0.00562806785366163
BLEU-4 Score: 0.005602737921769093
BLEU-4 Score: 0.005678767549738776
BLEU-4 Score: 0.005671455716492145
BLEU-4 Score: 0.005680080181290595
BLEU-4 Score: 0.005622055385852087
BLEU-4 Score: 0.005598513845661283
BLEU-4 Score: 0.005537656164149921
BLEU-4 Score: 0.005562754586263541
BLEU-4 Score: 0.00553170988635969
BLEU-4 Score: 0.00553221824281733
BLEU-4 Score: 0.005520060183860314
BLEU-4 Score: 0.005509059026575469
BLEU-4 Score: 0.005457923510234087


 31%|███       | 62/200 [00:00<00:02, 63.32it/s]

BLEU-4 Score: 0.0055293273777921
BLEU-4 Score: 0.005493711423483077
BLEU-4 Score: 0.005475225269913842
BLEU-4 Score: 0.005497783815777634
BLEU-4 Score: 0.005507739256277934
BLEU-4 Score: 0.005497706120767825
BLEU-4 Score: 0.005526702103290928
BLEU-4 Score: 0.005579922946044847
BLEU-4 Score: 0.005611178267959602
BLEU-4 Score: 0.005578837745010332
BLEU-4 Score: 0.005601537870338334
BLEU-4 Score: 0.005597879326577301
BLEU-4 Score: 0.005574162298155445


 38%|███▊      | 76/200 [00:01<00:01, 62.87it/s]

BLEU-4 Score: 0.005562291595278043
BLEU-4 Score: 0.0055193797795453315
BLEU-4 Score: 0.005518260686755058
BLEU-4 Score: 0.005497469170319433
BLEU-4 Score: 0.005549580094933809
BLEU-4 Score: 0.005539573760955785
BLEU-4 Score: 0.0055536469924298525
BLEU-4 Score: 0.005540072185372812
BLEU-4 Score: 0.005513936964558379
BLEU-4 Score: 0.005527002220899977
BLEU-4 Score: 0.005496125175146383
BLEU-4 Score: 0.005453999834667964
BLEU-4 Score: 0.005489416763497725
BLEU-4 Score: 0.005484896600824635
BLEU-4 Score: 0.005483481275231581
BLEU-4 Score: 0.005487056560018174


 45%|████▌     | 90/200 [00:01<00:01, 65.47it/s]

BLEU-4 Score: 0.005463971294894642
BLEU-4 Score: 0.005502250218722172
BLEU-4 Score: 0.005505474374962926
BLEU-4 Score: 0.005501084369520111
BLEU-4 Score: 0.005520170048535338
BLEU-4 Score: 0.005530379581475
BLEU-4 Score: 0.005521801035986439
BLEU-4 Score: 0.005515104590601104
BLEU-4 Score: 0.005516184816588856
BLEU-4 Score: 0.0054864096664109085
BLEU-4 Score: 0.005482560908214308
BLEU-4 Score: 0.005485482572797691
BLEU-4 Score: 0.005475665588802672


 52%|█████▏    | 104/200 [00:01<00:01, 65.36it/s]

BLEU-4 Score: 0.005490524686125967
BLEU-4 Score: 0.005481533566556801
BLEU-4 Score: 0.005496025796123388
BLEU-4 Score: 0.005506207483817825
BLEU-4 Score: 0.005499469331279202
BLEU-4 Score: 0.005493732811933326
BLEU-4 Score: 0.005490154505996597
BLEU-4 Score: 0.0054560212596976975
BLEU-4 Score: 0.005469916974016319
BLEU-4 Score: 0.0054431600794698776
BLEU-4 Score: 0.005436565511922248
BLEU-4 Score: 0.005460067734924589
BLEU-4 Score: 0.0054732674346981616
BLEU-4 Score: 0.005489326175983576


 59%|█████▉    | 118/200 [00:01<00:01, 64.57it/s]

BLEU-4 Score: 0.005491080227103909
BLEU-4 Score: 0.005473923179518936
BLEU-4 Score: 0.005506998050814493
BLEU-4 Score: 0.005509373776637092
BLEU-4 Score: 0.005508788888966026
BLEU-4 Score: 0.005509758148675574
BLEU-4 Score: 0.005513513903443731
BLEU-4 Score: 0.005507282713829964
BLEU-4 Score: 0.005484570999838457
BLEU-4 Score: 0.00548998022043259
BLEU-4 Score: 0.005475985011729413
BLEU-4 Score: 0.005465730447561956
BLEU-4 Score: 0.00551048502502126
BLEU-4 Score: 0.005486148236259535


 66%|██████▌   | 132/200 [00:02<00:01, 64.46it/s]

BLEU-4 Score: 0.005499813876795339
BLEU-4 Score: 0.005510883406907447
BLEU-4 Score: 0.005526071146838374
BLEU-4 Score: 0.005544377172698616
BLEU-4 Score: 0.005542524060391075
BLEU-4 Score: 0.005534400123427863
BLEU-4 Score: 0.005556082416467232
BLEU-4 Score: 0.005587122217332042
BLEU-4 Score: 0.005564715540714202
BLEU-4 Score: 0.005579824773399798
BLEU-4 Score: 0.0055860546579177385
BLEU-4 Score: 0.005622384476127127
BLEU-4 Score: 0.005645329425700452


 74%|███████▎  | 147/200 [00:02<00:00, 64.38it/s]

BLEU-4 Score: 0.005671064017374168
BLEU-4 Score: 0.005658402276120691
BLEU-4 Score: 0.005666554612932858
BLEU-4 Score: 0.005666214540148122
BLEU-4 Score: 0.005661918749878624
BLEU-4 Score: 0.00568085135396109
BLEU-4 Score: 0.005694934451828853
BLEU-4 Score: 0.005677499747921062
BLEU-4 Score: 0.0056793342631925375
BLEU-4 Score: 0.005706879170124211
BLEU-4 Score: 0.005699399659218979
BLEU-4 Score: 0.0057087934865018295
BLEU-4 Score: 0.005705649647568841
BLEU-4 Score: 0.005711335531995147
BLEU-4 Score: 0.005706009989507591


 80%|████████  | 161/200 [00:02<00:00, 63.96it/s]

BLEU-4 Score: 0.0056988667260504935
BLEU-4 Score: 0.005704096755419755
BLEU-4 Score: 0.005677785785804889
BLEU-4 Score: 0.005692396275183887
BLEU-4 Score: 0.005685107822294093
BLEU-4 Score: 0.005690267614595502
BLEU-4 Score: 0.00568786840423499
BLEU-4 Score: 0.005690769796235628
BLEU-4 Score: 0.005685958774854848
BLEU-4 Score: 0.005688209922595682
BLEU-4 Score: 0.0057103524664885354
BLEU-4 Score: 0.005721316325231477
BLEU-4 Score: 0.005722913753965762


 88%|████████▊ | 175/200 [00:02<00:00, 65.73it/s]

BLEU-4 Score: 0.005715743948117527
BLEU-4 Score: 0.005706915048101703
BLEU-4 Score: 0.005713637166202801
BLEU-4 Score: 0.005718308986705648
BLEU-4 Score: 0.005723345064955866
BLEU-4 Score: 0.005759209797460292
BLEU-4 Score: 0.005763531099770397
BLEU-4 Score: 0.005756432920317989
BLEU-4 Score: 0.005752804379573836
BLEU-4 Score: 0.005767045511965833
BLEU-4 Score: 0.005756906795374418
BLEU-4 Score: 0.005755114650656889
BLEU-4 Score: 0.005745462138318819
BLEU-4 Score: 0.0057476995632909965
BLEU-4 Score: 0.005752999883119868


 94%|█████████▍| 189/200 [00:02<00:00, 65.65it/s]

BLEU-4 Score: 0.005751269612654431
BLEU-4 Score: 0.005767731133254509
BLEU-4 Score: 0.005756717343725194
BLEU-4 Score: 0.005775869053324554
BLEU-4 Score: 0.005772410535148137
BLEU-4 Score: 0.005761831467488435
BLEU-4 Score: 0.005757312472511156
BLEU-4 Score: 0.0057508201603654425
BLEU-4 Score: 0.0057475126937983735
BLEU-4 Score: 0.0057328326224230595
BLEU-4 Score: 0.005743490313897788
BLEU-4 Score: 0.005740273648626345
BLEU-4 Score: 0.005748463664966327
BLEU-4 Score: 0.005763940170193201


100%|██████████| 200/200 [00:03<00:00, 63.88it/s]


BLEU-4 Score: 0.005783770824901488
BLEU-4 Score: 0.005780412851125313
BLEU-4 Score: 0.005777089318105817
BLEU-4 Score: 0.00579287528618159
BLEU-4 Score: 0.005786195579029273
BLEU-4 Score: 0.005798964532244941
BLEU-4 Score: 0.005800592142786154
BLEU-4 Score: 0.005813161423650881


In [None]:
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'vocab': vocab,
    'embed_size': 256,
    'hidden_size': 512
}, 'model.pth')