In [1]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = torch.device('cpu')
print(device)

cuda


In [None]:
IMAGES_PATH = "../phase_2/data/train2014/train2014"  # Directory with training images
CAPTIONS_PATH = "../phase_2/data/annotations_trainval2014/annotations/captions_train2014.json"  # Caption file

In [None]:
import torch
from torch.utils.data import Dataset
from pycocotools.coco import COCO
from PIL import Image
import nltk
import os
import random

class CocoDatasetClass(Dataset):
    def __init__(self, root, json_path, vocab, transform=None, max_samples=None):
        self.root = root
        self.coco = COCO(json_path)
        self.vocab = vocab

        self.ids = self.coco.getImgIds()
        
        self.transform = transform
        if max_samples:
            self.ids = self.ids[:max_samples]

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        vocab = self.vocab
        img_id = self.ids[index]
        
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])

        try:
            image = Image.open(img_path).convert("RGB")
        except:
            # skip corrupted image
            return self.__getitem__((index + 1) % len(self.ids))
        
        if self.transform:
            image = self.transform(image)

        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        caption = random.choice(anns)['caption'].lower()
        
        tokens = nltk.tokenize.word_tokenize(caption)

        caption_tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption_indices = [self.vocab.word2idx["<start>"]]
        caption_indices += [self.vocab.word2idx.get(token, self.vocab.word2idx["<unk>"]) for token in caption_tokens]
        caption_indices.append(self.vocab.word2idx["<end>"])

        caption_tensor = torch.tensor(caption_indices)

        return image, caption_tensor
    

In [None]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class import Vocabulary
nltk.download('punkt')
import json

tokens = []
counter = Counter()

import csv
import string

tokens = []
counter = Counter()

def build_vocab(json_path, threshold=5, limit=None):
    with open(json_path, 'r') as f:
        data = json.load(f) 

    counter = Counter()
    count =0

    for ann in tqdm.tqdm(data['annotations']):
        caption = ann['caption'].lower()
        tokens = nltk.tokenize.word_tokenize(caption)
        counter.update(tokens)
        count +=1
        if limit and count >= limit:
            break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab

[nltk_data] Downloading package punkt to /home/madhu-
[nltk_data]     thiramdas/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
vocab = build_vocab(CAPTIONS_PATH, threshold=3)
print("Total vocabulary size:", len(vocab))

indices = list(vocab.word2idx.values())
print("max idx:", max(indices))
print("len(vocab):", len(vocab))


100%|██████████| 414113/414113 [00:16<00:00, 25461.04it/s]

Total vocabulary size: 11489
max idx: 11488
len(vocab): 11489





In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def collate_fn(batch):
    images, captions = zip(*batch)

    images = torch.stack(images, 0)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions

from torch.utils.data import DataLoader

train_dataset = CocoDatasetClass(
    root=IMAGES_PATH,
    json_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=12,
    collate_fn= collate_fn
)


print(len(train_dataset))
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

loading annotations into memory...
Done (t=0.28s)
creating index...
index created!
82783
<class 'torch.Tensor'>
torch.Size([3, 224, 224])
tensor([   1,    4,  905,  113,   22,    4, 1646,   40,  259,   14,   47,  678,
          22,  266,  620,    2])
16


In [None]:
import torch.nn as nn
from model  import TransformerEncoderViT
from model  import TransformerDecoder

encoder = TransformerEncoderViT(embed_size=256).to(device)
decoder = TransformerDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])

params = filter(lambda p: p.requires_grad, decoder.parameters())
optimizer = torch.optim.AdamW(
    params,
    lr=3e-4,
    weight_decay=1e-2
)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
train_losses = []
for epoch in range(5):
    total_train_loss = 0
    for images, captions in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        # Teacher forcing:
        # input:  <start> w1 w2 ... w(T-1)
        # target: w1 w2 ... w(T-1) <end>
        captions_in = captions[:, :-1]
        targets     = captions[:, 1:]

        optimizer.zero_grad(set_to_none=True)
        
        memory = encoder(images)
        logits = decoder(memory, captions_in)   # <-- IMPORTANT
        
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1)
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), 1.0)
        optimizer.step()

        total_train_loss += loss.item()
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch}: Train={avg_train_loss:.4f}")

100%|██████████| 2587/2587 [09:28<00:00,  4.55it/s]

Epoch 0: Train=2.7131





In [None]:
torch.save(encoder.state_dict(), "models/encoder.pth")
torch.save(decoder.state_dict(), "models/decoder.pth")
torch.save(vocab, "models/vocab.pkl")