## Import dependencies

In [1]:
import numpy as np
import pandas as pd
import spacy 
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torchvision.models import resnet50,ResNet50_Weights
from torch.utils.data import DataLoader, Dataset, random_split

2025-10-02 17:08:52.409293: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759424932.607631      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759424932.663477      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Load Dataset

In [2]:
df = pd.read_csv('/kaggle/input/flickr8k/captions.txt')
df

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...
...,...,...
40450,997722733_0cb5439472.jpg,A man in a pink shirt climbs a rock face
40451,997722733_0cb5439472.jpg,A man is rock climbing high in the air .
40452,997722733_0cb5439472.jpg,A person in a red shirt climbing up a rock fac...
40453,997722733_0cb5439472.jpg,A rock climber in a red shirt .


## Preprocess Data


Main task: to convert text to numerical values

1. vocab mapping each word to an index
2. setup pytorch dataset to load the data
3. setup padding of every batch (so all the sequence_length are same and then setup dataloader)

In [3]:
# !python -m spacy download en_core_web_sm
spacy_eng= spacy.load("en_core_web_sm")
#testing
[tok.text.lower() for tok in spacy_eng.tokenizer('This is a sentence.')]

['this', 'is', 'a', 'sentence', '.']

### Vocabulary Mapping

In [4]:
class Vocabulary:
    def __init__(self,freq_threshold):
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        self.stoi = {"<PAD>":0,"<SOS>":1,"<EOS>":2,"<UNK>":3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod #this assures the function is under vocabulary but doesnot attend to its objects state.
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

    def build_vocabulary(self,sentence_list):
        '''
        sentence_list: input to the function of list of sentences

        Flow: Loop over sentencelist -> for each word in sentence update its freq
              -> if freq reaches threshold then add it to the vocabulary, else not(<UNK>).

        Note: The counter starts from 4, because 0-3 are reserved.
        '''
        frequencies = {}
        idx = 4 # 0-3 are reserved for pad,sos,bos,unk
        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                frequencies[word] = frequencies.get(word,0) + 1 # new word gets 1, if word already exists then updates the counter
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self,text):
        '''
        Takes in text and returns the id of the tokenized text if in stoi else <UNK>
        '''
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]

        

### Dataset Definition

In [5]:
class FlickrDataset(Dataset):
    def __init__(self,root_dir,captions_file,transform=None,freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(captions_file)
        self.transform = transform

        # get image and caption columns
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]
        
        #initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions.tolist())

    def __len__(self):
        return len(self.df)

    def __getitem__(self,index):
        '''
        Takes in particular index, maps to the index-th row of dataframe with 
        image,caption pair. Then wraps the numericalized caption with SOS and EOS.
        Returns the tensor of the final list of ids.
        '''
        #pull a row(image,caption pair) from dataframe
        caption = self.captions[index]
        img_id = self.imgs[index]
        # load the image with PIL
        img = Image.open(os.path.join(self.root_dir,img_id)).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        # wrapping up the numericalized token with SOS and EOS at first and last respectively.
        numericalized_caption = [self.vocab.stoi["<SOS>"]]
        numericalized_caption += self.vocab.numericalize(caption)
        numericalized_caption.append(self.vocab.stoi["<EOS>"])

        return img, torch.tensor(numericalized_caption)
    

### Collate Functionality

In [6]:
class MyCollate:
    '''
    This class aims to match the sequence of every text sequence using <PAD> tokens.
    For Images, This class introduces a Batch dimension making it [B,C,H,W] from [C,H,W]
    '''
    def __init__(self,pad_idx):
        self.pad_idx = pad_idx

    def __call__(self,batch):
        '''
        Input Example: batch -> [(img1_tensor,[1,4,8,4]),.... ], The image tensor size remains same, but 
        text tensor sequence length may differ so to fix we add pad tokens.

        For images: Take the each 1st value of each item from batch and unsqueeze at dimension 0 to add batch dim in all the image tensor
        then concatenate along the 0th(batch) dimension.

        For captions (as targets): Take the 2nd value of each item from batch and then pad it with pad value
        '''
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim=0)

        targets = [item[1] for item in batch]
        targets = pad_sequence(targets,batch_first = False,padding_value=self.pad_idx)

        return imgs,targets

### Generate DataLoaders 

In [7]:
def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
    split_ratio=0.8
):
    dataset = FlickrDataset(root_folder,annotation_file,transform=transform)
    #train val split
    train_size = int(split_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset,[train_size,val_size])

    pad_idx = dataset.vocab.stoi["<PAD>"]
    
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
        drop_last = False
    )
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False,
        pin_memory=pin_memory,
        collate_fn=MyCollate(pad_idx=pad_idx),
        drop_last = False
    )
    return train_loader,val_loader,dataset

## Utility functions
* Useful for pausing during training using 'checkpoints'.
1. save_checkpoint
2. load_checkpoint

In [8]:
def save_checkpoint(state,filename='my_checkpoint.pth.tar'):
    print("Saving Checkpoint")
    torch.save(state,filename)

In [9]:
def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

## Model Architecture

1. CNN as an encoder (ResNet50) for images to capture those features
2. RNN as a decoder (LSTM) for text generation

In [10]:
class EncoderCNN(nn.Module):
    def __init__(self,embed_size,train_CNN=False):
        super(EncoderCNN,self).__init__()
        self.train_CNN = train_CNN
        self.resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features,embed_size) #removing the classifier layer, then preserving the features in embed_size size.
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        for name,param in self.resnet.named_parameters():
            param.requires_grad = (train_CNN or 'fc' in name)

    def forward(self,images):
        features = self.resnet(images)
        return self.dropout(self.relu(features))


In [11]:
class DecoderRNN(nn.Module):
    def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
        super(DecoderRNN,self).__init__()
        self.embed = nn.Embedding(vocab_size,embed_size)
        self.lstm = nn.LSTM(embed_size,hidden_size,num_layers)
        self.linear = nn.Linear(hidden_size,vocab_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self,features,captions):
        '''
        Adds a sequence dimension to the features (dim0)[1,batch_size,embed_size] and then concatenates along dim0.
        First timestep gives image feature, Next timestep gives caption embeddings.
        Final shape becomes:[1+seq_len,batch_size,embed_size]
        '''
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0),embeddings),dim=0)
        hiddens,_ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

In [12]:
class CNNtoRNN(nn.Module):
    def __init__(self,embed_size,hidden_size,vocab_size,num_layers):
        super(CNNtoRNN,self).__init__()
        self.CNN = EncoderCNN(embed_size)
        self.RNN = DecoderRNN(embed_size,hidden_size,vocab_size,num_layers)

    def forward(self,images,captions):
        features = self.CNN(images)
        outputs = self.RNN(features,captions)
        return outputs
        
    def caption_image(self,image,vocabulary,max_len=50):
        '''
        During inference or evaluation we wont have target captions(which is going to be predicted,duh!).
        image: inference image
        vocabulary: whole mapping dictionary
        max_len: max length for captions to be predicted.
        '''
        caption_result = []
        with torch.no_grad():
            x = self.CNN(image).unsqueeze(0) # add batch dimension at dim0
            states = None
            for _ in range(max_len):
                hiddens,states = self.RNN.lstm(x,states)
                output = self.RNN.linear(hiddens.squeeze(0))
                predicted = output.argmax(1)
                
                caption_result.append(predicted.item())
                x = self.RNN.embed(predicted).unsqueeze(0)
                if vocabulary.itos[predicted.item()] == '<EOS>':
                    break

        return [vocabulary.itos[idx] for idx in caption_result]
                

## Train Model

In [13]:
transform = transforms.Compose(
    [
        transforms.Resize(356),
        transforms.RandomCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), #from pytorch docs(resnet)
    ]
    )

train_loader ,val_loader, dataset = get_loader(
    root_folder="/kaggle/input/flickr8k/Images/",
    annotation_file="/kaggle/input/flickr8k/captions.txt",
    transform = transform,
    num_workers=2
    )

device = 'cuda' if torch.cuda.is_available() else 'cpu'
load_model = False
save_model = True
train_CNN = False #set this true to train all the params of the ResNet model

embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs = 10

In [14]:
def train():
    torch.backends.cudnn.benchmark = True #10-20% faster training on gpu

    writer = SummaryWriter("runs/flickr") #logging details from training for review

    step = 0 # training step counter (based on epochs)

    model = CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
    optimizer = optim.Adam((p for p in model.parameters() if p.requires_grad),lr=learning_rate) # just to avoid wasting memory on frozen params.

    # Can continue the training from checkpoint if saved.
    if load_model:
        checkpoint_path="my_checkpoint.pth.tar"
        if os.path.exists(checkpoint_path):
            try:
                step = load_checkpoint(torch.load(checkpoint_path),model,optimizer)
                print("Checkpoint Loaded.")
            except Exception as e:
                print("Error loading checkpoint:{e}")
                print("Training from scratch...")
                step=0

    model.train()
    for epoch in range(num_epochs):
        total_train_loss = 0
        loop = tqdm(
            enumerate(train_loader),
            total=len(train_loader),
            desc=f"Epoch [{epoch+1}/{num_epochs}]",
            unit="batch",
            leave=True
        )
        for batch_idx, (imgs, captions) in loop:
            imgs, captions = imgs.to(device), captions.to(device)
            outputs = model(imgs, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
            loop.set_postfix(batch_loss=loss.item(),
            avg_loss=total_train_loss/(batch_idx+1))
    
            writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1
        if save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step
                        }
            save_checkpoint(checkpoint)
        avg_train_loss = total_train_loss / len(train_loader)
        print(f"Epoch {epoch+1} complete. Avg Train Loss: {avg_train_loss:.4f}")
    torch.save(model.state_dict(),"final_model.pth")
    print("Model saved successfully!")

train()

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 212MB/s]
Epoch [1/10]: 100%|██████████| 1012/1012 [02:50<00:00,  5.92batch/s, avg_loss=3.8, batch_loss=3.46]


Saving Checkpoint
Epoch 1 complete. Avg Train Loss: 3.7959


Epoch [2/10]: 100%|██████████| 1012/1012 [02:25<00:00,  6.95batch/s, avg_loss=3.11, batch_loss=3.08]


Saving Checkpoint
Epoch 2 complete. Avg Train Loss: 3.1116


Epoch [3/10]: 100%|██████████| 1012/1012 [02:24<00:00,  6.98batch/s, avg_loss=2.89, batch_loss=2.99]


Saving Checkpoint
Epoch 3 complete. Avg Train Loss: 2.8881


Epoch [4/10]: 100%|██████████| 1012/1012 [02:26<00:00,  6.93batch/s, avg_loss=2.74, batch_loss=3.2] 


Saving Checkpoint
Epoch 4 complete. Avg Train Loss: 2.7403


Epoch [5/10]: 100%|██████████| 1012/1012 [02:25<00:00,  6.94batch/s, avg_loss=2.63, batch_loss=2.47]


Saving Checkpoint
Epoch 5 complete. Avg Train Loss: 2.6305


Epoch [6/10]: 100%|██████████| 1012/1012 [02:25<00:00,  6.95batch/s, avg_loss=2.54, batch_loss=2.13]


Saving Checkpoint
Epoch 6 complete. Avg Train Loss: 2.5448


Epoch [7/10]: 100%|██████████| 1012/1012 [02:25<00:00,  6.94batch/s, avg_loss=2.48, batch_loss=2.65]


Saving Checkpoint
Epoch 7 complete. Avg Train Loss: 2.4756


Epoch [8/10]: 100%|██████████| 1012/1012 [02:28<00:00,  6.82batch/s, avg_loss=2.42, batch_loss=2.82]


Saving Checkpoint
Epoch 8 complete. Avg Train Loss: 2.4167


Epoch [9/10]: 100%|██████████| 1012/1012 [02:25<00:00,  6.96batch/s, avg_loss=2.37, batch_loss=2.41]


Saving Checkpoint
Epoch 9 complete. Avg Train Loss: 2.3658


Epoch [10/10]: 100%|██████████| 1012/1012 [02:24<00:00,  7.00batch/s, avg_loss=2.32, batch_loss=2.58]


Saving Checkpoint
Epoch 10 complete. Avg Train Loss: 2.3212
Model saved successfully!


In [18]:
!tensorboard --logdir=runs #for training details after completion

2025-10-02 17:37:23.043647: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759426643.064453     237 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759426643.071423     237 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.18.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [19]:
def eval():
    model = CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
    model.load_state_dict(torch.load("final_model.pth"))
    criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])

    model.eval()
    total_val_loss = 0
    loop = tqdm(
        enumerate(val_loader),
        total=len(val_loader),
        desc=f"Evaluation",
        unit="batch",
        leave=True
    )
    with torch.no_grad():
        for batch_idx, (imgs, captions) in loop:
            imgs, captions = imgs.to(device), captions.to(device)
            outputs = model(imgs, captions[:-1])
            loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))
            total_val_loss += loss.item()
            loop.set_postfix(batch_loss=loss.item(),
            avg_loss=total_val_loss/(batch_idx+1))
    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Evaluation complete. Avg Validation Loss: {avg_val_loss:.4f}")

eval()

Evaluation: 100%|██████████| 253/253 [00:37<00:00,  6.77batch/s, avg_loss=2.45, batch_loss=2.35]

Evaluation complete. Avg Validation Loss: 2.4506





In [22]:
def inference():
    inference_transform = transforms.Compose(
        [
            transforms.Resize(356),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
        ]
    )
    
    _,val_loader,dataset = get_loader(
        root_folder="/kaggle/input/flickr8k/Images/",
        annotation_file="/kaggle/input/flickr8k/captions.txt",
        transform = inference_transform,
        num_workers=2)
    
    model = CNNtoRNN(embed_size,hidden_size,vocab_size,num_layers).to(device)
    model.load_state_dict(torch.load("final_model.pth"))
    model.eval()
    imgs, captions = next(iter(val_loader))
    imgs, captions = imgs.to(device), captions.to(device)
    y_pred = model.caption_image(imgs[3].unsqueeze(0), dataset.vocab)  # one image
    print("\nPredicted:", " ".join(y_pred))


inference()
    


Predicted: <SOS> a white dog is running through the grass . <EOS>
