In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision import transforms
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import random_split
import gc
import nltk
from nltk.translate.bleu_score import sentence_bleu

In [2]:
torch.cuda.empty_cache()
gc.collect()

0

In [3]:
class Vocabulary:
    def __init__(self, freq_threshold):
        """
        freq_threshold: minimum frequency for a word to be included.
        """
        self.freq_threshold = freq_threshold
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

        # Special tokens.
        self.add_word("<PAD>")
        self.add_word("<SOS>")
        self.add_word("<EOS>")
        self.add_word("<UNK>")

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
    def tokenize(self, text):
        return str(text).lower().strip().split()


    def build_vocabulary(self, sentence_list):
        frequencies = {}
        for sentence in sentence_list:
            tokens = self.tokenize(sentence)
            for token in tokens:
                frequencies[token] = frequencies.get(token, 0) + 1

        for word, freq in frequencies.items():
            if freq >= self.freq_threshold:
                self.add_word(word)

    def numericalize(self, text):
        tokenized_text = self.tokenize(text)
        return [self.word2idx.get(token, self.word2idx["<UNK>"]) for token in tokenized_text]

In [4]:
class QuestionGenerationDataset(Dataset):
    def __init__(self, csv_file, image_folder, vocabulary, transform=None):
        """
        Args:
            csv_file: Path to the CSV file.
            image_folder: Folder containing images.
            vocabulary: An instance of the Vocabulary class.
            transform: Image transformations.
            
        The CSV is assumed to have:
          - Column 0: a URL (the image filename is the last 16 characters: 12 chars + ".jpg")
          - Column 6: the question text.
        """
        self.df = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.vocab = vocabulary
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Extract image name from the first column (last 16 characters)
        link = row.iloc[0]
        image_name = link[-16:]
        image_path = os.path.join(self.image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Get question text from column 7 (index 6)
        question = row.iloc[6]

        # Numericalize the question and add <SOS> and <EOS> tokens.
        tokens = [self.vocab.word2idx["<SOS>"]]
        tokens += self.vocab.numericalize(question)
        tokens.append(self.vocab.word2idx["<EOS>"])
        caption = torch.tensor(tokens, dtype=torch.long)
        return image, caption

In [5]:
def collate_fn(data):
    """
    Sort a batch of data by caption length (descending order) and pad sequences.
    """
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)
    images = torch.stack(images, 0)
    lengths = [len(cap) for cap in captions]
    padded_captions = rnn_utils.pad_sequence(captions, batch_first=True, padding_value=0)
    return images, padded_captions, lengths

In [6]:
class ImageEncoder(nn.Module):
    def __init__(self, embed_size):
        super(ImageEncoder, self).__init__()
        resnet = models.resnet50(pretrained=True)
        modules = list(resnet.children())[:-1]  # remove the final fully connected layer
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
    
    def forward(self, images):
        # Extract features without computing gradients.
        with torch.no_grad():
            features = self.resnet(images)
        features = features.view(features.size(0), -1)
        features = self.bn(self.fc(features))
        return features

In [7]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size
        self.num_layers = num_layers

    def forward(self, features, captions, lengths):
        """
        Args:
            features: Encoded image features of shape (batch, embed_size)
            captions: Tensor of token indices (batch, seq_length). Note that during training, we feed
                      the ground-truth captions (teacher forcing).
            lengths: List of actual lengths for each caption.
            
        The features are prepended as the first input to the LSTM.
        """
    
        embeddings = self.embed(captions)  # (batch, L, embed_size)
        # Initialize LSTM hidden state using image features.
        h0 = features.unsqueeze(0).repeat(self.num_layers, 1, 1)  # (num_layers, batch, embed_size)
        c0 = torch.zeros_like(h0)
        # Use original lengths (input captions length remains the same) after shifting targets.
        packed = rnn_utils.pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False)
        hiddens, _ = self.lstm(packed, (h0, c0))
        outputs = self.linear(hiddens[0])
        return outputs
    
    def sample(self, features, vocab, max_len=20):
        batch_size = features.size(0)
        sampled_ids = []
        # Create a batch of <SOS> tokens.
        inputs = self.embed(torch.tensor([vocab.word2idx["<SOS>"]]*batch_size).to(features.device)).unsqueeze(1)
        # Initialize LSTM hidden state with image features.
        h0 = features.unsqueeze(0).repeat(self.num_layers, 1, 1)  # Shape: (num_layers, batch_size, embed_size)
        c0 = torch.zeros_like(h0)
        states = (h0, c0)
        for i in range(max_len):
            hiddens, states = self.lstm(inputs, states)
            outputs = self.linear(hiddens.squeeze(1))
            predicted = outputs.argmax(1)
            sampled_ids.append(predicted)
            inputs = self.embed(predicted).unsqueeze(1)
        sampled_ids = torch.stack(sampled_ids, 1)
        return sampled_ids



In [8]:
def validate(encoder, decoder, dataloader, vocab, device, max_len=20):
    """
    Runs the model on the validation set and computes the average BLEU score.
    """
    encoder.eval()
    decoder.eval()
    total_bleu = 0.0
    total_examples = 0
    with torch.no_grad():
        for images, captions, lengths in dataloader:
            images = images.to(device)
            features = encoder(images)
            #sampled_ids = decoder.sample(features, max_len=max_len)
            sampled_ids = decoder.sample(features, vocab, max_len=max_len)
            for i in range(images.size(0)):
                sampled_seq = sampled_ids[i].cpu().numpy()
                # Convert predicted ids to words; stop at <EOS>
                pred_words = []
                for word_id in sampled_seq:
                    word = vocab.idx2word[word_id]
                    if word == "<EOS>":
                        break
                    pred_words.append(word)
                # Process the ground-truth caption: remove <SOS> and stop at <EOS>
                true_caption = captions[i].cpu().numpy()
                true_words = []
                for word_id in true_caption:
                    word = vocab.idx2word.get(int(word_id), "<UNK>")
                    if word == "<SOS>":
                        continue
                    if word == "<EOS>":
                        break
                    true_words.append(word)
                # Compute sentence-level BLEU score.
                bleu = sentence_bleu([true_words], pred_words, 
                                     smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method1)
                total_bleu += bleu
                total_examples += 1
    avg_bleu = total_bleu / total_examples if total_examples > 0 else 0
    encoder.train()
    decoder.train()
    return avg_bleu

In [9]:
def main():
    # Hyperparameters.
    embed_size = 256
    hidden_size = 256
    num_layers = 1
    num_epochs = 5
    batch_size = 32
    learning_rate = 1e-3
    freq_threshold = 5  # Only include words that occur at least 5 times.

    # Image transformations.
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Build vocabulary from the CSV.
    csv_file = r"C:\Users\krish\Downloads\train.csv"
    df = pd.read_csv(csv_file)
    questions = df.iloc[:, 6].tolist()  # Column 7 (index 6) has question text.
    vocab = Vocabulary(freq_threshold)
    vocab.build_vocabulary(questions)
    print("Vocabulary size:", len(vocab.word2idx))

    # Create the full dataset.
    image_folder = r"C:\Users\krish\Downloads\train"
    full_dataset = QuestionGenerationDataset(csv_file, image_folder, vocab, transform)

    # Split dataset into training (90%) and validation (10%) sets.
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    # Device configuration.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize encoder and decoder.
    encoder = ImageEncoder(embed_size).to(device)
    decoder = DecoderRNN(embed_size, hidden_size, len(vocab.word2idx), num_layers).to(device)

    # Only update decoder parameters and the encoder's fc and bn layers.
    params = list(decoder.parameters()) + list(encoder.fc.parameters()) + list(encoder.bn.parameters())
    optimizer = optim.Adam(params, lr=learning_rate)
    # Use CrossEntropyLoss and ignore the padding index.
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<PAD>"])

    for epoch in range(num_epochs):
        encoder.train()
        decoder.train()
        epoch_loss = 0.0
        total_samples = 0

        for images, captions, lengths in train_loader:
            images = images.to(device)
            captions = captions.to(device)

            # === Teacher Forcing Changes (Third Change) ===
            # Input is captions excluding the last token; target is captions excluding the first token.
            inputs = captions[:, :-1]
            targets = captions[:, 1:]
            # Adjust lengths: each length is reduced by 1.
            input_lengths = [l - 1 for l in lengths]
            # Pack targets with adjusted lengths.
            targets_packed = rnn_utils.pack_padded_sequence(targets, input_lengths, batch_first=True, enforce_sorted=False)[0]

            # Forward pass.
            features = encoder(images)
            outputs = decoder(features, inputs, input_lengths)
            loss = criterion(outputs, targets_packed)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * images.size(0)
            total_samples += images.size(0)

        avg_train_loss = epoch_loss / total_samples
        avg_bleu = validate(encoder, decoder, val_loader, vocab, device, max_len=20)
        print(f"Epoch [{epoch+1}/{num_epochs}] - Avg Train Loss: {avg_train_loss:.4f}, Avg BLEU: {avg_bleu:.4f}")

    # Save the model and vocabulary.
    torch.save({'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'vocab': vocab}, 'question_generation.pth')

    # Generate a question for a sample image.
    sample_image_path = os.path.join(image_folder, df.iloc[0, 0][-16:])
    question = generate_question(sample_image_path, encoder, decoder, vocab, transform, device)
    print("Generated question:", question)


In [10]:
def generate_question(image_path, encoder, decoder, vocab, transform, device, max_len=20):
    """
    Given an image file, generate a question using the trained model.
    """
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension.
    image = image.to(device)
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        feature = encoder(image)
        sampled_ids = decoder.sample(feature, max_len=max_len)
    sampled_ids = sampled_ids[0].cpu().numpy()
    # Convert word IDs back to words.
    words = []
    for word_id in sampled_ids:
        word = vocab.idx2word[word_id]
        if word == "<EOS>":
            break
        words.append(word)
    question = ' '.join(words)
    return question

In [11]:
if __name__ == "__main__":
    main()

Vocabulary size: 622




Epoch [1/5] - Avg Train Loss: 3.1641, Avg BLEU: 0.0014
Epoch [2/5] - Avg Train Loss: 3.0878, Avg BLEU: 0.0015


KeyboardInterrupt: 