# Medical Chatbot Memory Study - Google Colab

This notebook runs the complete experiment with real LLMs in Google Colab.

**Requirements:**
- Google Colab with GPU runtime (Runtime → Change runtime type → GPU)
- ~15 GB GPU memory (use T4 or better)

**What this does:**
1. Clones the repository
2. Installs dependencies
3. Loads Mistral-7B models with 4-bit quantization
4. Runs experiments comparing chatbot WITH vs WITHOUT memory
5. Generates results and metrics

## 1. Setup: Clone Repository & Install Dependencies

In [None]:
# Clone the repository (replace with your actual repo URL)
!git clone https://github.com/yourusername/health-experiment.git
%cd health-experiment

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q transformers>=4.35.0
!pip install -q accelerate>=0.24.0
!pip install -q bitsandbytes>=0.41.0
!pip install -q sentencepiece>=0.1.99
!pip install -q pyyaml numpy pandas python-dateutil tqdm

## 2. Verify GPU and Environment

In [None]:
import torch
import sys
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path.cwd()))

# Check GPU
print("=" * 60)
print("ENVIRONMENT CHECK")
print("=" * 60)
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected! LLM will be VERY slow on CPU.")
    print("   Go to Runtime → Change runtime type → Select GPU")

## 3. Test LLM Loading

In [None]:
from src.utils.llm_utils import test_model_loading, get_device

print("Testing LLM loading with a small model...")
print("(This will take 1-2 minutes)\n")

device = get_device()
print(f"\nUsing device: {device}")

# Test with small GPT-2 model first
success = test_model_loading("gpt2")

if success:
    print("\n✓ LLM system working correctly!")
else:
    print("\n✗ LLM test failed. Check errors above.")

## 4. Initialize Simulation with LLMs

This will load Mistral-7B with 4-bit quantization (~3.5GB GPU memory per model).

In [None]:
from src.core.simulation import MedicalChatbotSimulation

print("Initializing simulation with Mistral-7B models...")
print("(This will download models if first time - may take 5-10 minutes)\n")

sim = MedicalChatbotSimulation(
    config_path="config/experiment_config.yaml",
    random_seed=42
)

print("\n✓ Simulation initialized!")
print(f"  Patient model: {sim.patient_model_name}")
print(f"  Doctor model: {sim.doctor_model_name}")

## 5. Run Quick Test (1 day, both conditions)

Test the system with a single day to verify everything works.

In [None]:
print("Running quick test: 1 day, both conditions...\n")

# Test WITHOUT history
print("=" * 60)
print("Test 1: WITHOUT history")
print("=" * 60)
results_without = sim.run_experiment(
    patient_id="patient_001",
    num_days=1,
    condition="without_history",
    random_seed=42
)

print(f"\n✓ Test 1 complete:")
print(f"  Conversations: {results_without['total_conversations']}")
print(f"  Total turns: {results_without['total_turns']}")

# Test WITH history
print("\n" + "=" * 60)
print("Test 2: WITH history")
print("=" * 60)
results_with = sim.run_experiment(
    patient_id="patient_001",
    num_days=1,
    condition="with_history",
    random_seed=43
)

print(f"\n✓ Test 2 complete:")
print(f"  Conversations: {results_with['total_conversations']}")
print(f"  Total turns: {results_with['total_turns']}")

print("\n✓ Quick test passed! System is ready for full experiments.")

## 6. View Sample Conversation

Let's look at what the LLMs generated.

In [None]:
import json
from pathlib import Path

# Load the most recent conversation
conv_dir = Path("data/conversations/patient_001/without_history")
conv_files = sorted(conv_dir.glob("day_*.json"))

if conv_files:
    with open(conv_files[-1], 'r') as f:
        conv = json.load(f)
    
    print("=" * 60)
    print(f"SAMPLE CONVERSATION: {conv['conversation_id']}")
    print("=" * 60)
    print(f"Day: {conv['simulation_day']}")
    print(f"Condition: {conv['condition']}")
    print(f"Total turns: {len(conv['turns'])}")
    print("\n" + "-" * 60)
    
    # Show conversation
    for i, turn in enumerate(conv['turns'], 1):
        speaker = "PATIENT" if turn['speaker'] == 'patient' else "CHATBOT"
        print(f"\n[{i}] {speaker}:")
        print(f"{turn['message']}")
    
    # Show metrics
    print("\n" + "=" * 60)
    print("METRICS")
    print("=" * 60)
    metrics = conv.get('metrics', {}).get('automatic', {})
    print(f"Bot questions asked: {metrics.get('bot_questions_asked', 0)}")
    print(f"Info gathered: {metrics.get('key_info_gathered', [])}")
    print(f"Info missed: {metrics.get('key_info_missed', [])}")
else:
    print("No conversations found yet.")

## 7. Run Full Experiment (7 days, both conditions)

Now run the complete experiment. This will take ~30-45 minutes.

In [None]:
print("Running FULL experiment: 7 days, both conditions")
print("This will take approximately 30-45 minutes...\n")

# Run for patient_001, 7 days, both conditions
results = sim.run_full_experiment(
    patient_ids=["patient_001"],
    conditions=["without_history", "with_history"],
    num_days=7
)

print("\n" + "=" * 60)
print("EXPERIMENT COMPLETE!")
print("=" * 60)
print(f"Total experiments run: {len(results['results'])}")
print(f"Successful: {sum(1 for r in results['results'] if 'error' not in r)}")

# Summary
for result in results['results']:
    if 'error' not in result:
        print(f"\n{result['patient_id']} - {result['condition']}:")
        print(f"  Conversations: {result['total_conversations']}")
        print(f"  Total turns: {result['total_turns']}")

## 8. Analyze Results

Compare performance between WITH and WITHOUT history conditions.

In [None]:
import pandas as pd
import json
from pathlib import Path

# Load all conversations and extract metrics
metrics_data = []

for condition in ["with_history", "without_history"]:
    conv_dir = Path(f"data/conversations/patient_001/{condition}")
    
    for conv_file in conv_dir.glob("day_*.json"):
        with open(conv_file, 'r') as f:
            conv = json.load(f)
        
        metrics = conv.get('metrics', {}).get('automatic', {})
        
        metrics_data.append({
            'condition': condition,
            'day': conv['simulation_day'],
            'turns': len(conv['turns']),
            'bot_questions': metrics.get('bot_questions_asked', 0),
            'redundant_questions': metrics.get('redundant_questions', 0),
            'history_references': metrics.get('references_to_history', 0),
            'info_gathered_count': len(metrics.get('key_info_gathered', [])),
            'info_missed_count': len(metrics.get('key_info_missed', []))
        })

df = pd.DataFrame(metrics_data)

print("=" * 60)
print("RESULTS COMPARISON")
print("=" * 60)

# Group by condition and calculate means
summary = df.groupby('condition').agg({
    'turns': 'mean',
    'bot_questions': 'mean',
    'redundant_questions': 'mean',
    'history_references': 'mean',
    'info_gathered_count': 'mean',
    'info_missed_count': 'mean'
}).round(2)

print("\nAverage per conversation:")
print(summary)

# Calculate percentage difference
print("\n" + "=" * 60)
print("KEY FINDINGS")
print("=" * 60)

with_hist = summary.loc['with_history']
without_hist = summary.loc['without_history']

print(f"\nRedundant questions:")
print(f"  WITH history: {with_hist['redundant_questions']:.2f}")
print(f"  WITHOUT history: {without_hist['redundant_questions']:.2f}")

print(f"\nHistory references:")
print(f"  WITH history: {with_hist['history_references']:.2f}")
print(f"  WITHOUT history: {without_hist['history_references']:.2f}")

print(f"\nInformation gathering efficiency:")
print(f"  WITH history: {with_hist['info_gathered_count']:.2f} gathered, {with_hist['info_missed_count']:.2f} missed")
print(f"  WITHOUT history: {without_hist['info_gathered_count']:.2f} gathered, {without_hist['info_missed_count']:.2f} missed")

## 9. Download Results

Download all data for further analysis.

In [None]:
from google.colab import files
import shutil

# Create zip of all results
print("Creating results archive...")
shutil.make_archive('experiment_results', 'zip', 'data')

print("Downloading results...")
files.download('experiment_results.zip')

print("\n✓ Download complete!")
print("\nThe zip contains:")
print("  - All conversation logs (JSON)")
print("  - Patient state files")
print("  - Results summaries")

## 10. Clear GPU Memory (Optional)

If you want to run again or free up GPU memory.

In [None]:
from src.utils.llm_utils import clear_model_cache
import gc

print("Clearing GPU memory...")
clear_model_cache()
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    print(f"GPU memory cached: {torch.cuda.memory_reserved()/1e9:.2f} GB")

print("\n✓ Memory cleared!")