# Automated Image Captioning Using PyTorch

## Step 1: Import Necessary Libraries

In [19]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
from torchvision.io import read_image
import pandas as pd
from nltk.tokenize import wordpunct_tokenize

## Step 2: Load Dataset

In [20]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels['tokenized'] = self.img_labels.iloc[:,1].apply(lambda x: wordpunct_tokenize(x))
        self.img_labels['len+2'] = self.img_labels.iloc[:,2].apply(lambda x: len(x)+2) #2 added for <SOS> and <EOS>
        self.max_sentence_length = int(self.img_labels['len+2'].max())
        all_tokens = [token for tokens in self.img_labels['tokenized'] for token in tokens]
        unique_tokens = set(all_tokens)
        self.vocab_size = len(unique_tokens)+3
        self.img_labels['tokenized'] = self.img_labels['tokenized'].apply(lambda x: ['<SOS>']+x+['<EOS>'])
        self.word_to_idx = {'<PAD>':0,'<SOS>':1,'<EOS>':2}
        self.word_to_idx.update({word:idx for idx,word in enumerate(unique_tokens,start=3)})
        self.idx_to_word = {value:key for key,value in self.word_to_idx.items()}
        self.img_labels['sent_to_idx'] = self.img_labels['tokenized'].apply(lambda x: [self.word_to_idx.get(z) for z in x])
    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels['sent_to_idx'].get(idx)
        pad_length = self.max_sentence_length-len(self.img_labels['sent_to_idx'].get(idx))
        label = label+pad_length*[0]
        label = torch.tensor(label)
        if self.transform:
            image = self.transform(image)
        tgt_key_padding_mask = label[:-1]==0
        return image, label, tgt_key_padding_mask

## Step 3: Feature Extraction using CNN (ResNet)

In [21]:
class CNNEncoder(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)  # Load pre-trained ResNet
        for param in resnet.parameters():
            param.requires_grad = False  # Freeze ResNet layers
        
        # Remove the final layer and add an embedding layer
        resnet_layers_except_last = list(resnet.children())[:-1] # this is a list of layers
        self.resnet = nn.Sequential(*resnet_layers_except_last)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.init_weights()
    
    def init_weights(self):
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, images):
        features = self.resnet(images)  # Extract features
        features = features.view(features.size(0), -1)  # Flatten features
        features = self.fc(features)
        return features

## Step 4: Caption Generation using Transformer

In [22]:
class TransformerCaptioningModel(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, num_layers,max_seq_len,device):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=8,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=hidden_size,
            batch_first=True
        )
        self.fc = nn.Linear(embed_size, vocab_size)
        self.tgt_mask = nn.Transformer.generate_square_subsequent_mask(max_seq_len-1).to(device)
    
    def forward(self, features, captions,tgt_key_padding_mask,tgt_mask=None):
        embeddings = self.embedding(captions)
        if tgt_mask is None:
            transformer_output = self.transformer(features.unsqueeze(1),
                                                embeddings,
                                                tgt_mask = self.tgt_mask,
                                                tgt_key_padding_mask=tgt_key_padding_mask)
        else:
            transformer_output = self.transformer(features.unsqueeze(1),
                                                embeddings,
                                                tgt_mask = tgt_mask,
                                                tgt_key_padding_mask=tgt_key_padding_mask)
        outputs = self.fc(transformer_output)
        return outputs


## Step 5: Training the Model

In [23]:
def train_model(encoder, decoder, loss_fn,dataloader, num_epochs, learning_rate, vocab_size,device):
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
    loss_vec = []
    for epoch in range(num_epochs):
        counter, loss_val= 0, 0
        for images, captions, tgt_key_padding_mask in dataloader:
            images, captions, tgt_key_padding_mask = images.to(device), captions.to(device), tgt_key_padding_mask.to(device)
            # this implementation is done for one word prediction at time
            in_captions = captions[:,:-1] 
            out_captions = captions[:,1:]
            # Forward pass
            features = encoder(images)
            outputs = decoder(features, in_captions, tgt_key_padding_mask)

            # Compute loss and backpropagate
            loss = loss_fn(outputs.reshape(-1, vocab_size), out_captions.reshape(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_val = loss_val+loss.item()
            counter+=1
        loss_in_epoch = loss_val/counter
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss_in_epoch}")
        loss_vec.append(loss_in_epoch)
    return loss_vec

## 6.Initialize the model components

In [None]:
embed_size = 256
hidden_size = 512
num_epochs = 100
learning_rate = 0.0005
num_layers = 2
batch_size = 16

Image_Dir = 'Path//to//image_folder'
Annot_Dir = 'Path//to//csv_file'

composer = transforms.Compose([transforms.ToPILImage(),
                               transforms.Resize([256,256]),
                               transforms.ToTensor(),
                               transforms.Normalize(0,1)])

full_dataset = CustomImageDataset(Annot_Dir,Image_Dir,transform=composer)
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [0.8, 0.2])

vocab_size = full_dataset.vocab_size

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder = CNNEncoder(embed_size).to(device)
decoder = TransformerCaptioningModel(embed_size, vocab_size, hidden_size, num_layers,full_dataset.max_sentence_length,device).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
# Load data using a DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


loss_vec = train_model(encoder, decoder, loss_fn,train_dataloader, num_epochs, learning_rate, vocab_size, device)


## 7.Plot learning curve

In [None]:
fig, ax = plt.subplots()
line_up, = ax.plot(range(1,len(loss_vec)+1),loss_vec, label='Line 1')
ax.legend([line_up], ['Loss'])

## 8.Give image and get caption

In [None]:
Image_Address = 'Path//to//sigle_image'
Image = mpimg.imread(Image_Address)
plt.imshow(Image)
plt.axis('off')
# read image
image = read_image(Image_Address)

resized_image = composer(image).unsqueeze(0).to(device)
features = encoder(resized_image)
first_word = torch.ones([1,1],dtype=torch.int64).to(device) #<SOS>
encoder.eval()
decoder.eval()
with torch.inference_mode():
    generated_caption = first_word.to(device)
    special_idx = [0,1,2]
    cond = True
    while cond:         
        tgt_key_padding_mask=generated_caption==0
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(generated_caption.size(1)).to(device)            
        output = decoder(features,generated_caption,tgt_key_padding_mask,tgt_mask)
        new_caption = torch.argmax(output,dim=2)
        new_caption = new_caption[:,-1].unsqueeze(1)
        generated_caption = torch.hstack((generated_caption,new_caption))
        idx_list = generated_caption.squeeze().tolist()
        last_idx = idx_list[-1]
        if (last_idx in special_idx) or (len(idx_list)==(full_dataset.max_sentence_length-1)):
            cond = False
        
token = [full_dataset.idx_to_word.get(x) for x in idx_list]
title = ' '.join([x for x in token if (x !='<PAD>' and x!='<SOS>' and x!='<EOS>')])
plt.title(title)
plt.show()