# Imports

In [1]:
import os
import zipfile
import nltk
import json
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset, load_from_disk
from pprint import pprint
from PIL import Image

import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from huggingface_hub import hf_hub_download

nltk.download('punkt_tab', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

  from .autonotebook import tqdm as notebook_tqdm


True

# Constants

### Dataset Contstants

In [2]:
IMG_SIZE = (224, 224)
VOCAB_SIZE = 5000
BATCH_SIZE = 32
MAX_NODES_PER_QUESTION = 10

# Directory Information
DATA_DIR = "data/"
DATASET_PATH = os.path.join(DATA_DIR, 'dataset/')
IMAGE_PATH = os.path.join(DATA_DIR, 'imgs/')
VOCABS_PATH = os.path.join(DATA_DIR, 'vocabs/')

# Huggingface Repository Information
repo_id = "BoKelvin/SLAKE"
repo_type = "dataset"
img_file = "imgs.zip"

# Entity Extraction
MAX_PHRASE_LENGTH = 5 # Look up to 5-grams for entity matching

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


# Dataset Setup

### Data Download

In [3]:
# Utility function for downloading and extracting ZIP file
def download_and_store_ZIP(filename, save_dir):
    print(f"Fetching file {filename} from {repo_id} repo")

    try:
        # Caches the file locally and returns the path to the cached file
        cached_zip_path = hf_hub_download(
          repo_id=repo_id,
          filename=filename,
          repo_type=repo_type
        )
        print(f"{filename} download complete. Cached at: {cached_zip_path}")

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Extract the contents
        print(f"Extracting to {save_dir}...")
        with zipfile.ZipFile(cached_zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_dir)

        print("Extraction complete.")
        print(f"{filename} files are located in: {os.path.abspath(save_dir)}")
    except Exception as e:
        print(f"Failed to download or extract {filename}: {e}")

# Scoping to English only
def filter_language(original):
    return original.filter(lambda data: data['q_lang'] == 'en')

# Download and store the dataset
def download_and_store_english_dataset():
    print(f"Downloading dataset from {repo_id} repo")

    # Load from Hugging Face
    original = load_dataset(repo_id)

    # Scope to English Only
    original = filter_language(original)

    # Show the dataset formatting
    pprint(original)

    # Save the original dataset
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    if not os.path.exists(DATASET_PATH):
        os.makedirs(DATASET_PATH)

    original.save_to_disk(DATASET_PATH)
    return original

# Download and store the image files
def download_and_store_image():
    download_and_store_ZIP(img_file, DATA_DIR)

# Download necessary files
def download_and_store_slake():
    dataset = download_and_store_english_dataset()
    download_and_store_image()

    return dataset

### Vocabs list

In [59]:
class VocabularyBuilder:
    def __init__(self, min_freq=1):
        self.min_freq = min_freq
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def tokenize(self, text):
        return nltk.word_tokenize(text.lower())
    
    def __len__(self):
        return len(self.stoi)
    
    def build_word_vocabs(self, sentences):
        counter = Counter()
        start_index = len(self.stoi)

        # 1. Count frequencies of all tokens in the tokenized sentences
        for sentence in sentences:
            tokens = self.tokenize(sentence)
            counter.update(tokens)

        # 2. Add words that meet the frequency threshold
        for word, count in counter.items():
            if count >= self.min_freq and word not in self.stoi:
                self.stoi[word] = start_index
                self.itos[start_index] = word
                start_index += 1

        print(f"Vocabulary Built. Vocabulary Size: {len(self.stoi)}")

    def numericalize(self, text):
        tokens = self.tokenize(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
            for token in tokens
        ]

### Utilities

In [60]:
# Build vocabularies for questions and answers
def build_vocabs(dataset):
    questions = [item['question'] for item in dataset]
    answers = [item['answer'] for item in dataset]

    # Question Vocabulary
    questvocab_builder = VocabularyBuilder(min_freq=1)
    questvocab_builder.build_word_vocabs(questions)
    
    # Answer Vocabulary
    ansvocab_builder = VocabularyBuilder(min_freq=1)

    # Use a dummy tokenizer that just returns the whole lowercased string as one token
    identity_tokenizer = lambda x: [x.lower().strip()]
    ansvocab_builder.tokenize = identity_tokenizer

    ansvocab_builder.build_word_vocabs(answers)

    return questvocab_builder, ansvocab_builder

# Save vocabularies to JSON files
def save_vocabs(quest_vocab, ans_vocab):
    if not os.path.exists(VOCABS_PATH):
        os.makedirs(VOCABS_PATH)

    # Save Question Vocabulary
    with open(os.path.join(VOCABS_PATH, 'question_vocab.json'), 'w') as f:
        json.dump({'stoi': quest_vocab.stoi, 'itos': quest_vocab.itos}, f)

    # Save Answer Vocabulary
    with open(os.path.join(VOCABS_PATH, 'answer_vocab.json'), 'w') as f:
        json.dump({'stoi': ans_vocab.stoi, 'itos': ans_vocab.itos}, f)

    print("Vocabularies saved successfully.")

### Dataset Class

In [61]:
class SlakeDataset(Dataset):
    def __init__(self, dataset, question_vocab, answer_vocab, transform=None):
        self.data = dataset
        self.question_vocab = question_vocab
        self.answer_vocab = answer_vocab
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]

        # 1. Image Processing
        image_path = item['img_name']
        image_path = os.path.join(IMAGE_PATH, image_path)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # 2. Question Processing
        question = item['question']
        question_indices = self.question_vocab.numericalize(question)

        # 3. Answer Processing
        answer = str(item.get('answer', '')) # Answer may be missing in test set
        answer_index = self.answer_vocab.numericalize(answer)

        return {
            'image': image,
            'question' : torch.tensor(question_indices),
            'answer' : torch.tensor(answer_index, dtype=torch.long),
            # Add original items for reference
            'original_question': question,
            'original_answer': answer,
            # Add ID for tracking
            'id': item['qid']
        }

### Collate function

Questions have different lengths, need to pad properly to make sure the length is constant.

In [74]:
def slake_collate_fn(batch, pad_index=0):
    # Separate different components
    images = []
    questions = []
    answers = []
    original_questions = []
    original_answers = []
    ids = []
    
    for item in batch:
        images.append(item['image'])
        questions.append(item['question'])
        answers.append(item['answer'])
        original_questions.append(item['original_question'])
        original_answers.append(item['original_answer'])
        ids.append(item['id'])
    
    # Stack images
    images = torch.stack(images)  # [batch_size, 3, H, W]
    
    # Get question lengths BEFORE padding
    question_lengths = torch.tensor([len(q) for q in questions])
    
    # Pad questions to the longest sequence in THIS batch
    # pad_sequence expects list of tensors, pads with 0 by default
    questions_padded = pad_sequence(questions, batch_first=True, padding_value=pad_index)
    # questions_padded: [batch_size, max_len_in_batch]
    
    # Handling answers
    # Handling each answer as a single class
    # answers = torch.stack(answers)
    answers = torch.tensor([item['answer'] for item in batch])
    
    return {
        'image': images,
        'question': questions_padded,
        'question_lengths': question_lengths,
        'answer': answers,
        'original_question': original_questions,
        'original_answer': original_answers,
        'id': ids
    }

## Getting everything ready

In [75]:
# Comment out if dataset is already downloaded
# dataset = download_and_store_slake()

# Uncomment if dataset is already downloaded
dataset = load_from_disk(DATASET_PATH)

# Build vocabularies for training
train_data = dataset['train']
question_vocab, answer_vocab = build_vocabs(train_data)

# Define image transformations
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create train dataset and dataloader
train_dataset = SlakeDataset(train_data, question_vocab, answer_vocab, transform=transform)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=slake_collate_fn
)

Vocabulary Built. Vocabulary Size: 281
Vocabulary Built. Vocabulary Size: 225


## Modeling (Baseline)

1. Basic CNN-LSTM </br>
2. CNN-Bidirectional LSTM with self attention


In [64]:
class MedicalVQABaseline(nn.Module):
    def __init__(self, vocab_size, num_classes, embed_dim=256, hidden_dim=512):
        super(MedicalVQABaseline, self).__init__()

        # 1. CNN - ResNet
        resnet = models.resnet34(pretrained=True)
        num_features = 512
        # Remove the last classification layer
        self.resnet_features = nn.Sequential(*list(resnet.children())[:-1])
        self.img_projector = nn.Linear(num_features, hidden_dim)
        self.bn_img = nn.BatchNorm1d(hidden_dim)

        # 2. LSTM
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

        # 4. Classifier
        fusion_dim = hidden_dim * 2

        self.attention = nn.Linear(embed_dim, 1)
        self.kg_gate = nn.Parameter(torch.tensor(0.0))

        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, images, questions, question_lengths=None):
        # CNN
        # Extract features
        img_feats = self.resnet_features(images).view(images.size(0), -1)
        img_feats = self.img_projector(img_feats)
        img_feats = self.bn_img(img_feats) # Normalize
        img_feats = torch.relu(img_feats)

        # LSTM
        embeds = self.embedding(questions) # (Batch, Seq, Embed_Dim)
        # LSTM output: (Batch, Seq, Hidden), (h_n, c_n)
        # Take and modify the final hidden state h_n: (1, Batch, Hidden)
        _, (h_n, _) = self.lstm(embeds)
        text_feats = h_n.squeeze(0) # (Batch, Hidden)
        
        # Fusion
        combined = torch.cat((img_feats, text_feats), dim=1)

        # Classification
        logits = self.classifier(combined)
        return logits


In [65]:
# Bidirectional LSTM with Self-Attention for question encoding
class BiLSTMWithSelfAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=300, hidden_dim=512, num_layers=1, dropout=0.5, pooling_strategy='mean'):
        super(BiLSTMWithSelfAttention, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.pooling_strategy = pooling_strategy
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        
        # Bidirectional LSTM
        self.bilstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Self-attention mechanism
        # BiLSTM outputs hidden_dim * 2 (forward + backward)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim * 2,
            num_heads=8,
            dropout=dropout,
            batch_first=True
        )
        
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        
    def forward(self, questions, question_lengths=None):
        # Embed questions
        embeds = self.embedding(questions)  # [B, seq_len, embed_dim]
        embeds = self.dropout(embeds)
        
        # Pack sequence if lengths provided (for efficiency)
        if question_lengths is not None:
            embeds = nn.utils.rnn.pack_padded_sequence(
                embeds, question_lengths.cpu(), 
                batch_first=True, enforce_sorted=False
            )
        
        # BiLSTM encoding
        lstm_out, (hidden, cell) = self.bilstm(embeds)
        
        # Unpack if needed
        if question_lengths is not None:
            lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
                lstm_out, batch_first=True
            )
        
        # lstm_out: [B, seq_len, hidden_dim * 2]
        
        # Self-attention: query = key = value = lstm_out
        attn_out, attn_weights = self.attention(
            query=lstm_out,
            key=lstm_out,
            value=lstm_out,
            need_weights=True
        )
        
        # Residual connection + Layer Norm
        attn_out = self.layer_norm(lstm_out + attn_out)
        attn_out = self.dropout(attn_out)
        
        # Pooling strategy - you can experiment with these:
        if self.pooling_strategy == 'mean':
            question_feature = attn_out.mean(dim=1)  # [B, hidden_dim * 2]
        elif self.pooling_strategy == 'max':
            question_feature = attn_out.max(dim=1)[0]
        else:
            # Last hidden state (concatenate forward and backward)
            question_feature = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        return question_feature, attn_weights

In [66]:
# Complete VQA model: ResNet34 + BiLSTM with Self-Attention
class VQA_ResNet_BiLSTM_Attention(nn.Module):
    def __init__(self, vocab_size, num_classes, embed_dim=300, 
                 lstm_hidden=512, fusion_dim=1024, dropout=0.5):
        super(VQA_ResNet_BiLSTM_Attention, self).__init__()
        
        # Image encoder: ResNet34
        resnet = models.resnet34(pretrained=True)
        # Remove the final FC layer
        self.image_encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.image_feature_dim = 512  # ResNet34 final layer
        
        # Question encoder: BiLSTM + Self-Attention
        self.question_encoder = BiLSTMWithSelfAttention(
            vocab_size=vocab_size,
            embed_dim=embed_dim,
            hidden_dim=lstm_hidden,
            num_layers=2,  # 2-layer BiLSTM
            dropout=dropout
        )
        self.question_feature_dim = lstm_hidden * 2  # Bidirectional
        
        # Multimodal fusion
        self.fusion = nn.Sequential(
            nn.Linear(self.image_feature_dim + self.question_feature_dim, fusion_dim),
            nn.BatchNorm1d(fusion_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.BatchNorm1d(fusion_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        # Classifier
        self.classifier = nn.Linear(fusion_dim // 2, num_classes)
        
    def forward(self, images, questions, question_lengths=None):
        # Extract image features
        img_features = self.image_encoder(images)  # [B, 512, 1, 1]
        img_features = img_features.squeeze(-1).squeeze(-1)  # [B, 512]
        
        # Extract question features with attention
        q_features, attn_weights = self.question_encoder(questions, question_lengths) # [B, lstm_hidden * 2]
        
        # Concatenate image and question features
        combined = torch.cat([img_features, q_features], dim=1)
        # combined: [B, 512 + lstm_hidden*2]
        
        # Fusion
        fused = self.fusion(combined)  # [B, fusion_dim // 2]
        
        # Classification
        logits = self.classifier(fused)  # [B, num_classes]
        
        return logits

## Training-Validation-Testing and Hyperparameter Tuning

Currently for testing purpose I have not added any validation section

In [82]:
def calculate_accuracy(predictions, targets):
    # Get the index of the max log-probability
    _, preds = torch.max(predictions, dim=1)
    correct = (preds == targets).float()
    acc = correct.sum() / len(correct)
    return acc

def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    # tqdm creates a nice progress bar
    loop = tqdm(loader, total=len(loader), leave=True)
    
    for batch in loop:
        batch_idx = loop.n

        # 1. Move data to Device
        images = batch['image'].to(device)
        questions = batch['question'].to(device)
        question_lengths = batch['question_lengths'].to(device)
        answers = batch['answer'].to(device)

        # 2. Forward Pass
        # Forward pass - now with question_lengths!
        logits = model(images, questions, question_lengths)
        
        # 3. Calculate Loss
        loss = criterion(logits, answers)
        
        # 4. Backward Pass
        optimizer.zero_grad() # Clear previous gradients
        loss.backward()       # Compute gradients

        # Gradient clipping (important for LSTMs)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

        optimizer.step()      # Update weights
        
        # 5. Metrics
        total_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == answers).sum().item()
        total += answers.size(0)
        
        # 6. Update progress bar
        acc = calculate_accuracy(logits, answers)
        loop.set_description(f"Train")
        loop.set_postfix(loss=loss.item(), acc=acc.item())
        
    avg_loss = total_loss / len(loader)
    accuracy = 100 * correct / total
    
    return avg_loss, accuracy

In [83]:
# Testing purpose: Train one epoch
model_basic = MedicalVQABaseline(
    vocab_size=len(question_vocab),
    num_classes=len(answer_vocab)
).to(device)

# 2. FREEZE the ResNet
for param in model_basic.resnet_features.parameters():
    param.requires_grad = False

# Fine-tune only later layers instead of freezing it fully
# for param in list(self.resnet_features.parameters())[:-20]:
#     param.requires_grad = False

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_basic.parameters()), lr=0.0003)
criterion = nn.CrossEntropyLoss()

In [84]:
for i in range(5):
    train_loss, train_acc = train_one_epoch(model_basic, train_loader, criterion, optimizer)

Train: 100%|██████████| 154/154 [00:33<00:00,  4.56it/s, acc=0.304, loss=2.67]
Train: 100%|██████████| 154/154 [00:33<00:00,  4.59it/s, acc=0.522, loss=1.79]
Train: 100%|██████████| 154/154 [00:34<00:00,  4.42it/s, acc=0.478, loss=1.53] 
Train: 100%|██████████| 154/154 [00:36<00:00,  4.28it/s, acc=0.609, loss=1.55] 
Train: 100%|██████████| 154/154 [00:36<00:00,  4.27it/s, acc=0.652, loss=0.839]


In [85]:
new_model = MedicalVQABaseline(
    vocab_size=len(question_vocab),
    num_classes=len(answer_vocab)
).to(device)

# 2. FREEZE the ResNet (Crucial for Baseline stability)
for param in new_model.resnet_features.parameters():
    param.requires_grad = False

optimizer_new = torch.optim.Adam(filter(lambda p: p.requires_grad, new_model.parameters()), lr=0.0003)
criterion_new = nn.CrossEntropyLoss()

In [86]:
for i in range(5):
    train_loss, train_acc = train_one_epoch(new_model, train_loader, criterion_new, optimizer_new)

Train: 100%|██████████| 154/154 [00:34<00:00,  4.42it/s, acc=0.391, loss=2.01] 
Train: 100%|██████████| 154/154 [00:38<00:00,  4.04it/s, acc=0.435, loss=1.68]
Train: 100%|██████████| 154/154 [00:37<00:00,  4.08it/s, acc=0.652, loss=1.17] 
Train: 100%|██████████| 154/154 [00:37<00:00,  4.11it/s, acc=0.522, loss=1.37] 
Train: 100%|██████████| 154/154 [00:38<00:00,  4.05it/s, acc=0.478, loss=1.23] 
