# Kenya Medical Vignettes Model Pipeline

## This notebook orchestrates the ML pipeline for predicting clinician responses to vignettes.

## 1. Cell 1: Install Dependencies and Import Libraries

In [None]:
import subprocess 
import sys
import os
import time
import json
import threading
from pathlib import Path
from IPython.display import display, clear_output

from datetime import datetime

# Install dependencies from requirements.txt
subprocess.run([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])

# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
from collections import Counter
import re
from datasets import load_from_disk 
from ipywidgets import interact, IntSlider

%matplotlib inline

## 2. Cell 2: Data Preprocessing

In [None]:
# Make sure we're in the project root directory
print("Current working directory:", os.getcwd())

# Verify the data files exist
print("Train file exists:", os.path.exists('data/train.csv'))
print("Test file exists:", os.path.exists('data/test.csv'))

print("\n🚀 PRIORITY FIXES: ENHANCED DATA PREPROCESSING")
print("=" * 60)
print("🔧 PRIORITY FIXES APPLIED:")
print("✅ Removed multitask learning components")
print("✅ Simplified prompt format")
print("✅ Removed few-shot examples")
print("✅ Implemented basic augmentation (synonym replacement and noise injection)")
print("✅ Consistent tokenizer handling with default t5-small")
print("=" * 60)

# Run the updated data preprocessing script
result = subprocess.run(['python', 'scripts/data_preprocessing.py'],
                        capture_output=True, text=True, cwd=os.getcwd())

print("Return code:", result.returncode)
print("STDOUT:", result.stdout)
if result.stderr:
    print("STDERR:", result.stderr)

# Only proceed if the script ran successfully
if result.returncode == 0:
    # Load processed datasets
    train_dataset = load_from_disk('outputs/train_dataset')
    val_dataset = load_from_disk('outputs/val_dataset')
    test_dataset = load_from_disk('outputs/test_dataset')

    print(f"\n📊 Dataset Sizes:")
    print(f'Train size: {len(train_dataset)} (with basic augmentation)')
    print(f'Validation size: {len(val_dataset)}')
    print(f'Test size: {len(test_dataset)}')

    # Show sample of enhanced features
    print(f"\n🔍 Sample Verification:")
    print("Sample train example:")
    sample = train_dataset[0]
    print(f"Prompt length: {len(sample['Prompt'])} chars")
    print(f"Target length: {len(sample['Clinician'])} chars" if 'Clinician' in sample else "No target (test data)")

    # Verify augmentation
    print(f"\n🔄 Augmentation Verification:")
    original_prompts = [ex['Prompt'] for ex in train_dataset if 'original' in ex.get('augmentation_type', '')]
    augmented_prompts = [ex['Prompt'] for ex in train_dataset if 'augmented' in ex.get('augmentation_type', '')]
    print(f"Original prompts: {len(original_prompts)}")
    print(f"Augmented prompts: {len(augmented_prompts)}")

    print("\n" + "="*60)
    print("🎯 Preprocessing Completed Successfully!")
    print("Ready for training with simplified, consistent format")
    print("="*60)

else:
    print("❌ Preprocessing failed! Check error messages above.")
    print("Cannot proceed to training without successful preprocessing.")

## 3. Cell 3: Model Training

In [None]:
CURRENT_BATCH = 1

EXPERIMENT_BATCHES = {
    1: [("baseline", "baseline"), ("fast", "fast_training")],
    2: [("aggressive", "aggressive_training"), ("data_augmented", "data_augmented")],
    3: [("balanced", "balanced_training"), ("quality", "quality_training")]
}

def check_environment():
    """Verify environment before running experiments"""
    required_paths = [
        'outputs/train_dataset',
        'outputs/val_dataset',
        'scripts/model_training.py',
        'scripts/run_experiments.py',
        'conf/config.yaml'
    ]
    for path in required_paths:
        if not Path(path).exists():
            print(f"❌ Missing required path: {path}")
            return False
    for config, name in EXPERIMENT_BATCHES[CURRENT_BATCH]:
        config_path = f"conf/experiments/{config}.yaml"
        if not Path(config_path).exists():
            print(f"❌ Missing configuration file: {config_path}")
            return False
    print("✅ Environment check passed")
    return True

def monitor_training(experiments):
    """Monitor training progress"""
    training_data = {exp_name: {'loss': [], 'steps': []} for _, exp_name in experiments}
    
    def update_data():
        for _, exp_name in experiments:
            log_dir = Path(f"./experiments/{exp_name}/hydra_outputs")
            if log_dir.exists():
                log_files = list(log_dir.rglob("*.log"))
                if log_files:
                    try:
                        with open(log_files[0], 'r') as f:
                            lines = f.readlines()[-20:]
                        for line in lines:
                            if '"step":' in line and '"train_loss":' in line:
                                step_match = re.search(r'"step":\s*(\d+)', line)
                                loss_match = re.search(r'"train_loss":\s*([\d.]+)', line)
                                if step_match and loss_match:
                                    step, loss = int(step_match.group(1)), float(loss_match.group(1))
                                    if step not in training_data[exp_name]['steps']:
                                        training_data[exp_name]['steps'].append(step)
                                        training_data[exp_name]['loss'].append(loss)
                    except Exception as e:
                        print(f"⚠️ Error reading log for {exp_name}: {e}")
    
    def plot_progress():
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(10, 6))
        for _, exp_name in experiments:
            data = training_data[exp_name]
            if data['steps']:
                ax.plot(data['steps'], data['loss'], label=exp_name, marker='o')
        ax.set_xlabel('Steps')
        ax.set_ylabel('Loss')
        ax.set_title('🚀 Training Progress')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        display(fig)
        plt.close(fig)
    
    def monitor_loop():
        start_time = time.time()
        while time.time() - start_time < 3600:  # Max 1 hour
            try:
                update_data()
                plot_progress()
                time.sleep(30)
            except Exception as e:
                print(f"⚠️ Monitoring error: {e}")
                break
    
    threading.Thread(target=monitor_loop, daemon=True).start()
    return training_data

print("🚀 ENHANCED EXPERIMENT RUNNER")
print("=" * 50)

if not check_environment():
    print("❌ Environment check failed. Please fix issues before proceeding.")
else:
    current_experiments = EXPERIMENT_BATCHES.get(CURRENT_BATCH, [])
    print(f"🎯 RUNNING BATCH {CURRENT_BATCH}:")
    for i, (config, name) in enumerate(current_experiments, 1):
        print(f"  {i}. {name} ({config})")

    if not current_experiments:
        print(f"❌ Invalid batch: {CURRENT_BATCH}")
    else:
        print("📊 Starting training with monitoring...")
        training_data = monitor_training(current_experiments)
        start_time = time.time()
        env = os.environ.copy()
        env['HYDRA_FULL_ERROR'] = '1'  # Enable full stack traces

        # Added line here to print the current working directory
        print(f"Current working directory: {os.getcwd()}")

        process = subprocess.Popen(
            ['python', 'scripts/run_experiments.py', str(CURRENT_BATCH)],
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, env=env
        )
        
        try:
            stdout, stderr = process.communicate()
            total_time = time.time() - start_time
            
            print(f"\n⏱️ Completed in {total_time/60:.1f} minutes")
            print("STDOUT:", stdout[-2000:])
            if stderr:
                print("STDERR:", stderr[-2000:])
            
            if process.returncode == 0:
                print("✅ EXPERIMENTS COMPLETED!")
                print(f"📁 Results: ./experiments/experiment_results.json")
                if "WINNER" in stdout:
                    lines = stdout.split('\n')
                    for i, line in enumerate(lines):
                        if "WINNER" in line:
                            print("\n🏆 RESULTS:")
                            for j in range(i, min(i+10, len(lines))):
                                if lines[j].strip():
                                    print(lines[j])
                            break
                next_batch = CURRENT_BATCH + 1
                if next_batch in EXPERIMENT_BATCHES:
                    print(f"\n💡 NEXT: Change CURRENT_BATCH = {next_batch}")
                else:
                    print("\n🎉 ALL BATCHES COMPLETE!")
            else:
                print("❌ EXPERIMENTS FAILED!")
                print(f"Error details logged above")
                
        except KeyboardInterrupt:
            print("\n⚠️ Interrupted!")
            process.terminate()
        except Exception as e:
            print(f"❌ Error: {e}")
            if process.poll() is None:
                process.terminate()

    print(f"\nBatch {CURRENT_BATCH} complete. Change CURRENT_BATCH to run next batch.")

## 4. Cell 4: Model Evaluation

In [None]:
# Run the model evaluation script to compute ROUGE metrics
subprocess.run(['python', 'scripts/model_evaluation.py'])
print("Evaluation completed. Check console output for results.")

## 5. Cell 5: Model Optimization

In [None]:
# Run the model optimization script to quantize the model
subprocess.run(['python', 'scripts/model_optimization.py'])
print("Model optimization completed. Check console output for details.")

## 6. Cell 6: Inference

In [None]:
# Run the inference script to generate predictions for the test dataset
subprocess.run(['python', 'scripts/inference.py'])

# Load and display the first few rows of the submission file
submission = pd.read_csv('outputs/submission.csv')
print("Submission file preview:")
print(submission.head())

## 7. Cell 7: Visualizations

In [None]:
# Load test dataset and submission
test_dataset = load_from_disk('outputs/test_dataset')
submission = pd.read_csv('outputs/submission.csv')

# Prediction length distribution
prediction_lengths = [len(pred.split()) for pred in submission['Clinician']]
plt.figure(figsize=(10, 6))
sns.histplot(prediction_lengths, bins=20, kde=True)
plt.title('Distribution of Prediction Lengths (in words)')
plt.xlabel('Number of words')
plt.ylabel('Frequency')
plt.show()

# Format compliance check
all_lowercase = all(pred.islower() for pred in submission['Clinician'])
no_punctuation = all(not any(c in pred for c in '.,!?;:"()[]{}') for pred in submission['Clinician'])
print(f"All predictions lowercase: {all_lowercase}")
print(f"No punctuation in predictions: {no_punctuation}")

# Medical term usage
medical_terms = ['patient', 'diagnosis', 'treatment', 'symptoms', 'condition', 'clinical', 'assessment']
medical_term_counts = [
    sum(1 for pred in submission['Clinician'] if term in pred.lower())
    for term in medical_terms
]
plt.figure(figsize=(10, 6))
sns.barplot(x=medical_terms, y=medical_term_counts)
plt.title('Frequency of Medical Terms in Predictions')
plt.xlabel('Medical Terms')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()