# GRU decoder for image captioning

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from img_caption import *
from torch.nn.utils.rnn import pad_sequence
from torchvision import transforms
from functools import partial
from tqdm import tqdm

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/nitinrajesh/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
def collate_fn(batch, pad_idx):
    imgs, caps = zip(*batch)
    imgs = torch.stack(imgs, 0)
    caps = pad_sequence(caps, batch_first=True, padding_value=pad_idx)
    return imgs, caps

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

vocab = Vocabulary()
dataset = Flickr8kDataset(
    image_dir="flickr8k/Images/",
    captions_file="flickr8k/captions.txt",
    vocab=vocab,
    transform=transform
)
pad_idx = vocab.stoi["<PAD>"]
custom_collate = partial(collate_fn, pad_idx=pad_idx)

data_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = EncoderCNN(embed_size=256).to(device=device)
decoder = DecoderGRU(embed_size=256, hidden_size=256, vocab_size=len(vocab)).to(device=device)

criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = optim.Adam(list(decoder.parameters()) + list(encoder.fc.parameters()), lr=1e-3)


for epoch in range(10):
    for idx, (imgs, captions) in enumerate(tqdm(data_loader,desc="Training progress ",leave=True)):
        imgs, captions = imgs.to(device), captions.to(device)
        features = encoder(imgs)
        outputs = decoder(features, captions[:, :-1])
        loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

torch.save(decoder.state_dict(), "decoder_gru.pth")




256


TypeError: 'module' object is not callable. Did you mean: 'tqdm.tqdm(...)'?