In [None]:
# Google Colab setup (auto-detects environment)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Install packages
    print("Installing packages...")
    !pip install -q torch torchvision tqdm pyyaml scikit-learn pandas matplotlib seaborn
    
    # Set project path
    import os
    PROJECT_PATH = "/content/drive/MyDrive/WOA7015 Advanced Machine Learning/data"
    os.chdir(PROJECT_PATH)
    print(f"✓ Running on Colab - Path: {PROJECT_PATH}")
    
except ImportError:
    # Running locally
    PROJECT_PATH = None
    print("✓ Running locally")

## 0. Setup (Run this first)

**For Google Colab**: This cell will automatically mount your Drive and install packages.
**For Local**: This cell will skip Colab-specific setup.

# Text-Only VQA Baseline Training

**Goal**: Train a language model to answer medical questions based on text only (without images)

This notebook walks through:
1. Loading and exploring the data
2. Building vocabulary from questions
3. Creating an LSTM model
4. Training the model
5. Evaluating performance
6. Analyzing results

**Note**: Can run on Google Colab (free tier) - no GPU needed for text-only baseline!

## 1. Setup and Imports

In [None]:
# Force reload src modules (run this if you see import errors after updating code)
import sys
modules_to_remove = [m for m in sys.modules if m.startswith('src')]
for module in modules_to_remove:
    del sys.modules[module]
print(f"✓ Cleared {len(modules_to_remove)} cached modules")

In [None]:
import sys
import os
from pathlib import Path

# Setup paths
if 'PROJECT_PATH' in globals() and PROJECT_PATH:
    # Colab environment
    project_root = Path(PROJECT_PATH)
else:
    # Local environment
    project_root = Path().absolute().parent

sys.path.insert(0, str(project_root))

import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import yaml
from tqdm.notebook import tqdm

# Import our modules
from src.data.dataset import TextOnlyVQADataset, create_text_dataloaders
from src.models.text_model import LSTMTextModel, create_text_model
from src.training.trainer import TextVQATrainer
from src.evaluation.metrics import VQAMetrics, calculate_accuracy

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Imports successful")
print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2. Load Configuration

In [None]:
# Load lightweight configuration (or create default if not exists)
config_path = project_root / "config_lightweight.yaml"

if config_path.exists():
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    print("✓ Loaded config_lightweight.yaml")
else:
    # Create default lightweight config for Colab
    print("⚠️  config_lightweight.yaml not found, using default settings")
    config = {
        'data': {'train_csv': 'trainrenamed.csv', 'test_csv': 'testrenamed.csv', 
                 'answers_file': 'answers.txt', 'val_split': 0.15},
        'text': {'max_length': 32, 'embedding_dim': 128},
        'model': {'baseline': {'hidden_dim': 256, 'num_layers': 1, 'dropout': 0.2}},
        'training': {'batch_size': 8, 'num_epochs': 10, 'learning_rate': 0.001, 
                    'weight_decay': 0.0001, 'scheduler': 'step', 'gradient_clip': 1.0,
                    'early_stopping_patience': 5},
        'paths': {'checkpoints': 'checkpoints'},
        'seed': 42
    }

print("\nConfiguration:")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['num_epochs']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Hidden dim: {config['model']['baseline']['hidden_dim']}")
print(f"  Embedding dim: {config['text']['embedding_dim']}")
print(f"  Max sequence length: {config['text']['max_length']}")

## 3. Explore the Data

In [None]:
# Load training data
train_df = pd.read_csv(project_root / config['data']['train_csv'])
test_df = pd.read_csv(project_root / config['data']['test_csv'])

print(f"Training samples: {len(train_df):,}")
print(f"Test samples: {len(test_df):,}")
print(f"\nColumns: {list(train_df.columns)}")
print(f"\nFirst few samples:")
train_df.head()

In [None]:
# Analyze questions
train_df['question_length'] = train_df['question'].str.split().str.len()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Question length distribution
axes[0].hist(train_df['question_length'], bins=50, edgecolor='black')
axes[0].set_xlabel('Question Length (words)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of Question Lengths')
axes[0].axvline(train_df['question_length'].mean(), color='red', linestyle='--', label=f'Mean: {train_df["question_length"].mean():.1f}')
axes[0].legend()

# Answer distribution (top 20)
top_answers = train_df['answer'].value_counts().head(20)
axes[1].barh(range(len(top_answers)), top_answers.values)
axes[1].set_yticks(range(len(top_answers)))
axes[1].set_yticklabels(top_answers.index, fontsize=8)
axes[1].set_xlabel('Count')
axes[1].set_title('Top 20 Most Common Answers')
axes[1].invert_yaxis()

plt.tight_layout()
plt.show()

print(f"Question length statistics:")
print(f"  Mean: {train_df['question_length'].mean():.2f} words")
print(f"  Median: {train_df['question_length'].median():.0f} words")
print(f"  Max: {train_df['question_length'].max():.0f} words")
print(f"\nUnique answers: {train_df['answer'].nunique():,}")

In [None]:
# Analyze question types
def classify_question_type(answer):
    answer_lower = str(answer).lower().strip()
    if answer_lower in ['yes', 'no']:
        return 'yes/no'
    else:
        return 'open-ended'

train_df['question_type'] = train_df['answer'].apply(classify_question_type)

question_type_counts = train_df['question_type'].value_counts()

plt.figure(figsize=(8, 6))
plt.pie(question_type_counts.values, labels=question_type_counts.index, 
        autopct='%1.1f%%', startangle=90)
plt.title('Question Types Distribution')
plt.show()

print("Question type breakdown:")
for qtype, count in question_type_counts.items():
    print(f"  {qtype}: {count:,} ({count/len(train_df)*100:.1f}%)")

## 4. Create Dataloaders

In [None]:
# Create dataloaders
print("Creating dataloaders...")

train_loader, val_loader, test_loader, vocab_size, num_classes, vocab = create_text_dataloaders(
    train_csv=str(project_root / config['data']['train_csv']),
    test_csv=str(project_root / config['data']['test_csv']),
    answers_file=str(project_root / config['data']['answers_file']),
    batch_size=config['training']['batch_size'],
    val_split=config['data']['val_split'],
    num_workers=0,
    max_length=config['text']['max_length']
)

print(f"\n✓ Dataloaders created:")
print(f"  Vocabulary size: {vocab_size:,}")
print(f"  Number of answer classes: {num_classes:,}")
print(f"  Train batches: {len(train_loader):,}")
print(f"  Val batches: {len(val_loader):,}")
print(f"  Test batches: {len(test_loader):,}")

In [1]:
# Inspect a sample batch
sample_batch = next(iter(train_loader))

print("Sample batch:")
print(f"  Question tensor shape: {sample_batch['question'].shape}")
print(f"  Answer tensor shape: {sample_batch['answer'].shape}")
print(f"\nFirst question (encoded): {sample_batch['question'][0][:20].tolist()}...")
print(f"First question (text): {sample_batch['question_text'][0]}")
print(f"First answer: {sample_batch['answer_text'][0]}")

NameError: name 'train_loader' is not defined

In [None]:
# Show vocabulary statistics
print("Vocabulary sample (first 20 tokens):")
sorted_vocab = sorted(vocab.items(), key=lambda x: x[1])[:20]
for token, idx in sorted_vocab:
    print(f"  {idx:3d}: '{token}'")

## 5. Create the Model

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = create_text_model(
    model_type='lstm',
    vocab_size=vocab_size,
    num_classes=num_classes,
    embedding_dim=config['text']['embedding_dim'],
    hidden_dim=config['model']['baseline']['hidden_dim'],
    num_layers=config['model']['baseline']['num_layers'],
    dropout=config['model']['baseline']['dropout'],
    bidirectional=True
)

model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / (1024**2):.2f} MB")
print(f"\nModel architecture:")
print(model)

In [None]:
# Test forward pass
model.eval()
with torch.no_grad():
    test_questions = sample_batch['question'][:4].to(device)
    test_outputs = model(test_questions)
    test_probs = torch.softmax(test_outputs, dim=1)
    test_preds = test_outputs.argmax(dim=1)

print("Test forward pass:")
print(f"  Input shape: {test_questions.shape}")
print(f"  Output shape: {test_outputs.shape}")
print(f"  Predictions: {test_preds.cpu().numpy()}")
print(f"  Max probabilities: {test_probs.max(dim=1)[0].cpu().numpy()}")
print("\n✓ Model is working correctly!")

## 6. Training Setup

In [None]:
# Set random seed for reproducibility
torch.manual_seed(config['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(config['seed'])

# Create trainer
trainer = TextVQATrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device,
    checkpoint_dir=str(project_root / config['paths']['checkpoints']),
    experiment_name="text_baseline_lstm_notebook"
)

print("✓ Trainer initialized")
print(f"  Checkpoints will be saved to: {trainer.checkpoint_dir}")

## 7. Train the Model

**Note**: This will take some time depending on your laptop specs:
- **Quick test (2 epochs)**: ~5-10 minutes
- **Full training (10 epochs)**: ~30-60 minutes

In [None]:
# Optional: Reduce epochs for quick testing
# config['training']['num_epochs'] = 2  # Uncomment for quick test

# Start training
print("Starting training...\n")
trainer.train()

print(f"\n✓ Training completed!")
print(f"Best validation accuracy: {trainer.best_val_acc:.4f}")

## 8. Evaluate on Test Set

In [None]:
# Load best model
best_model_path = trainer.checkpoint_dir / "best_model.pth"
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

print("✓ Loaded best model")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  Best val accuracy: {checkpoint['best_val_acc']:.4f}")

In [None]:
# Evaluate on test set
model.eval()
test_metrics = VQAMetrics(num_classes)

print("Evaluating on test set...\n")

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        questions = batch['question'].to(device)
        answers = batch['answer'].to(device)
        
        outputs = model(questions)
        predictions = outputs.argmax(dim=1)
        
        test_metrics.update(
            predictions,
            answers,
            batch.get('question_text'),
            batch.get('answer_text')
        )

# Compute metrics
metrics = test_metrics.compute()

print("\nTest Set Results:")
print("=" * 50)
print(f"Accuracy:           {metrics['accuracy']:.4f}")
print(f"F1 Score (macro):   {metrics['f1_macro']:.4f}")
print(f"F1 Score (weighted):{metrics['f1_weighted']:.4f}")
print(f"Precision (macro):  {metrics['precision_macro']:.4f}")
print(f"Recall (macro):     {metrics['recall_macro']:.4f}")
print(f"Exact Match:        {metrics['exact_match']:.4f}")

## 9. Analyze Results by Question Type

In [None]:
# Per-question-type analysis
per_type_metrics = test_metrics.compute_per_question_type()

if per_type_metrics:
    print("\nPerformance by Question Type:")
    print("=" * 50)
    
    for qtype, stats in per_type_metrics.items():
        print(f"\n{qtype.upper()}:")
        print(f"  Count: {stats['count']:,}")
        print(f"  Accuracy: {stats['accuracy']:.4f}")
    
    # Visualize
    fig, ax = plt.subplots(figsize=(8, 5))
    types = list(per_type_metrics.keys())
    accs = [per_type_metrics[t]['accuracy'] for t in types]
    
    bars = ax.bar(types, accs, color=['skyblue', 'coral'])
    ax.set_ylabel('Accuracy')
    ax.set_title('Accuracy by Question Type')
    ax.set_ylim(0, 1)
    
    # Add value labels
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

## 10. Error Analysis

In [None]:
# Get confusion statistics
confusion_stats = test_metrics.get_confusion_stats(top_k=10)

print("\nTop 10 Most Confused Answer Pairs:")
print("=" * 70)
print(f"{'True Answer':<25} {'Predicted Answer':<25} {'Count':>10}")
print("-" * 70)

# Get answer vocabulary
with open(project_root / config['data']['answers_file'], 'r', encoding='utf-8') as f:
    answers = [line.strip() for line in f.readlines()]

for (true_idx, pred_idx), count in confusion_stats['top_confusions']:
    true_ans = answers[true_idx] if true_idx < len(answers) else f"idx_{true_idx}"
    pred_ans = answers[pred_idx] if pred_idx < len(answers) else f"idx_{pred_idx}"
    print(f"{true_ans:<25} {pred_ans:<25} {count:>10}")

print(f"\nTotal errors: {confusion_stats['total_errors']:,}")

## 11. Sample Predictions

In [None]:
# Show some example predictions
model.eval()
sample_batch = next(iter(test_loader))

with torch.no_grad():
    questions = sample_batch['question'].to(device)
    outputs = model(questions)
    predictions, probabilities = model.predict(questions)

# Show first 10 examples
print("Sample Predictions:")
print("=" * 100)

for i in range(min(10, len(sample_batch['question_text']))):
    question = sample_batch['question_text'][i]
    true_answer = sample_batch['answer_text'][i]
    pred_idx = predictions[i].item()
    pred_answer = answers[pred_idx] if pred_idx < len(answers) else f"idx_{pred_idx}"
    confidence = probabilities[i].max().item()
    
    is_correct = "✓" if pred_answer.lower() == true_answer.lower() else "✗"
    
    print(f"\n{is_correct} Example {i+1}:")
    print(f"  Question: {question}")
    print(f"  True Answer: {true_answer}")
    print(f"  Predicted: {pred_answer} (confidence: {confidence:.3f})")

## 12. Save Results Summary

In [None]:
# Create results summary
results_summary = {
    'model_type': 'LSTM',
    'vocab_size': vocab_size,
    'num_classes': num_classes,
    'total_parameters': total_params,
    'test_accuracy': metrics['accuracy'],
    'test_f1_macro': metrics['f1_macro'],
    'test_f1_weighted': metrics['f1_weighted'],
    'test_precision': metrics['precision_macro'],
    'test_recall': metrics['recall_macro'],
    'best_val_accuracy': trainer.best_val_acc,
}

# Save to file
results_dir = project_root / 'results'
results_dir.mkdir(exist_ok=True)

import json
with open(results_dir / 'text_baseline_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print("✓ Results saved to results/text_baseline_results.json")
print("\nSummary:")
for key, value in results_summary.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

## 13. Conclusion and Next Steps

### What we've achieved:
- ✅ Built vocabulary from medical questions
- ✅ Trained LSTM model to predict answers from text only
- ✅ Evaluated performance on test set
- ✅ Analyzed results by question type

### Expected Performance:
- **Text-only baseline**: 15-30% accuracy
- This is significantly better than random (0.02% for 4,593 classes)
- But limited without visual information!

### Next Steps:
1. **Add Vision**: Implement CNN image encoder
2. **Multimodal Fusion**: Combine text + image features
3. **Attention Mechanisms**: Let model focus on relevant image regions
4. **Pre-trained VLMs**: Use BLIP, CLIP, or similar models
5. **Fine-tuning**: Optimize on PathVQA dataset

### Model Checkpoint:
Your trained model is saved at:
```
checkpoints/text_baseline_lstm_notebook/best_model.pth
```

You can load it later for inference or further training!