# GRU decoder for image captioning

In [4]:
import torch.optim as optim
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from img_caption import *
from torch.nn.utils.rnn import pad_sequence
from torchvision import transforms

In [5]:
def collate_fn(batch, pad_idx):
    imgs, caps = zip(*batch)
    imgs = torch.stack(imgs, 0)
    caps = pad_sequence(caps, batch_first=True, padding_value=pad_idx)
    return imgs, caps

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

vocab = Vocabulary()
dataset = FlickrDataset(
    root_folder="flickr8k",
    captions_file="flickr8k/captions.txt",
    vocab=vocab,
    transform=transform
)
pad_idx = vocab.stoi["<PAD>"]
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=lambda batch: collate_fn(batch, pad_idx))

encoder = EncoderCNN(embed_size=256).cuda()
decoder = DecoderGRU(embed_size=256, hidden_size=512, vocab_size=len(vocab)).cuda()

criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(list(decoder.parameters()) + list(encoder.fc.parameters()), lr=1e-3)

for epoch in range(10):
    for idx, (imgs, captions) in enumerate(data_loader):
        imgs, captions = imgs.cuda(), captions.cuda()
        features = encoder(imgs)
        outputs = decoder(features, captions[:, :-1])
        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

torch.save(decoder.state_dict(), "decoder_gru.pth")


FileNotFoundError: [Errno 2] No such file or directory: 'captions.txt'