<a href="https://colab.research.google.com/github/indrad123/imagecaptioning/blob/main/fin_cnn_rnn_attention_image_caption_id.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchtext.data import Field
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import models
from collections import defaultdict
from PIL import Image
import os
import numpy as np
from datasets import load_dataset
from torch.cuda.amp import GradScaler, autocast
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sklearn.model_selection import train_test_split
import h5py
from tqdm import tqdm
from google.colab import drive

nltk.download('punkt')

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the checkpoint path
checkpoint_dir = '/content/drive/MyDrive/CnnRnnAttention/checkpoint'
checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.pth.tar')

# Define the model save path
model_dir = '/content/drive/MyDrive/CnnRnnAttention/model'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Ensure the checkpoint directory exists
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Load dataset
dataset = load_dataset("Mozilla/flickr30k-transformed-captions")

# Prepare the dataset
data = pd.DataFrame({
    "image": dataset["test"]["image"],
    "caption": dataset["test"]["original_alt_text"]
})

# Split the dataset
train_data, val_data = train_test_split(data, test_size=0.05, random_state=42)

# Define captions field
captions = Field(sequential=False, init_token='<start>', eos_token='<end>')

# Build vocabulary
all_captions = train_data['caption'].tolist()
all_tokens = [[w.lower() for w in c.split()] for c in all_captions]
all_tokens = [w for sublist in all_tokens for w in sublist]
captions.build_vocab(all_tokens)



In [None]:
class Vocab:
    pass

vocab = Vocab()
captions.vocab.itos.insert(0, '<pad>')
vocab.itos = captions.vocab.itos
vocab.stoi = defaultdict(lambda: captions.vocab.itos.index('<unk>'))
vocab.stoi['<pad>'] = 0
for s, i in captions.vocab.stoi.items():
    vocab.stoi[s] = i + 1

# Custom dataset class
class CaptioningData(Dataset):
    def __init__(self, df, vocab, transform=None):
        self.df = df.reset_index(drop=True)
        self.vocab = vocab
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize(224),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
        )

    def __getitem__(self, index):
        row = self.df.iloc[index].squeeze()
        image = Image.open(row.image).convert('RGB')
        image = self.transform(image)
        caption = row.caption.lower().split()
        target = [self.vocab.stoi['<start>']] + [self.vocab.stoi[token] for token in caption] + [self.vocab.stoi['<end>']]
        target = torch.Tensor(target).long()
        return image, target, row.caption

    def __len__(self):
        return len(self.df)

    def collate_fn(self, data):
        data.sort(key=lambda x: len(x[1]), reverse=True)
        images, targets, captions = zip(*data)
        images = torch.stack(images, 0)
        lengths = [len(tar) for tar in targets]
        padded_targets = torch.zeros(len(targets), max(lengths)).long()
        for i, tar in enumerate(targets):
            end = lengths[i]
            padded_targets[i, :end] = tar[:end]
        return images.to(device), padded_targets.to(device), torch.tensor(lengths).long().to(device), captions



In [None]:
# Save intermediate data to H5
def save_to_h5(dataset, file_name):
    with h5py.File(file_name, 'w') as h:
        images = []
        captions = []
        for img, cap, _ in tqdm(dataset):
            images.append(img.numpy())
            captions.append(cap.numpy())
        h.create_dataset('images', data=np.array(images))
        h.create_dataset('captions', data=np.array(captions))

# Load intermediate data from H5
def load_from_h5(file_name):
    with h5py.File(file_name, 'r') as h:
        images = h['images'][:]
        captions = h['captions'][:]
    return images, captions

# Save training and validation datasets to H5
save_to_h5(CaptioningData(train_data, vocab), 'train_data.h5')
save_to_h5(CaptioningData(val_data, vocab), 'val_data.h5')



In [None]:
# Custom dataset class for loading H5 data
class H5Dataset(Dataset):
    def __init__(self, file_name):
        self.images, self.captions = load_from_h5(file_name)

    def __getitem__(self, index):
        image = torch.tensor(self.images[index])
        caption = torch.tensor(self.captions[index]).long()
        return image, caption

    def __len__(self):
        return len(self.images)

    def collate_fn(self, data):
        data.sort(key=lambda x: len(x[1]), reverse=True)
        images, targets = zip(*data)
        images = torch.stack(images, 0)
        lengths = [len(tar) for tar in targets]
        padded_targets = torch.zeros(len(targets), max(lengths)).long()
        for i, tar in enumerate(targets):
            end = lengths[i]
            padded_targets[i, :end] = tar[:end]
        return images.to(device), padded_targets.to(device), torch.tensor(lengths).long().to(device)

# Load datasets from H5
train_ds = H5Dataset('train_data.h5')
val_ds = H5Dataset('val_data.h5')

# Use multiple workers and prefetching
train_dl = DataLoader(train_ds, batch_size=32, collate_fn=train_ds.collate_fn, num_workers=8, pin_memory=True, prefetch_factor=2)
val_dl = DataLoader(val_ds, batch_size=32, collate_fn=val_ds.collate_fn, num_workers=8, pin_memory=True, prefetch_factor=2)



In [None]:
# Encoder
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)

    def forward(self, images):
        with torch.no_grad():
            features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features

encoder = EncoderCNN(256).to(device)

# Decoder
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=80):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seq_length = max_seq_length

    def forward(self, features, captions, lengths):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        packed = pack_padded_sequence(embeddings, lengths.cpu(), batch_first=True)
        outputs, _ = self.lstm(packed)
        outputs = self.linear(outputs[0])
        return outputs

    def predict(self, features, states=None):
        sampled_ids = []
        inputs = features.unsqueeze(1)
        for i in range(self.max_seq_length):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            _, predicted = outputs.max(1)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted)
            inputs = inputs.unsqueeze(1)
        sampled_ids = torch.stack(sampled_ids, 1)
        sentences = []
        for sampled_id in sampled_ids:
            sampled_id = sampled_id.cpu().numpy()
            sampled_caption = []
            for word_id in sampled_id:
                word = vocab.itos[word_id]
                sampled_caption.append(word)
                if word == '<end>':
                    break
            sentence = ' '.join(sampled_caption)
            sentences.append(sentence)
        return sentences

decoder = DecoderRNN(256, 512, len(vocab.itos), 1).to(device)



In [None]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.AdamW(params, lr=1e-3)

# Mixed precision training
scaler = GradScaler()

# Function to save the model checkpoint
def save_checkpoint(state, filename):
    torch.save(state, filename)

# Function to load the model checkpoint
def load_checkpoint(filename, encoder, decoder, optimizer):
    if os.path.isfile(filename):
        print(f"Loading checkpoint '{filename}'")
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Checkpoint loaded successfully from '{filename}' at (epoch {checkpoint['epoch']})")
    else:
        print(f"No checkpoint found at '{filename}'")
        start_epoch = 0
    return start_epoch

# Training function
def train_batch(data, encoder, decoder, optimizer, criterion):
    encoder.train()
    decoder.train()
    images, captions, lengths, _ = data
    images = images.to(device)
    captions = captions.to(device)
    targets = pack_padded_sequence(captions, lengths.cpu(), batch_first=True)[0]

    optimizer.zero_grad()

    with autocast():
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

@torch.no_grad()
def validate_batch(data, encoder, decoder, criterion):
    encoder.eval()
    decoder.eval()
    images, captions, lengths, _ = data
    images = images.to(device)
    captions = captions.to(device)
    targets = pack_padded_sequence(captions, lengths.cpu(), batch_first=True)[0]
    features = encoder(images)
    outputs = decoder(features, captions, lengths)
    loss = criterion(outputs, targets)
    return loss.item()

@torch.no_grad()
def evaluate_model(encoder, decoder, data_loader):
    encoder.eval()
    decoder.eval()
    smooth = SmoothingFunction().method4
    total_bleu_score = 0
    total_samples = 0
    for i, data in enumerate(data_loader):
        images, captions, lengths, original_captions = data
        images = images.to(device)
        features = encoder(images)
        predicted_captions = decoder.predict(features)
        for pred, true in zip(predicted_captions, original_captions):
            true_tokens = [w.lower() for w in nltk.word_tokenize(true)]
            pred_tokens = [w for w in pred.split() if w not in ('<start>', '<end>', '<pad>')]
            bleu_score = sentence_bleu([true_tokens], pred_tokens, smoothing_function=smooth)
            total_bleu_score += bleu_score
            total_samples += 1
    return total_bleu_score / total_samples



In [None]:
# Load model from checkpoint if exists
start_epoch = load_checkpoint(checkpoint_path, encoder, decoder, optimizer)

# Training loop
n_epochs = 10
for epoch in range(start_epoch, n_epochs):
    if epoch == 5:
        optimizer = torch.optim.AdamW(params, lr=1e-4)

    train_loss = 0.0
    for i, data in enumerate(train_dl):
        batch_loss = train_batch(data, encoder, decoder, optimizer, criterion)
        train_loss += batch_loss
        if (i + 1) % 10 == 0:
            print(f'Epoch [{epoch + 1}/{n_epochs}], Step [{i + 1}/{len(train_dl)}], Loss: {batch_loss:.4f}')

    train_loss /= len(train_dl)
    print(f'Epoch [{epoch + 1}/{n_epochs}], Training Loss: {train_loss:.4f}')

    val_loss = 0.0
    for i, data in enumerate(val_dl):
        batch_loss = validate_batch(data, encoder, decoder, criterion)
        val_loss += batch_loss

    val_loss /= len(val_dl)
    print(f'Epoch [{epoch + 1}/{n_epochs}], Validation Loss: {val_loss:.4f}')

    # Evaluate model
    bleu_score = evaluate_model(encoder, decoder, val_dl)
    print(f'Epoch [{epoch + 1}/{n_epochs}], BLEU Score: {bleu_score:.4f}')

    # Save checkpoint
    save_checkpoint({
        'epoch': epoch + 1,
        'encoder_state_dict': encoder.state_dict(),
        'decoder_state_dict': decoder.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, checkpoint_path)



In [None]:
# Save final model
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch
}, os.path.join(model_dir, 'model.pth'))

# Save vocab to Google Drive
import json
with open(os.path.join(model_dir, 'vocab.json'), 'w') as f:
    json.dump({'itos': vocab.itos, 'stoi': dict(vocab.stoi)}, f)

# Load model function for inference
def load_model(encoder, decoder, optimizer, model_dir):
    encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder.pth')))
    decoder.load_state_dict(torch.load(os.path.join(model_dir, 'decoder.pth')))
    optimizer.load_state_dict(torch.load(os.path.join(model_dir, 'optimizer.pth')))
    return encoder, decoder, optimizer

# Prediction function
@torch.no_grad()
def load_image_and_predict(image_path):
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    encoder.eval()
    decoder.eval()
    features = encoder(image)
    sentence = decoder.predict(features)[0]
    return sentence




In [None]:
# Example usage
image_path = 'path_to_an_image.jpg'
caption = load_image_and_predict(image_path)
print('Generated Caption:', caption)