# Echo Note Training with Gemma LLM
This notebook demonstrates how to fine-tune the Gemma language model on echocardiogram notes. The model will learn to analyze echo reports and provide structured assessments across 19 different cardiac features.

## Overview
1. Set up Google Drive and dependencies
2. Load and prepare echo note data
3. Configure data labels and formatting
4. Clean and preprocess the dataset
5. Create train/tune/test splits
6. Convert to HuggingFace format
7. Configure and train the model
8. Evaluate results

## 1. Set Up Google Drive
First, let's mount Google Drive to access our data and save model checkpoints. This step is essential for persisting our data and model files across Colab sessions.

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up working directory
import os
WORKING_DIR = '/content/drive/MyDrive/echo_training'  # Change this to your preferred location
os.makedirs(WORKING_DIR, exist_ok=True)
os.chdir(WORKING_DIR)

print(f"Working directory set to: {WORKING_DIR}")
print("Contents:", os.listdir())

## 2. Install Required Dependencies
Install and import all necessary libraries for the project. We'll need:
- transformers: For the Gemma model and training
- datasets: For data handling
- pandas & numpy: For data manipulation
- sklearn: For data splitting
- torch: For deep learning operations

In [None]:
# Install required packages
!pip install transformers datasets torch numpy pandas scikit-learn bitsandbytes -q

# Import libraries
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
import ast
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, DataCollatorForLanguageModeling, TrainingArguments

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

## 3. Load and Prepare Echo Data
Load your echo data from the CSV file. Make sure your data file is uploaded to the working directory in Google Drive.

In [None]:
# Load your CSV file
# Change 'your_echo_data.csv' to your actual filename in Google Drive
df = pd.read_csv('your_echo_data.csv')

print("Data loaded successfully!")
print(f"Total samples: {len(df)}")
print("\nColumns in the dataset:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head(2).to_string())

## 4. Configure Data Labels
Set up the 19 cardiac feature labels and create functions to parse the label data. These labels represent different aspects of cardiac function that we'll be predicting.

In [None]:
# Define label names
LABEL_NAMES = [
    'LA_cavity', 'RA_dilated', 'LV_systolic', 'LV_cavity',
    'LV_wall', 'RV_cavity', 'RV_systolic', 'AV_stenosis',
    'MV_stenosis', 'TV_regurgitation', 'TV_stenosis',
    'TV_pulm_htn', 'AV_regurgitation', 'MV_regurgitation',
    'RA_pressure', 'LV_diastolic', 'RV_volume_overload',
    'RV_wall', 'RV_pressure_overload'
]

def parse_labels(label_str):
    """Convert string representation of list to actual list."""
    if isinstance(label_str, str):
        return ast.literal_eval(label_str)
    elif isinstance(label_str, list):
        return label_str
    else:
        return label_str

# Parse labels
df['labels_parsed'] = df['labels'].apply(parse_labels)

# Verify labels
assert all(len(labels) == 19 for labels in df['labels_parsed']), \
    "Not all label arrays have 19 values!"

print("\nLabels parsed successfully!")
print(f"Example labels: {df['labels_parsed'].iloc[0]}")

# Display label distribution
label_counts = pd.DataFrame([
    [label, sum(x[i] for x in df['labels_parsed'])] 
    for i, label in enumerate(LABEL_NAMES)
], columns=['Label', 'Count'])

print("\nLabel Distribution:")
print(label_counts)

## 5. Format Data for Training
Create a formatting function to prepare the data for Gemma. We'll structure the input as a prompt with the echo report and expected output format.

In [None]:
def format_echo_prompt(row):
    """Format echo report into Gemma instruction format."""
    input_text = row['input']  # Change 'input' to your actual column name
    labels = row['labels_parsed']
    
    # Create formatted label string
    label_pairs = [f"{LABEL_NAMES[i]}: {labels[i]}" for i in range(19)]
    label_text = "\n".join(label_pairs)
    
    prompt = f"""<start_of_turn>user
Analyze this echocardiogram report and provide assessment values for each cardiac feature. Output should be in the format "feature: value" for each of the 19 features.

Report:
{input_text}<end_of_turn>
<start_of_turn>model
{label_text}<end_of_turn>"""
    
    return prompt

# Apply formatting
df['text'] = df.apply(format_echo_prompt, axis=1)

print("FORMATTED EXAMPLE:")
print("="*70)
print(df['text'].iloc[0])
print("="*70)

## 6. Data Cleaning and Preprocessing
Clean the dataset by removing duplicates, handling missing values, and filtering by text length if needed.

In [None]:
print(f"Before cleaning: {len(df)} samples")

# Remove any rows with missing data
df = df.dropna(subset=['text'])

# Remove duplicates based on input text
df = df.drop_duplicates(subset=['input'])

# Analyze text length
df['text_length'] = df['text'].str.len()
print(f"\nText length stats:")
print(df['text_length'].describe())

# Optional: Remove extremely short or long examples
# Uncomment and adjust thresholds as needed
# df = df[(df['text_length'] > 100) & (df['text_length'] < 4096)]

df = df.drop(columns=['text_length'])

print(f"After cleaning: {len(df)} samples")

## 7. Create Dataset Splits
Split the data into training (70%), tuning/validation (15%), and test (15%) sets. We'll save these splits to CSV files in Google Drive for reproducibility.

In [None]:
# First split: separate test set (15%)
train_tune_df, test_df = train_test_split(
    df,
    test_size=0.15,
    random_state=42,
    shuffle=True
)

# Second split: separate train and tune from remaining 85%
train_df, tune_df = train_test_split(
    train_tune_df,
    test_size=0.1765,  # This gives us 15% of original
    random_state=42,
    shuffle=True
)

print("DATASET SPLITS:")
print("="*70)
print(f"Training set:   {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"Tuning set:     {len(tune_df)} samples ({len(tune_df)/len(df)*100:.1f}%)")
print(f"Test set:       {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
print(f"Total:          {len(df)} samples")

# Save splits to CSV
train_df.to_csv('echo_train.csv', index=False)
tune_df.to_csv('echo_tune.csv', index=False)
test_df.to_csv('echo_test.csv', index=False)

print("\n✓ Splits saved to CSV files in Google Drive:")

## 8. Convert to HuggingFace Format
Transform our data into HuggingFace datasets for efficient training. This includes creating a DatasetDict and tokenizing the data.

In [None]:
# Create datasets from the DataFrames
train_dataset = Dataset.from_pandas(train_df[['text']], preserve_index=False)
tune_dataset = Dataset.from_pandas(tune_df[['text']], preserve_index=False)
test_dataset = Dataset.from_pandas(test_df[['text']], preserve_index=False)

# Combine into a DatasetDict
dataset = DatasetDict({
    'train': train_dataset,
    'validation': tune_dataset,
    'test': test_dataset
})

print("HUGGING FACE DATASET:")
print("="*70)
print(dataset)

## 9. Configure Model and Tokenizer
Set up the Gemma model and tokenizer. We'll use the 2B parameter version for prototyping, but you can switch to 9B for production.

In [None]:
# Load model and tokenizer
model_name = 'google/gemma-2b'  # or 'google/gemma-9b' for production
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,  # Changed to bfloat16
)

# Enable gradient checkpointing after model creation
model.gradient_checkpointing_enable()
model.config.use_cache = False  # Disable KV cache for training

# Tokenization function
def tokenize_function(examples):
    """Tokenize the text data."""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,  # Reduced from 1024 to save memory
        padding="max_length",
    )

# Tokenize all splits
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing datasets",
)

print("TOKENIZED DATASETS:")
print("="*70)
print(tokenized_datasets)

## 10. Configure Training
Set up the training arguments and create the trainer. We'll use settings optimized for medical domain training.

In [None]:
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Training arguments optimized for medical domain with memory constraints
training_args = TrainingArguments(
    output_dir="./gemma_echo_finetuned",
    num_train_epochs=5,
    
    # Reduced batch sizes
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    
    learning_rate=2e-5,
    weight_decay=0.01,
    
    # Memory optimization settings
    fp16=False,          # Disable fp16
    bf16=True,          # Enable bf16 instead
    gradient_checkpointing=True,
    
    logging_dir='./logs',
    logging_steps=50,
    
    # Evaluation and saving settings
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    
    # Training optimization
    warmup_steps=100,
    lr_scheduler_type="cosine",
    optim="adamw_torch",
    max_grad_norm=1.0,
    
    report_to="none",
    hub_strategy="end",
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)

print("\n✓ Trainer configured for echo report data!")
print(f"Training on:   {len(tokenized_datasets['train'])} samples")
print(f"Validating on: {len(tokenized_datasets['validation'])} samples")
print(f"Test set:      {len(tokenized_datasets['test'])} samples (held out)")

## 11. Model Training and Evaluation
Train the model and evaluate its performance. We'll also include a helper function for final evaluation on the test set.

In [None]:
# Training
print("Starting training...")
trainer.train()

def evaluate_on_test_set():
    """Evaluate the fine-tuned model on the held-out test set."""
    print("\nEVALUATING ON TEST SET")
    print("="*70)
    
    test_results = trainer.evaluate(tokenized_datasets['test'])
    
    print("\nTest Set Results:")
    for key, value in test_results.items():
        print(f"  {key}: {value:.4f}")
    
    return test_results

# Evaluate on test set
test_results = evaluate_on_test_set()

# Save the final model
output_dir = "final_model"
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print("\n✓ Model and tokenizer saved to:", output_dir)

## Important Notes for Medical Data

1. **Data Balance**
   - Check if your 19 labels are balanced
   - Medical data often has class imbalance
   - Consider stratified splitting if certain conditions are rare

2. **Evaluation Metrics**
   - Loss alone may not be sufficient
   - Consider implementing custom metrics (F1, precision, recall per label)
   - Medical predictions need high precision

3. **Model Size**
   - Gemma 2B is good for prototyping
   - Consider Gemma 9B for production if accuracy is critical
   
4. **Training Tips**
   - Monitor validation loss closely
   - Stop if validation loss stops decreasing
   - Medical domain may need more epochs (5-10)
   - Lower learning rate is safer for specialized domains

5. **Prompt Engineering**
   - Current format outputs all 19 values at once
   - Consider one value at a time for higher reliability
   - Add medical context in prompts if needed

Remember to validate the model thoroughly before any clinical use!