In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Import required libraries
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from tqdm import tqdm
import random

Mounted at /content/drive


In [None]:
# File paths (already saved)
FEATURES_PATH = "/content/drive/MyDrive/data/image_features.pt"
CAPTIONS_FILE = "/content/drive/MyDrive/data/filtered_captions_cleaned.tsv"
IMAGE_FOLDER = "/content/drive/MyDrive/data/loadedimages"  # For inference

# Load precomputed image features
features_dict = torch.load(FEATURES_PATH)
print(f"Loaded features for {len(features_dict)} images from {FEATURES_PATH}")

Loaded features for 53367 images from /content/drive/MyDrive/data/image_features.pt


In [None]:
# Load captions and build captions dictionary
df = pd.read_csv(CAPTIONS_FILE, sep="\t", names=["image", "caption"])
df = df.dropna()
df["image"] = df["image"].astype(str).str.strip()
captions_dict = {}
for _, row in df.iterrows():
    img_name = row["image"]
    caption = row["caption"]
    if img_name not in captions_dict:
        captions_dict[img_name] = []
    captions_dict[img_name].append(caption)
print(f"Captions dictionary created for {len(captions_dict)} images")


Captions dictionary created for 53367 images


In [None]:
# Build larger vocabulary
all_captions = sum(captions_dict.values(), [])
words = [word for caption in all_captions for word in caption.split()]
word_counts = Counter(words)

# Option 1: Take the 10,000 most common words (adjustable)
# most_common = word_counts.most_common(10000)
# vocab = ["<PAD>", "<SOS>", "<EOS>"] + [w for w, _ in most_common]

# Option 2: Use all unique words (default)
vocab = ["<PAD>", "<SOS>", "<EOS>"] + list(word_counts.keys())

word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
VOCAB_SIZE = len(vocab)
print(f"Vocabulary size: {VOCAB_SIZE}")


Vocabulary size: 16399


In [None]:
# Dataset class
class CaptionFeatureDataset(Dataset):
    def __init__(self, features_dict, captions_dict, word2idx, max_len=22):
        self.image_names = list(features_dict.keys())
        self.features_dict = features_dict
        self.captions_dict = captions_dict
        self.word2idx = word2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        feature = self.features_dict[img_name]
        caption = random.choice(self.captions_dict[img_name]).split()
        caption = ["<SOS>"] + caption + ["<EOS>"]
        tokens = [self.word2idx.get(w, self.word2idx["<PAD>"]) for w in caption]
        tokens = tokens[:self.max_len] + [self.word2idx["<PAD>"]] * (self.max_len - len(tokens))
        return feature, torch.tensor(tokens, dtype=torch.long)

In [None]:
# Split dataset into train and validation
dataset = CaptionFeatureDataset(features_dict, captions_dict, word2idx, max_len=22)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
print(f"Training set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")


Training set size: 42693, Validation set size: 10674


In [None]:
# Transformer Decoder
class Transformer_Decoder(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 22, embed_size))
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size, nhead=8, dim_feedforward=hidden_size, dropout=0.1, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, features, captions, mask=None):
        batch_size = captions.size(0)
        seq_len = captions.size(1)
        embedded = self.embedding(captions) + self.pos_encoding[:, :seq_len, :]
        if mask is None:
            mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(features.device)
        memory = features.unsqueeze(1)
        output = self.decoder(tgt=embedded, memory=memory, tgt_mask=mask)
        return self.fc(output)

In [None]:
# Training setup
device = "cuda" if torch.cuda.is_available() else "cpu"
decoder = Transformer_Decoder(embed_size=256, vocab_size=VOCAB_SIZE, hidden_size=512, num_layers=3).to(device)
project_features = nn.Linear(2048, 256).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<PAD>"])
optimizer = torch.optim.Adam(list(decoder.parameters()) + list(project_features.parameters()), lr=0.001)

In [None]:
# Training loop
NUM_EPOCHS = 20
best_loss = float('inf')
print("Training Started...")
for epoch in range(NUM_EPOCHS):
    decoder.train()
    project_features.train()
    total_train_loss = 0
    for features, captions in tqdm(train_loader):
        features, captions = features.to(device), captions.to(device)
        optimizer.zero_grad()
        projected = project_features(features)
        output = decoder(projected, captions[:, :-1])
        loss = criterion(output.view(-1, VOCAB_SIZE), captions[:, 1:].contiguous().view(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    decoder.eval()
    project_features.eval()
    total_val_loss = 0
    with torch.no_grad():
        for features, captions in val_loader:
            features, captions = features.to(device), captions.to(device)
            projected = project_features(features)
            output = decoder(projected, captions[:, :-1])
            loss = criterion(output.view(-1, VOCAB_SIZE), captions[:, 1:].contiguous().view(-1))
            total_val_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save({'decoder': decoder.state_dict(), 'project_features': project_features.state_dict()}, '/content/drive/MyDrive/data/best_model.pt')

print("Training Finished!")

Training Started...


100%|██████████| 1335/1335 [00:20<00:00, 65.44it/s]


Epoch 1 - Train Loss: 5.2449, Val Loss: 4.7580


100%|██████████| 1335/1335 [00:20<00:00, 66.57it/s]


Epoch 2 - Train Loss: 4.4471, Val Loss: 4.5718


100%|██████████| 1335/1335 [00:20<00:00, 66.50it/s]


Epoch 3 - Train Loss: 4.0750, Val Loss: 4.5411


100%|██████████| 1335/1335 [00:20<00:00, 66.53it/s]


Epoch 4 - Train Loss: 3.7862, Val Loss: 4.5650


100%|██████████| 1335/1335 [00:20<00:00, 66.45it/s]


Epoch 5 - Train Loss: 3.5416, Val Loss: 4.6219


100%|██████████| 1335/1335 [00:19<00:00, 66.78it/s]


Epoch 6 - Train Loss: 3.3344, Val Loss: 4.6916


100%|██████████| 1335/1335 [00:19<00:00, 67.17it/s]


Epoch 7 - Train Loss: 3.1639, Val Loss: 4.7695


100%|██████████| 1335/1335 [00:19<00:00, 67.11it/s]


Epoch 8 - Train Loss: 3.0252, Val Loss: 4.8417


100%|██████████| 1335/1335 [00:20<00:00, 66.71it/s]


Epoch 9 - Train Loss: 2.9044, Val Loss: 4.9196


100%|██████████| 1335/1335 [00:19<00:00, 66.78it/s]


Epoch 10 - Train Loss: 2.7994, Val Loss: 4.9918


100%|██████████| 1335/1335 [00:20<00:00, 66.41it/s]


Epoch 11 - Train Loss: 2.7050, Val Loss: 5.0596


100%|██████████| 1335/1335 [00:20<00:00, 66.74it/s]


Epoch 12 - Train Loss: 2.6241, Val Loss: 5.1231


100%|██████████| 1335/1335 [00:20<00:00, 66.72it/s]


Epoch 13 - Train Loss: 2.5503, Val Loss: 5.1910


100%|██████████| 1335/1335 [00:20<00:00, 66.50it/s]


Epoch 14 - Train Loss: 2.4773, Val Loss: 5.2793


100%|██████████| 1335/1335 [00:19<00:00, 66.84it/s]


Epoch 15 - Train Loss: 2.4177, Val Loss: 5.3119


100%|██████████| 1335/1335 [00:19<00:00, 66.90it/s]


Epoch 16 - Train Loss: 2.3584, Val Loss: 5.3496


100%|██████████| 1335/1335 [00:20<00:00, 66.49it/s]


Epoch 17 - Train Loss: 2.3065, Val Loss: 5.4024


100%|██████████| 1335/1335 [00:19<00:00, 67.08it/s]


Epoch 18 - Train Loss: 2.2538, Val Loss: 5.4898


100%|██████████| 1335/1335 [00:20<00:00, 66.46it/s]


Epoch 19 - Train Loss: 2.2068, Val Loss: 5.5250


100%|██████████| 1335/1335 [00:20<00:00, 65.97it/s]


Epoch 20 - Train Loss: 2.1633, Val Loss: 5.5557
Training Finished!


In [None]:
# Load ResNet50 for inference
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
resnet.eval().cuda()
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Beam search for caption generation
def beam_search_caption(image_path, resnet, transform, decoder, project_features, word2idx, idx2word, beam_width=5, max_len=22):
    decoder.eval()
    project_features.eval()
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).cuda()
    with torch.no_grad():
        cnn_feat = resnet(image_tensor).view(1, -1)
        img_embed = project_features(cnn_feat)
    sequences = [[ [word2idx["<SOS>"]], 0.0 ]]
    completed = []
    for _ in range(max_len):
        candidates = []
        for seq, score in sequences:
            if seq[-1] == word2idx["<EOS>"]:
                completed.append((seq, score))
                continue
            input_seq = torch.tensor([seq], dtype=torch.long).to(device)
            with torch.no_grad():
                output = decoder(img_embed, input_seq)
                probs = torch.softmax(output[:, -1, :], dim=-1)
                topk = torch.topk(probs, beam_width)
            for i in range(beam_width):
                token = topk.indices[0, i].item()
                token_prob = topk.values[0, i].item()
                new_seq = seq + [token]
                new_score = score + torch.log(torch.tensor(token_prob + 1e-10)).item()
                candidates.append((new_seq, new_score))
        sequences = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
        if len(sequences) == 0:
            break
    all_sequences = sequences + completed
    best_seq = sorted(all_sequences, key=lambda x: x[1], reverse=True)[0][0]
    caption = [idx2word[t] for t in best_seq if t not in [word2idx["<PAD>"], word2idx["<SOS>"], word2idx["<EOS>"]]]
    return " ".join(caption)

In [None]:
# Test inference
test_image_path = "/content/drive/MyDrive/data/loadedimages/106.jpg"
caption = beam_search_caption(test_image_path, resnet, transform, decoder, project_features, word2idx, idx2word)
print("Generated Caption:", caption)

Generated Caption: i love the smell of the flowers !


In [None]:
# Load saved model
checkpoint = torch.load('/content/drive/MyDrive/data/best_model.pt')
decoder.load_state_dict(checkpoint['decoder'])
project_features.load_state_dict(checkpoint['project_features'])
decoder.eval()
project_features.eval()
print("Loaded saved model from /content/drive/MyDrive/data/best_model.pt")