In [4]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [18]:
IMAGES_PATH = "./data/train2017/train2017"  # Directory with training images
CAPTIONS_PATH = "./data/annotations_trainval2017/annotations/captions_train2017.json"  # Caption file

In [6]:
import tqdm
import nltk 
from collections import Counter
nltk.download('punkt_tab')
import json

tokens = []
counter = Counter()

class Vocabulary:
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

        self.add_word("<pad>")
        self.add_word("<start>")
        self.add_word("<end>")
        self.add_word("<unk>")

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __len__(self):
        return len(self.word2idx)

def build_vocab(json_path, threshold=5, limit =5000):
    with open(json_path, 'r') as f:
        data = json.load(f) 

    counter = Counter()
    count =0

    for ann in tqdm.tqdm(data['annotations']):
        caption = ann['caption'].lower()
        tokens = nltk.tokenize.word_tokenize(caption)
        counter.update(tokens)
        count +=1
        if count >= limit:
            break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab

[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\pc\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [7]:
vocab = build_vocab(CAPTIONS_PATH, threshold=5, limit=5000)
print("Total vocabulary size:", len(vocab))

  1%|          | 4999/591753 [00:00<00:23, 25375.72it/s]

Total vocabulary size: 927





In [8]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [28]:

def collate_fn(data):
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    max_length = max(lengths)
    padded_captions = torch.zeros(len(captions), max_length).long()

    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_captions[i, :end] = cap[:end]

    return images, padded_captions, lengths

In [32]:
from torch.utils.data import DataLoader
from dataset import CocoDataset

train_dataset = CocoDataset(
    root=IMAGES_PATH,
    json_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=5000
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0,
    collate_fn= collate_fn
)



loading annotations into memory...
Done (t=0.46s)
creating index...
index created!


In [34]:
from model import EncoderCNN, DecoderRNN
import torch.nn as nn

encoder = EncoderCNN(embed_size=256).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])
params = list(decoder.parameters()) + list(encoder.embed.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=1e-3)

for epoch in range(3):
    for images, captions, lengths in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        features = encoder(images)
        outputs = decoder(features, captions)

        loss = criterion(outputs.reshape(-1, len(vocab)), captions.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch:", epoch, "Loss:", loss.item())

100%|██████████| 157/157 [00:22<00:00,  7.03it/s]


Epoch: 0 Loss: 3.0615651607513428


100%|██████████| 157/157 [00:17<00:00,  8.73it/s]


Epoch: 1 Loss: 2.253413200378418


100%|██████████| 157/157 [00:18<00:00,  8.51it/s]

Epoch: 2 Loss: 2.868299961090088





In [35]:
def generate_caption(image, encoder, decoder, vocab):
    feature = encoder(image.unsqueeze(0))
    caption_ids = [vocab.word2idx["<start>"]]
    
    for _ in range(20):
        cap_tensor = torch.Tensor(caption_ids).long().unsqueeze(0).to(device)
        outputs = decoder(feature, cap_tensor)
        predicted = outputs.argmax(2)[:, -1].item()
        
        caption_ids.append(predicted)
        if vocab.idx2word[predicted] == "<end>":
            break

    return " ".join([vocab.idx2word[idx] for idx in caption_ids])

In [39]:
from PIL import Image

def load_image(img_path, transform):
    image = Image.open(img_path).convert("RGB")
    image = transform(image)   # shape: [1, 3, 224, 224]
    return image
generate_caption(load_image('000000000034.jpg', transform), encoder, decoder, vocab)

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor