# Tabula-8B VA Cause of Death Prediction on Google Colab

This notebook runs the Tabula-8B model on Google Colab with GPU acceleration for fast inference.

**Requirements:**
- Google Colab with GPU runtime (T4 or better)
- PHMRC dataset file
- ~20GB free space for model


## Step 1: Setup GPU Runtime

**IMPORTANT**: Before running this notebook:
1. Go to `Runtime` → `Change runtime type`
2. Select `GPU` as Hardware accelerator (T4 or better)
3. Click `Save`

In [None]:
# Check GPU availability
import torch
import subprocess
import sys

if torch.cuda.is_available():
    print(f"✅ GPU Available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    !nvidia-smi
else:
    print("❌ No GPU detected! Please enable GPU runtime:")
    print("   Runtime → Change runtime type → GPU")

## Step 2: Clone Repository and Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/cliu238/tabula-8b-va-prediction.git
%cd tabula-8b-va-prediction

# Install required packages
!pip install -q transformers torch accelerate datasets pandas numpy scikit-learn tqdm python-dotenv pillow

## Step 3: Upload PHMRC Data

Upload your PHMRC CSV file when prompted. The file should be named:
`IHME_PHMRC_VA_DATA_ADULT_Y2013M09D11_0.csv`

In [None]:
from google.colab import files
import os

# Create data directory
os.makedirs('data/raw/PHMRC', exist_ok=True)

print("Please upload the PHMRC adult dataset CSV file:")
uploaded = files.upload()

# Move uploaded file to correct location
for filename in uploaded.keys():
    if 'ADULT' in filename:
        !mv "{filename}" data/raw/PHMRC/
        print(f"✅ Moved {filename} to data/raw/PHMRC/")

# Verify file exists
!ls -la data/raw/PHMRC/

## Step 4: Create GPU-Optimized Model Loader

In [None]:
%%writefile run_colab.py
#!/usr/bin/env python
"""
Colab-optimized script for Tabula-8B VA prediction with GPU support
"""

import sys
import torch
import pandas as pd
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# Add src to path
sys.path.insert(0, str(Path.cwd()))

from src.data.preprocessor import PHMRCPreprocessor
from src.data.serializer import VADataSerializer

def load_model_gpu():
    """Load Tabula-8B with GPU optimization."""
    print("Loading Tabula-8B model on GPU...")
    
    model_name = "mlfoundations/tabula-8b"
    
    # Load with GPU and half precision for speed
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True
    )
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"✅ Model loaded on {torch.cuda.get_device_name(0)}")
    return model, tokenizer

def predict_batch_gpu(model, tokenizer, texts, batch_size=8):
    """Run batch predictions on GPU."""
    predictions = []
    device = next(model.parameters()).device
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Predicting"):
        batch = texts[i:i+batch_size]
        
        prompts = [
            f"Based on the following patient information, predict the most likely cause of death.\n\n"
            f"Patient: {text}\n\n"
            f"Cause of death:"
            for text in batch
        ]
        
        # Tokenize batch
        inputs = tokenizer(
            prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=1024
        ).to(device)
        
        # Generate predictions
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=20,
                temperature=0.1,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id
            )
        
        # Decode predictions
        for j, output in enumerate(outputs):
            generated = tokenizer.decode(
                output[inputs['input_ids'][j].shape[0]:],
                skip_special_tokens=True
            ).strip()
            
            # Clean up prediction
            cause = generated.split('\n')[0].strip()
            if ':' in cause:
                cause = cause.split(':')[-1].strip()
            
            predictions.append(cause if cause else "Unknown")
    
    return predictions

def main(sample_size=100):
    """Run complete pipeline on GPU."""
    
    # Load and preprocess data
    print("\n" + "="*60)
    print("Loading and preprocessing data...")
    print("="*60)
    
    preprocessor = PHMRCPreprocessor()
    df = preprocessor.load_data('data/raw/PHMRC/IHME_PHMRC_VA_DATA_ADULT_Y2013M09D11_0.csv')
    
    # Sample data
    sample_df = df.sample(n=min(sample_size, len(df)), random_state=42)
    processed_df = preprocessor.preprocess(sample_df)
    
    # Serialize to text
    print("\nSerializing patient records...")
    serializer = VADataSerializer(verbose=False)
    texts = serializer.serialize_batch(processed_df, show_progress=True)
    
    # Load model
    print("\n" + "="*60)
    print("Loading Tabula-8B model...")
    print("="*60)
    model, tokenizer = load_model_gpu()
    
    # Run predictions
    print("\n" + "="*60)
    print("Running predictions on GPU...")
    print("="*60)
    predictions = predict_batch_gpu(model, tokenizer, texts, batch_size=8)
    
    # Save results
    results_df = processed_df.copy()
    results_df['predicted_cause'] = predictions
    
    output_path = f'predictions_gpu_{sample_size}.csv'
    results_df[['age', 'gender', 'cause_of_death', 'predicted_cause']].to_csv(output_path, index=False)
    
    # Calculate accuracy
    if 'cause_of_death' in results_df.columns:
        correct = sum(1 for true, pred in zip(results_df['cause_of_death'], predictions)
                     if true.lower() == pred.lower())
        accuracy = correct / len(predictions)
        print(f"\n✅ Accuracy: {accuracy:.2%} ({correct}/{len(predictions)})")
    
    print(f"\n✅ Results saved to {output_path}")
    return results_df

if __name__ == "__main__":
    import sys
    sample_size = int(sys.argv[1]) if len(sys.argv) > 1 else 100
    results = main(sample_size)

## Step 5: Download Model and Run Predictions

In [None]:
# Run predictions with GPU acceleration
# This will download the model on first run (~16GB)
!python run_colab.py 50  # Process 50 samples

## Step 6: Analyze Results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load results
results = pd.read_csv('predictions_gpu_50.csv')

# Display sample predictions
print("Sample Predictions:")
print(results[['age', 'gender', 'cause_of_death', 'predicted_cause']].head(10))

# Calculate accuracy by cause
if 'cause_of_death' in results.columns:
    results['correct'] = results['cause_of_death'].str.lower() == results['predicted_cause'].str.lower()
    
    print(f"\nOverall Accuracy: {results['correct'].mean():.2%}")
    
    # Accuracy by top causes
    cause_counts = results.groupby('cause_of_death').agg({
        'correct': ['sum', 'count', 'mean']
    }).round(3)
    cause_counts.columns = ['Correct', 'Total', 'Accuracy']
    cause_counts = cause_counts.sort_values('Total', ascending=False).head(10)
    
    print("\nAccuracy by Top Causes:")
    print(cause_counts)

## Step 7: Download Results

In [None]:
from google.colab import files

# Download the predictions file
files.download('predictions_gpu_50.csv')
print("✅ Results downloaded!")

## Step 8: Run Full Dataset (Optional)

⚠️ **Warning**: Processing the full dataset (7841 records) will take ~15-30 minutes even on GPU

In [None]:
# Uncomment to run on full dataset
# !python run_colab.py 7841