# MedQA Fine-Tuning: Teaching an LLM to Pass a Medical Exam

**Task:** Fine-tune Mistral-7B-Instruct-v0.3 on USMLE-style medical multiple-choice questions using QLoRA.

**Dataset:** MedQA (12,723 examples from US Medical Licensing Examination)

**Metric:** Accuracy (4-option MCQ)

**CS614 - Gen AI with LLMs | Individual Assignment**

## 0. Setup & Installation

In [None]:
# ── Google Colab Setup ────────────────────────────────────
# Run this cell first on Colab to clone the repo and install deps.
# Locally, this cell is a no-op.

import os

IN_COLAB = "COLAB_GPU" in os.environ or "COLAB_RELEASE_TAG" in os.environ

if IN_COLAB:
    from getpass import getpass

    # Clone repo (only if not already cloned)
    if not os.path.exists("/content/cs614-assignment-1"):
        token = getpass("GitHub PAT: ")
        !git clone https://{token}@github.com/ikhwanwahid/cs614-assignment-1.git /content/cs614-assignment-1
        del token  # don't keep token in memory

    # Change to project root
    os.chdir("/content/cs614-assignment-1")

    # Install dependencies
    !pip install -q transformers datasets peft bitsandbytes trl accelerate
    !pip install -q scikit-learn matplotlib pandas tqdm

import sys

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..")) if not IN_COLAB else os.getcwd()
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")
print(f"Running on Colab: {IN_COLAB}")

## 1. Imports & Configuration

In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

from configs.hyperparams import ExperimentConfig, get_all_configs
from src.data_loader import (
    format_for_inference,
    format_for_training,
    load_medqa_dataset,
    prepare_datasets,
    LABEL_MAP,
)
from src.model_loader import load_base_model_and_tokenizer, load_finetuned_model
from src.trainer import create_trainer, train_and_save
from src.evaluator import (
    compare_models,
    confidence_calibration,
    error_analysis,
    evaluate_on_dataset,
)
from src.baselines import run_few_shot_baseline, run_zero_shot_baseline
from src.topic_classifier import classify_dataset
from src.utils import (
    plot_all_configs_comparison,
    plot_calibration_curve,
    plot_error_taxonomy,
    plot_per_topic_accuracy,
    plot_training_curves,
    save_results_json,
    set_seed,
    setup_device,
)

set_seed(42)
device = setup_device()

## 2. Load & Explore Data

In [None]:
train_raw, val_raw, test_raw = load_medqa_dataset()

print(f"Train: {len(train_raw):,} examples")
print(f"Val:   {len(val_raw):,} examples")
print(f"Test:  {len(test_raw):,} examples")
print(f"\nColumns: {train_raw.column_names}")

In [None]:
# Show example questions
for i in range(3):
    ex = test_raw[i]
    print(f"\n{'='*60}")
    print(f"Example {i+1} | Answer: {LABEL_MAP[ex['label']]}")
    print(f"{'='*60}")
    print(format_for_inference(ex)[:500])

In [None]:
# Answer label distribution
train_labels = [LABEL_MAP[ex["label"]] for ex in train_raw]
label_counts = Counter(train_labels)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Label distribution
axes[0].bar(label_counts.keys(), label_counts.values(), color="#4a90d9", alpha=0.8)
axes[0].set_title("Answer Label Distribution (Train)")
axes[0].set_ylabel("Count")

# Topic distribution
test_topics = classify_dataset(test_raw)
topic_counts = Counter(test_topics)
topics_sorted = sorted(topic_counts.items(), key=lambda x: x[1], reverse=True)
axes[1].barh([t[0] for t in topics_sorted], [t[1] for t in topics_sorted],
             color="#4a90d9", alpha=0.8)
axes[1].set_title("Topic Distribution (Test Set)")
axes[1].set_xlabel("Count")

plt.tight_layout()
plt.show()

print(f"\nTopics in test set:")
for topic, count in topics_sorted:
    print(f"  {topic}: {count} ({count/len(test_raw)*100:.1f}%)")

## 3. Zero-Shot Baseline (Base Model)

Evaluate the base Mistral-7B-Instruct model with no fine-tuning.

In [None]:
base_config = ExperimentConfig(name="base_model")
base_model, base_tokenizer = load_base_model_and_tokenizer(base_config, for_training=False)
# Set left padding for generation
base_tokenizer.padding_side = "left"

In [None]:
zero_shot_results = run_zero_shot_baseline(
    base_model, base_tokenizer, test_raw,
    topic_labels=test_topics, batch_size=8,
)

print(f"Zero-shot accuracy: {zero_shot_results['overall_accuracy']:.4f}")
print(f"Zero-shot macro F1: {zero_shot_results['macro_f1']:.4f}")
print(f"Extraction failure rate: {zero_shot_results['extraction_failure_rate']:.4f}")

save_results_json(zero_shot_results, os.path.join(PROJECT_ROOT, "results/zero_shot_results.json"))

## 4. Few-Shot Baseline (Base Model)

Evaluate with 3 training examples as in-context exemplars.

In [None]:
few_shot_results = run_few_shot_baseline(
    base_model, base_tokenizer, test_raw, train_raw,
    n_shots=3, topic_labels=test_topics, batch_size=4,
)

print(f"3-shot accuracy: {few_shot_results['overall_accuracy']:.4f}")
print(f"3-shot macro F1: {few_shot_results['macro_f1']:.4f}")
print(f"Extraction failure rate: {few_shot_results['extraction_failure_rate']:.4f}")

save_results_json(few_shot_results, os.path.join(PROJECT_ROOT, "results/few_shot_results.json"))

In [None]:
# Free GPU memory before training
del base_model
torch.cuda.empty_cache()
print("GPU memory freed.")

## 5. Hyperparameter Sweep -- Training

Train 6 configurations varying LoRA rank, learning rate, epochs, and dropout.

| Config | What it tests | Key difference |
|--------|--------------|----------------|
| 1 (Baseline) | Standard QLoRA defaults | r=16, lr=2e-4, 2 epochs |
| 2 (Low Rank) | Fewer params sufficient? | r=8, alpha=16 |
| 3 (High Rank) | More capacity helps? | r=64, alpha=128, lr=1e-4 |
| 4 (Low LR) | Slower convergence? | lr=5e-5, 3 epochs |
| 5 (Extended) | Optimal stopping point | 3 epochs, lr=2e-4 |
| 6 (Aggressive) | Speed + regularization | r=32, lr=3e-4, dropout=0.1 |

In [None]:
configs = get_all_configs()
all_training_logs = {}
all_val_results = {}

for config_name, config in configs.items():
    print(f"\n{'='*60}")
    print(f"Training: {config_name}")
    print(f"Description: {config.description}")
    print(f"{'='*60}")

    set_seed(config.seed)

    # Prepare data
    train_ds, val_ds, _ = prepare_datasets(config)

    # Load model with LoRA
    model, tokenizer = load_base_model_and_tokenizer(config, for_training=True)

    # Train
    trainer = create_trainer(model, tokenizer, train_ds, val_ds, config)
    training_metrics = train_and_save(trainer, config)
    all_training_logs[config_name] = training_metrics["log_history"]

    # Quick validation accuracy
    tokenizer.padding_side = "left"  # switch for generation
    val_results = evaluate_on_dataset(model, tokenizer, val_raw, batch_size=8)
    all_val_results[config_name] = val_results
    print(f"  Val accuracy: {val_results['overall_accuracy']:.4f}")

    # Free memory
    del model, trainer
    torch.cuda.empty_cache()

print("\n" + "="*60)
print("All training complete!")
print("="*60)

## 6. Training Curves

In [None]:
# Individual training curves
for config_name, log_history in all_training_logs.items():
    plot_training_curves(
        log_history, title=f"Training Curves: {config_name}",
        save_path=os.path.join(PROJECT_ROOT, f"results/{config_name}/training_curve.png"),
    )

In [None]:
# All configs overlaid
plot_all_configs_comparison(
    all_training_logs,
    save_path=os.path.join(PROJECT_ROOT, "results/all_configs_comparison.png"),
)

In [None]:
# Validation accuracy summary table
val_summary = []
for name, res in all_val_results.items():
    val_summary.append({
        "Config": name,
        "Val Accuracy": f"{res['overall_accuracy']:.4f}",
        "Val Macro F1": f"{res['macro_f1']:.4f}",
        "Extraction Failures": f"{res['extraction_failure_rate']:.4f}",
    })

val_df = pd.DataFrame(val_summary).sort_values("Val Accuracy", ascending=False)
print(val_df.to_string(index=False))

best_config_name = val_df.iloc[0]["Config"]
print(f"\nBest config: {best_config_name}")

## 7. Best Config -- Full Test Evaluation

In [None]:
# Load the best fine-tuned model
best_config = configs[best_config_name]
adapter_path = os.path.join(PROJECT_ROOT, f"results/{best_config_name}/adapter")

ft_model, ft_tokenizer = load_finetuned_model(best_config, adapter_path)
print(f"Loaded adapter from: {adapter_path}")

In [None]:
# Full test set evaluation
ft_results = evaluate_on_dataset(
    ft_model, ft_tokenizer, test_raw,
    topic_labels=test_topics, batch_size=8,
)

print(f"Fine-tuned accuracy: {ft_results['overall_accuracy']:.4f}")
print(f"Fine-tuned macro F1: {ft_results['macro_f1']:.4f}")
print(f"Extraction failures: {ft_results['extraction_failure_rate']:.4f}")
print(f"\nClassification Report:\n{ft_results['classification_report']}")

save_results_json(ft_results, os.path.join(PROJECT_ROOT, "results/finetuned_test_results.json"))

## 8. Comparison: Zero-Shot vs Few-Shot vs Fine-Tuned

In [None]:
comparison = compare_models(zero_shot_results, ft_results, few_shot_results)

comp_df = pd.DataFrame({
    "Metric": ["Accuracy", "Macro F1", "Extraction Failure Rate"],
    "Zero-Shot": [
        comparison["zero_shot"]["accuracy"],
        comparison["zero_shot"]["macro_f1"],
        comparison["zero_shot"]["extraction_failures"],
    ],
    "3-Shot": [
        comparison["few_shot"]["accuracy"],
        comparison["few_shot"]["macro_f1"],
        comparison["few_shot"]["extraction_failures"],
    ],
    "Fine-Tuned": [
        comparison["fine_tuned"]["accuracy"],
        comparison["fine_tuned"]["macro_f1"],
        comparison["fine_tuned"]["extraction_failures"],
    ],
})

print(comp_df.to_string(index=False))
print(f"\nAccuracy improvement (fine-tuned vs zero-shot): +{comparison['delta']['accuracy']:.4f}")

save_results_json(comparison, os.path.join(PROJECT_ROOT, "results/comparison_summary.json"))

## 9. Per-Topic Accuracy Breakdown

In [None]:
plot_per_topic_accuracy(
    ft_results["per_topic_accuracy"],
    base_per_topic=zero_shot_results.get("per_topic_accuracy"),
    title="Per-Topic Accuracy: Base vs Fine-Tuned",
    save_path=os.path.join(PROJECT_ROOT, "results/per_topic_accuracy.png"),
)

# Print per-topic table
topic_df = pd.DataFrame([
    {
        "Topic": topic,
        "Base Accuracy": f"{zero_shot_results.get('per_topic_accuracy', {}).get(topic, 0):.3f}",
        "Fine-Tuned Accuracy": f"{acc:.3f}",
        "Delta": f"{acc - zero_shot_results.get('per_topic_accuracy', {}).get(topic, 0):+.3f}",
    }
    for topic, acc in sorted(ft_results["per_topic_accuracy"].items(),
                             key=lambda x: x[1], reverse=True)
])
print(topic_df.to_string(index=False))

## 10. Error Analysis

In [None]:
errors = error_analysis(ft_results["predictions"])

print(f"Total errors: {errors['total_errors']} / {ft_results['n_total']}")
print(f"\nError breakdown:")
for cat, count in errors["error_counts"].items():
    print(f"  {cat}: {count}")

print(f"\nMost confused answer pairs (gold -> predicted):")
for (gold, pred), count in errors["most_confused_pairs"][:5]:
    print(f"  {gold} -> {pred}: {count} times")

In [None]:
plot_error_taxonomy(
    errors["error_counts"],
    save_path=os.path.join(PROJECT_ROOT, "results/error_taxonomy.png"),
)

# Show example errors
print("\nExample substantive errors:")
for ex in errors["error_examples"]["substantive_error"][:3]:
    print(f"\n  Q index: {ex['idx']}")
    print(f"  Gold: {ex['gold']} | Predicted: {ex['pred']}")
    print(f"  Topic: {ex.get('topic', 'N/A')}")
    print(f"  Raw output: {ex['raw_output'][:100]}")

In [None]:
# Error rate by topic
topic_err_df = pd.DataFrame([
    {"Topic": topic, "Error Rate": f"{rate:.3f}"}
    for topic, rate in sorted(errors["topic_error_rates"].items(),
                               key=lambda x: x[1], reverse=True)
])
print("Error rate by topic:")
print(topic_err_df.to_string(index=False))

## 11. Confidence Calibration

In [None]:
cal_data = confidence_calibration(
    ft_model, ft_tokenizer, test_raw,
    topic_labels=test_topics, batch_size=8,
)

print(f"Expected Calibration Error (ECE): {cal_data['ece']:.4f}")
print(f"Avg confidence (correct):   {cal_data['avg_confidence_correct']:.4f}")
print(f"Avg confidence (incorrect): {cal_data['avg_confidence_incorrect']:.4f}")

plot_calibration_curve(
    cal_data,
    save_path=os.path.join(PROJECT_ROOT, "results/calibration_curve.png"),
)

save_results_json(cal_data, os.path.join(PROJECT_ROOT, "results/calibration_data.json"))

In [None]:
# Clean up
del ft_model
torch.cuda.empty_cache()

## 12. Summary & Conclusions

### Key Findings

*(Fill in after running experiments)*

1. **Fine-tuning improved accuracy from X% (zero-shot) to Y% (+Z points)**
2. **Best hyperparameter configuration:** ...
3. **Strongest topics:** ...
4. **Weakest topics / failure modes:** ...

### Limitations

1. **Heuristic topic classifier** -- keyword-based, may misclassify multi-topic questions
2. **Single model family** -- only tested Mistral-7B; results may differ on Llama-3 or Phi-3
3. **No chain-of-thought** -- model outputs a single letter without reasoning; CoT training could improve accuracy
4. **Dataset scope** -- MedQA covers USMLE only (US-centric medical practice)

### Ethical Considerations

1. **Clinical risk** -- This model must NOT be used for actual medical decisions. Even at 65-70% accuracy, 1 in 3 answers is wrong.
2. **Cultural bias** -- USMLE reflects US medical practice, guidelines, and drug formularies. Performance would likely degrade on non-US medical contexts.
3. **Overconfidence** -- Fine-tuned LLMs can be confidently wrong. The calibration analysis above quantifies this risk.
4. **Demographic blind spots** -- Clinical vignettes in MedQA may underrepresent certain demographics, leading to performance disparities.

### Alternative Design Choices

- **Prompt engineering (Medprompt)** -- Microsoft showed GPT-4 with careful prompting can reach 90%+ without fine-tuning
- **RAG** -- Retrieval-augmented generation using medical textbooks could ground the model's knowledge
- **Larger model** -- Fine-tuning a 70B model would likely yield higher accuracy but requires more compute
- **MedMCQA augmentation** -- Adding 183K Indian medical exam questions with explanations for CoT training

### Reproducibility

- **Model:** `mistralai/Mistral-7B-Instruct-v0.3`
- **Dataset:** `GBaker/MedQA-USMLE-4-options-hf`
- **Libraries:** See `pyproject.toml` for exact versions
- **Seed:** 42
- **Compute:** Google Colab Pro, NVIDIA A100 40GB
- **All code:** Available in `src/` directory