# Image Captioning Training Notebook
This notebook downloads the Flickr8k dataset from Kaggle, prepares the data, trains the image captioning model, and reports training progress.

In [1]:
# Step 1: Upload kaggle.json for Kaggle API authentication
from google.colab import files
uploaded = files.upload()

In [2]:
# Step 2: Configure Kaggle API
!mkdir -p ~/.kaggle
!cp /content/kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [3]:
# Step 3: Download Flickr8k dataset from Kaggle
!kaggle datasets download -d adityajn105/flickr8k
!unzip -q flickr8k.zip -d data/

Downloading flickr8k.zip to /content
 99% 1.04G/1.05G [00:20<00:00, 56.3MB/s]
100% 1.05G/1.05G [00:20<00:00, 53.3MB/s]

  End-of-central-directory signature not found.  Either this file is not
  a zipfile, or it constitutes one disk of a multi-part archive.  In the
  latter case the central directory and zipfile comment will be found on
  the last disk(s) of this archive.
unzip:  cannot find zipfile directory in one of flickr8k.zip or
        flickr8k.zip.zip, and cannot find flickr8k.zip.ZIP, period.


### Note:
- If the unzip command fails, please manually upload the dataset or check the Kaggle API setup.

In [4]:
# Step 4: Install required packages
!pip install torch torchvision pytorch-lightning nltk pycocoevalcap

In [5]:
# Step 5: Import necessary modules
import os
import json
import nltk
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torchvision import transforms
from PIL import Image
from pycocoevalcap.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from pytorch_lightning.callbacks import EarlyStopping
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import defaultdict
import random
nltk.download('punkt')

In [6]:
# Step 6: Define Vocabulary class and build_vocab function
class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        for token in ['<pad>', '<start>', '<end>', '<unk>']:
            self.add_word(token)
    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    def __call__(self, word):
        return self.word2idx.get(word, self.word2idx['<unk>'])
    def __len__(self):
        return self.idx
def build_vocab(captions_file, freq_threshold=5):
    with open(captions_file, 'r') as f:
        data = f.readlines()
    counter = {}
    for line in data:
        parts = line.strip().split('\t')
        if len(parts) == 2:
            caption = parts[1]
            tokens = nltk.word_tokenize(caption.lower())
            for token in tokens:
                counter[token] = counter.get(token, 0) + 1
    vocab = Vocabulary()
    for word, freq in counter.items():
        if freq >= freq_threshold:
            vocab.add_word(word)
    return vocab

In [7]:
# Step 7: Define Dataset class
class ImageCaptionDataset(Dataset):
    def __init__(self, image_dir, captions_file, vocab, transform=None):
        self.image_dir = image_dir
        self.captions_file = captions_file
        self.vocab = vocab
        self.transform = transform
        self.img_captions = self._load_data()
    def _load_data(self):
        img_captions = []
        with open(self.captions_file, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                if len(parts) == 2:
                    img_info, caption = parts
                    img_name = img_info.split('#')[0]
                    img_captions.append((img_name, caption))
        return img_captions
    def __len__(self):
        return len(self.img_captions)
    def __getitem__(self, idx):
        img_name, caption = self.img_captions[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        tokens = ['<start>'] + nltk.word_tokenize(caption.lower()) + ['<end>']
        indices = [self.vocab(token) for token in tokens]
        return image, torch.tensor(indices)
    @staticmethod
    def custom_collate_fn(batch):
        batch = [item for item in batch if item[0] is not None]
        images, captions = zip(*batch)
        images = torch.stack(images, 0)
        max_len = max(len(cap) for cap in captions)
        padded_captions = torch.zeros((len(captions), max_len), dtype=torch.long)
        for i, cap in enumerate(captions):
            padded_captions[i, :len(cap)] = cap
        return images, padded_captions

In [8]:
# Step 8: Define Model
import torch.nn as nn
import pytorch_lightning as pl
import torchvision
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        resnet = torchvision.models.resnet50(weights="IMAGENET1K_V1")
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size)
    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images).squeeze()
        return self.bn(self.linear(features))
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        hiddens, _ = self.lstm(inputs)
        outputs = self.linear(hiddens)
        return outputs
class ImageCaptioningModel(pl.LightningModule):
    def __init__(self, vocab_size, embed_size=256, hidden_size=512):
        super().__init__()
        self.encoder = EncoderCNN(embed_size)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs
    def training_step(self, batch, batch_idx):
        images, captions = batch
        outputs = self(images, captions)
        loss = self.loss_fn(outputs.view(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
        self.log("train_loss", loss)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
    def validation_step(self, batch, batch_idx):
        images, captions = batch
        outputs = self(images, captions)
        loss = self.loss_fn(outputs.view(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
        self.log("val_loss", loss)
        return loss

In [9]:
# Step 9: Define training parameters and transforms
batch_size = 32
max_epochs = 10
learning_rate = 0.001
image_dir = "data/Flicker8k_Dataset"
captions_file = "data/Flickr8k_text/Flickr8k.token.txt"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

In [10]:
# Step 10: Build vocabulary, dataset, dataloader, and model
vocab = build_vocab(captions_file)
dataset = ImageCaptionDataset(image_dir, captions_file, vocab, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=ImageCaptionDataset.custom_collate_fn)
model = ImageCaptioningModel(vocab_size=len(vocab), embed_size=256, hidden_size=512)

In [11]:
# Step 11: Initialize PyTorch Lightning Trainer and train
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=max_epochs, precision=16)
trainer.fit(model, train_dataloaders=dataloader)

In [None]:
# Step 12: After training, you can save the model checkpoint
trainer.save_checkpoint("image_captioning_model.ckpt")

# Step 13: To generate captions, load the model and run inference (not shown here but can be added)