# GRU decoder for image captioning

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

Mounted at /content/drive


In [2]:
import sys
sys.path.append('/content/drive/My Drive/Datasets')

from img_caption import *

In [3]:
import torch.optim as optim
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from img_caption import *
from torchvision import transforms
from functools import partial
from tqdm import tqdm

In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

gec = GloveEmbeddingConverter('/content/drive/My Drive/Datasets/glove.6B.50d.txt')

dataset = Flickr8kDataset(
    image_dir="/content/drive/My Drive/Datasets/flickr8k/Images/",
    captions_file="/content/drive/My Drive/Datasets/flickr8k/captions.txt",
    glove_ec=gec,
    transform=transform
)

vocab_size = gec.get_vocab_size()
print(len(dataset), vocab_size)

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


40455 30005


In [5]:
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8, collate_fn=collate_fn_with_padding)
hidden_dim = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = EncoderCNN(embed_size=gec.embedding_dim).to(device=device)
decoder = FastDecoderGRU(embedding_dim=gec.embedding_dim, hidden_dim=hidden_dim,
                     vocab_size=vocab_size, embeddings=gec.build_embedding_matrix()).to(device=device)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(list(decoder.parameters()) + list(encoder.fc.parameters()), lr=1e-3)

  self.embedding.weight.data.copy_(torch.tensor(embeddings, dtype=torch.float))


In [6]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 110917 KiB | 110917 KiB | 111096 KiB | 183808 B   |
|       from large pool |  92388 KiB |  92388 KiB |  92388 KiB |      0 B   |
|       from small pool |  18528 KiB |  18590 KiB |  18708 KiB | 183808 B   |
|---------------------------------------------------------------------------|
| Active memory         | 110917 KiB | 110917 KiB | 111096 KiB | 183808 B   |
|       from large pool |  92388 KiB |  92388 KiB |  92388 KiB |      0 B   |
|       from small pool |  18528 KiB |  18590 KiB |  18708 KiB | 183808 B   |
|---------------------------------------------------------------

In [None]:
torch.cuda.empty_cache()

for epoch in range(10):
    for idx, (imgs, captions) in enumerate(tqdm(data_loader,desc="Training progress ",leave=True)):

        imgs, captions = imgs.to(device), captions.to(device)
        with torch.no_grad():
          features = encoder(imgs)
        hidden = torch.zeros(1, features.size(0), hidden_dim).to(features.device)
        predictions, _ = decoder(captions[:, :-1], hidden)  # Input sequence excluding <EOS>
        target = captions[:, 1:]                   # Target sequence excluding <SOS>
        loss = criterion(predictions.view(-1, vocab_size), target.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print(torch.cuda.memory_summary())

    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
    # Save everything in one checkpoint
    torch.save({
        'epoch': epoch,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item(),
    }, 'checkpoint.pth')


Training progress : 100%|██████████| 1265/1265 [05:51<00:00,  3.60it/s]


Epoch [1/10], Loss: 3.2475


Training progress : 100%|██████████| 1265/1265 [05:30<00:00,  3.83it/s]


Epoch [2/10], Loss: 3.3767


Training progress : 100%|██████████| 1265/1265 [05:27<00:00,  3.87it/s]


Epoch [3/10], Loss: 3.0532


Training progress : 100%|██████████| 1265/1265 [05:24<00:00,  3.89it/s]


Epoch [4/10], Loss: 2.7555


Training progress : 100%|██████████| 1265/1265 [05:24<00:00,  3.90it/s]


Epoch [5/10], Loss: 3.1134


Training progress : 100%|██████████| 1265/1265 [05:26<00:00,  3.88it/s]


Epoch [6/10], Loss: 2.7778


Training progress : 100%|██████████| 1265/1265 [05:26<00:00,  3.87it/s]


Epoch [7/10], Loss: 2.4197


Training progress : 100%|██████████| 1265/1265 [05:28<00:00,  3.86it/s]


Epoch [8/10], Loss: 2.4746


Training progress : 100%|██████████| 1265/1265 [05:22<00:00,  3.92it/s]


Epoch [9/10], Loss: 3.6285


Training progress :  81%|████████  | 1027/1265 [04:27<00:52,  4.49it/s]