In [1]:
# Step 0: Install Kaggle library
!pip install kaggle



In [2]:
# Step 0b: Configure Kaggle to use the uploaded kaggle.json
import os

# Create a directory for Kaggle and copy the kaggle.json file
!mkdir -p ~/.kaggle
!cp /content/kaggle.json ~/.kaggle/

# Set permissions for the kaggle.json file
!chmod 600 ~/.kaggle/kaggle.json

print("✅ Kaggle API configured.")

✅ Kaggle API configured.


In [3]:
# Step 0d: Download the Flickr8k dataset
# The dataset will be downloaded to the 'data' directory
!kaggle datasets download -d adityajn105/flickr8k -p data --unzip

print("✅ Flickr8k dataset downloaded and unzipped.")

Dataset URL: https://www.kaggle.com/datasets/adityajn105/flickr8k
License(s): CC0-1.0
Downloading flickr8k.zip to data
 98% 1.02G/1.04G [00:05<00:00, 239MB/s]
100% 1.04G/1.04G [00:06<00:00, 185MB/s]
✅ Flickr8k dataset downloaded and unzipped.


In [4]:
# Step 1: Install required packages
!pip install torch torchvision pytorch-lightning nltk pycocotools timm sentencepiece

# Download spaCy model if you want to use spaCy alternative (optional)
# !pip install spacy
# !python -m spacy download en_core_web_sm

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.2-py3-none-any.whl.metadata (21 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1

In [5]:
# Step 2: Import necessary modules and set seed
import os
import random
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import timm
import sentencepiece as spm
import nltk

# Download nltk punkt tokenizer if not already present
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Set seed for reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(torch.cuda.device_count())
    pl.seed_everything(seed)

seed_everything()

# Enable cudnn benchmark for fixed input size
torch.backends.cudnn.benchmark = True

print("✅ Necessary modules imported and seed set.")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
INFO:lightning_fabric.utilities.seed:Seed set to 42


✅ Necessary modules imported and seed set.


In [6]:
# Step 3: Prepare captions file for SentencePiece training
captions_file = "data/captions.txt"

# Read captions and write to a separate file for SentencePiece
with open(captions_file, 'r') as f:
    lines = f.readlines()

captions_only_file = "captions.txt"
with open(captions_only_file, 'w') as f:
    for line in lines:
        line = line.strip()
        if not line:
            continue
        # Extract caption part
        if ',' in line:
            parts = line.split(',')
            caption = ','.join(parts[1:]).strip()
            if caption.lower() != 'caption':
                f.write(caption + '\n')
        elif '\t' in line:
            parts = line.split('\t')
            if len(parts) == 2:
                caption = parts[1].strip()
                f.write(caption + '\n')

print(f"✅ Captions extracted to {captions_only_file} for SentencePiece training.")

✅ Captions extracted to captions.txt for SentencePiece training.


In [7]:
# Step 4: Train SentencePiece BPE tokenizer
spm.SentencePieceTrainer.Train(
    f'--input={captions_only_file} --model_prefix=bpe --vocab_size=4000 --model_type=bpe --pad_id=0 --unk_id=1 --bos_id=2 --eos_id=3'
)

# Load the trained tokenizer
sp = spm.SentencePieceProcessor()
sp.load("bpe.model")

print("✅ SentencePiece BPE tokenizer trained and loaded.")

✅ SentencePiece BPE tokenizer trained and loaded.


In [8]:
# Step 5: Define Dataset class using SentencePiece tokenizer
class ImageCaptionDataset(Dataset):
    def __init__(self, image_dir, samples, transform=None):
        self.image_dir = image_dir
        self.samples = samples  # list of (img_name, caption) tuples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_name, caption = self.samples[idx]
        img_path = os.path.join(self.image_dir, img_name)
        try:
            image = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            print(f"Warning: Image file not found at {img_path}. Skipping sample.")
            return None

        if self.transform:
            image = self.transform(image)

        # Encode caption using SentencePiece tokenizer
        # Add BOS (<start>) and EOS (<end>) tokens
        encoded = [sp.bos_id()] + sp.encode(caption, out_type=int) + [sp.eos_id()]
        return image, torch.tensor(encoded)

    @staticmethod
    def custom_collate_fn(batch):
        batch = [item for item in batch if item is not None]
        if not batch:
            return None, None
        images, captions = zip(*batch)
        images = torch.stack(images, 0)
        max_len = max(len(cap) for cap in captions)
        padded_captions = torch.zeros((len(captions), max_len), dtype=torch.long)
        for i, cap in enumerate(captions):
            padded_captions[i, :len(cap)] = cap
        return images, padded_captions

print("✅ Dataset class defined with SentencePiece tokenizer.")

✅ Dataset class defined with SentencePiece tokenizer.


In [9]:
# Step 6: Define training parameters and transforms
batch_size = 64
max_epochs = 15
learning_rate = 1e-4

image_dir = "data/Images"
captions_file = "data/captions.txt"

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(300),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

val_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

print(f"✅ Training parameters and transforms defined.")

✅ Training parameters and transforms defined.


In [10]:
# Step 7: Load and prepare data
all_samples = []
with open(captions_file, 'r') as f:
    first_line = f.readline().strip()
    if first_line.lower().startswith("image"):
        pass  # skip header
    else:
        f.seek(0)
    for line in f:
        line = line.strip()
        if not line:
            continue
        if ',' in line:
            parts = line.split(',')
            if len(parts) >= 2:
                img_name = parts[0].strip()
                caption = ','.join(parts[1:]).strip()
                if img_name.lower() != 'image':
                    all_samples.append((img_name, caption))
        elif '\t' in line:
            parts = line.split('\t')
            if len(parts) == 2:
                img_info, caption = parts
                img_name = img_info.split('#')[0].strip()
                all_samples.append((img_name, caption.strip()))
        else:
            print(f"⚠️ Skipping malformed line: {line}")

print(f"🔍 Parsed {len(all_samples)} total (img, caption) pairs")

# Train/val split
train_samples, val_samples = train_test_split(all_samples, test_size=0.10, random_state=42, shuffle=True)
print(f"→ {len(train_samples)} train samples, {len(val_samples)} val samples")

# Create datasets
train_dataset = ImageCaptionDataset(image_dir=image_dir, samples=train_samples, transform=train_transform)
val_dataset = ImageCaptionDataset(image_dir=image_dir, samples=val_samples, transform=val_transform)

# DataLoader setup
num_workers = min(8, os.cpu_count() or 1)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=ImageCaptionDataset.custom_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=ImageCaptionDataset.custom_collate_fn)

print(f"✅ DataLoaders created with batch_size={batch_size}, num_workers={num_workers}")

🔍 Parsed 40455 total (img, caption) pairs
→ 36409 train samples, 4046 val samples
✅ DataLoaders created with batch_size=64, num_workers=2


In [11]:
# Step 8: Define Attention module
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)
        att = self.full_att(self.relu(att1 + att2)).squeeze(2)
        alpha = self.softmax(att)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return attention_weighted_encoding, alpha

print("✅ Attention module defined.")

✅ Attention module defined.


In [12]:
# Step 9: Define EncoderCNN with EfficientNet-B3 backbone
class EncoderCNN(nn.Module):
    def __init__(self, embed_size=512, fine_tune=True):
        super().__init__()
        self.backbone = timm.create_model("efficientnet_b3", pretrained=True, features_only=True)
        self.out_channels = self.backbone.feature_info[-1]['num_chs']
        if not fine_tune:
            for param in self.backbone.parameters():
                param.requires_grad = False
        self.conv = nn.Conv2d(self.out_channels, embed_size, kernel_size=1)
        self.bn = nn.BatchNorm2d(embed_size)

    def forward(self, x):
        feats = self.backbone(x)[-1]
        feats = self.bn(self.conv(feats))
        b, c, h, w = feats.shape
        return feats.permute(0, 2, 3, 1).reshape(b, h * w, c)

print("✅ EncoderCNN module with EfficientNet-B3 defined.")

✅ EncoderCNN module with EfficientNet-B3 defined.


In [13]:
# Step 10: Define DecoderRNN with Attention
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, attention_dim=256, num_layers=1, dropout=0.5):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(embed_size, hidden_size, attention_dim)
        self.lstm = nn.LSTM(embed_size + embed_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)
        self.relu = nn.ReLU()

    def forward(self, features, captions):
        embeddings = self.embed(captions[:, :-1])
        batch_size = features.size(0)
        seq_len = embeddings.size(1)
        avg_features = features.mean(dim=1)
        hidden_state = self.init_h(avg_features).unsqueeze(0)
        cell_state = self.init_c(avg_features).unsqueeze(0)
        outputs = []
        alphas = []
        for t in range(seq_len):
            attention_weighted_encoding, alpha = self.attention(features, hidden_state.squeeze(0))
            alphas.append(alpha)
            lstm_input = torch.cat((embeddings[:, t, :], attention_weighted_encoding), dim=1).unsqueeze(1)
            output, (hidden_state, cell_state) = self.lstm(lstm_input, (hidden_state, cell_state))
            output = self.linear(self.dropout(output.squeeze(1)))
            outputs.append(output)
        outputs = torch.stack(outputs, dim=1)
        alphas = torch.stack(alphas, dim=1)
        return outputs

print("✅ DecoderRNN module defined.")

✅ DecoderRNN module defined.


In [14]:
# Step 11: Define ImageCaptioningModel (PyTorch Lightning Module)
class ImageCaptioningModel(pl.LightningModule):
    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, learning_rate=1e-4):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = EncoderCNN(embed_size, fine_tune=True)
        self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=0)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    def training_step(self, batch, batch_idx):
        images, captions = batch
        if images is None:
            return None
        outputs = self(images, captions)
        loss = self.loss_fn(outputs.view(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, captions = batch
        if images is None:
            return None
        with torch.no_grad():
            outputs = self(images, captions)
            loss = self.loss_fn(outputs.view(-1, outputs.size(-1)), captions[:, 1:].reshape(-1))
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=self.hparams.learning_rate)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_loss'
            },
        }

print("✅ ImageCaptioningModel (PyTorch Lightning Module) defined.")

✅ ImageCaptioningModel (PyTorch Lightning Module) defined.


In [15]:
# Step 12: Initialize model and trainer with callbacks
model = ImageCaptioningModel(vocab_size=sp.get_piece_size(), embed_size=512, hidden_size=1024, learning_rate=learning_rate)

early_stop_callback = EarlyStopping(monitor='val_loss', patience=3, verbose=True, mode='min')
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=1, filename='best-checkpoint')
lr_monitor = LearningRateMonitor(logging_interval='epoch')

trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else 'auto',
    max_epochs=max_epochs,
    precision=16 if torch.cuda.is_available() else 32,
    callbacks=[early_stop_callback, checkpoint_callback, lr_monitor],
    gradient_clip_val=1.0
)

print("✅ Model and Trainer initialized with callbacks.")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]

/usr/local/lib/python3.11/dist-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


✅ Model and Trainer initialized with callbacks.


In [16]:
# Step 13: Train the model
print("🚀 Starting model training...")
if 'train_loader' in globals() and 'val_loader' in globals():
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    print("✅ Model training completed.")
    print(f"Best model checkpoint saved at: {trainer.checkpoint_callback.best_model_path}")
else:
    print("❌ train_loader or val_loader not found. Please ensure the data loading and preparation steps were executed.")

🚀 Starting model training...


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | EncoderCNN       | 10.3 M | train
1 | decoder | DecoderRNN       | 16.0 M | train
2 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
26.3 M    Trainable params
0         Non-trainable params
26.3 M    Total params
105.164   Total estimated model params size (MB)
543       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved. New best score: 3.797


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.411 >= min_delta = 0.0. New best score: 3.386


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.207 >= min_delta = 0.0. New best score: 3.179


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.141 >= min_delta = 0.0. New best score: 3.039


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.103 >= min_delta = 0.0. New best score: 2.936


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.079 >= min_delta = 0.0. New best score: 2.857


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.065 >= min_delta = 0.0. New best score: 2.791


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.050 >= min_delta = 0.0. New best score: 2.742


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.048 >= min_delta = 0.0. New best score: 2.694


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.042 >= min_delta = 0.0. New best score: 2.651


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 2.622


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.028 >= min_delta = 0.0. New best score: 2.594


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.018 >= min_delta = 0.0. New best score: 2.576


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.023 >= min_delta = 0.0. New best score: 2.553


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 2.533
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=15` reached.


✅ Model training completed.
Best model checkpoint saved at: /content/lightning_logs/version_0/checkpoints/best-checkpoint.ckpt


In [20]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
# Step 14: Load model for inference
import torch
import torchvision.transforms as transforms
from PIL import Image
import sentencepiece as spm

def load_model_for_inference(checkpoint_path, vocab_size, embed_size=512, hidden_size=1024):
    model = ImageCaptioningModel(vocab_size=vocab_size, embed_size=embed_size, hidden_size=hidden_size)
    checkpoint = torch.load(checkpoint_path, map_location=model.device)
    state_dict = checkpoint['state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('model.'):
            new_state_dict[k[6:]] = v
        else:
            new_state_dict[k] = v
    model.load_state_dict(new_state_dict)
    model.eval()
    return model

# Step 15: Process image for inference
def process_image_for_inference(image_path, transform, device):
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0)
        return image.to(device)
    except FileNotFoundError:
        print(f"❌ Error: Image file not found at {image_path}")
        return None
    except Exception as e:
        print(f"❌ Error processing image {image_path}: {e}")
        return None

# Step 16: Beam search decoding
def generate_caption_beam_search(image_tensor, model, sp, beam_width=3, max_len=30, device=None):
    if model is None:
        print("❌ Model is not loaded. Cannot generate caption.")
        return "Error: Model not loaded."
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        features = model.encoder(image_tensor)
        # Beam search initialization
        sequences = [[list(), 0.0, (torch.zeros(1, 1, model.decoder.lstm.hidden_size).to(device),
                                  torch.zeros(1, 1, model.decoder.lstm.hidden_size).to(device))]]
        # Start with BOS token embedding
        start_token = sp.bos_id()
        end_token = sp.eos_id()
        for _ in range(max_len):
            all_candidates = []
            for seq, score, (h, c) in sequences:
                if len(seq) > 0 and seq[-1] == end_token:
                    all_candidates.append((seq, score, (h, c)))
                    continue
                if len(seq) == 0:
                    inputs = model.decoder.embed(torch.tensor([start_token]).to(device)).unsqueeze(1)
                else:
                    inputs = model.decoder.embed(torch.tensor([seq[-1]]).to(device)).unsqueeze(1)
                attention_weighted_encoding, _ = model.decoder.attention(features, h.squeeze(0))
                lstm_input = torch.cat((inputs.squeeze(1), attention_weighted_encoding), dim=1).unsqueeze(1)
                output, (h, c) = model.decoder.lstm(lstm_input, (h, c))
                output = model.decoder.linear(output.squeeze(1))
                log_probs = torch.log_softmax(output, dim=1)
                top_log_probs, top_indices = torch.topk(log_probs, beam_width)
                for i in range(beam_width):
                    candidate = seq + [top_indices[0][i].item()]
                    candidate_score = score - top_log_probs[0][i].item()  # negative log likelihood
                    all_candidates.append((candidate, candidate_score, (h, c)))
            # Order all candidates by score
            ordered = sorted(all_candidates, key=lambda tup: tup[1])
            # Select top beam_width sequences
            sequences = ordered[:beam_width]
            # If all sequences end with end_token, stop early
            if all(seq[-1] == end_token for seq, _, _ in sequences):
                break
        # Choose the best sequence
        best_seq = sequences[0][0]
        # Convert token ids to words
        tokens = [sp.id_to_piece(id) for id in best_seq if id not in (sp.pad_id(), sp.bos_id(), sp.eos_id())]
        return ' '.join(tokens)

# Step 17: Greedy search decoding
def generate_caption_greedy_search(image_tensor, model, sp, max_len=30, device=None):
    if model is None:
        print("❌ Model is not loaded. Cannot generate caption.")
        return "Error: Model not loaded."
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        features = model.encoder(image_tensor)
        caption = []
        # Start with BOS token
        token = sp.bos_id()
        end_token = sp.eos_id()
        # Initialize hidden and cell states
        avg_features = features.mean(dim=1)
        hidden_state = model.decoder.init_h(avg_features).unsqueeze(0)
        cell_state = model.decoder.init_c(avg_features).unsqueeze(0)
        for _ in range(max_len):
            inputs = model.decoder.embed(torch.tensor([token]).to(device)).unsqueeze(1)
            attention_weighted_encoding, _ = model.decoder.attention(features, hidden_state.squeeze(0))
            lstm_input = torch.cat((inputs.squeeze(1), attention_weighted_encoding), dim=1).unsqueeze(1)
            output, (hidden_state, cell_state) = model.decoder.lstm(lstm_input, (hidden_state, cell_state))
            output = model.decoder.linear(output.squeeze(1))
            _, predicted_token = torch.max(output, dim=1)
            token = predicted_token.item()
            if token == end_token:
                break
            if token != sp.pad_id() and token != sp.bos_id():
                caption.append(sp.id_to_piece(token))
        return ' '.join(caption)


# Example usage for inference
test_image_path = 'data/Images/1000268201_693b08cb0e.jpg'
inference_transform = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device for inference: {device}")

# Update the checkpoint path based on the training output
checkpoint_path = '/content/lightning_logs/version_0/checkpoints/best-checkpoint.ckpt'

model_inference = load_model_for_inference(checkpoint_path, vocab_size=sp.get_piece_size(), embed_size=512, hidden_size=1024)
model_inference.to(device)

image_tensor = process_image_for_inference(test_image_path, inference_transform, device)
if image_tensor is not None:
    caption_beam = generate_caption_beam_search(image_tensor, model_inference, sp, beam_width=5, max_len=30, device=device)
    print("\n🖼️ Generated Caption (Beam Search):")
    print(caption_beam)

    caption_greedy = generate_caption_greedy_search(image_tensor, model_inference, sp, max_len=30, device=device)
    print("\n🖼️ Generated Caption (Greedy Search):")
    print(caption_greedy)

Using device for inference: cuda





🖼️ Generated Caption (Beam Search):
▁A ▁little ▁girl ▁in ▁a ▁pink ▁shirt ▁is ▁climbing ▁a ▁wooden ▁structure ▁.

🖼️ Generated Caption (Greedy Search):
▁A ▁little ▁girl ▁in ▁a ▁pink ▁shirt ▁is ▁climbing ▁a ▁wooden ▁structure ▁.


### Step 0a: Set up Kaggle API
Before running the next cell, please upload your `kaggle.json` file to the `/content/` directory in your Colab environment. This file contains your Kaggle API credentials.

### Step 0c: Download the dataset
I recommend using the **Flickr8k** dataset for this project. It's a standard benchmark for image captioning and should work well with your model architecture. The dataset includes images and five captions per image.

In [21]:
# Step 13c: Copy the best checkpoint to Google Drive
import shutil
import os

# Define the source and destination paths
# Make sure trainer.checkpoint_callback.best_model_path is accessible
source_path = trainer.checkpoint_callback.best_model_path
destination_dir = '/content/drive/My Drive/image_captioning_checkpoints'
destination_path = os.path.join(destination_dir, os.path.basename(source_path))

# Create the destination directory if it doesn't exist
os.makedirs(destination_dir, exist_ok=True)

# Copy the file
try:
    shutil.copyfile(source_path, destination_path)
    print(f"✅ Best model checkpoint copied to Google Drive: {destination_path}")
except FileNotFoundError:
    print(f"❌ Error: Source file not found at {source_path}. Ensure training completed successfully.")
except Exception as e:
    print(f"❌ Error copying file to Google Drive: {e}")

✅ Best model checkpoint copied to Google Drive: /content/drive/My Drive/image_captioning_checkpoints/best-checkpoint.ckpt
