<a href="https://colab.research.google.com/github/christinium/Health/blob/main/EchoProject/Gemma_prompt_only_echo_label_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Using Gemma to Label Echo Notes

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up working directory
import os
WORKING_DIR = '/content/drive/MyDrive/echo_training/'  # Change this to your preferred location
os.makedirs(WORKING_DIR, exist_ok=True)
os.chdir(WORKING_DIR)

Mounted at /content/drive


In [None]:
# Import libraries
import pandas as pd
import numpy as np
import ast
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
from tqdm import tqdm
from datetime import datetime

In [None]:
import ast
from typing import List, Dict

In [None]:
test_df = pd.read_csv('echo_test.csv')

In [None]:

test_df = test_df.rename(columns={test_df.columns[0]: 'id_num'})

In [None]:
# Constants and configurations
LABEL_NAMES = [
    'LA_cavity', 'RA_dilated', 'LV_systolic', 'LV_cavity',
    'LV_wall', 'RV_cavity', 'RV_systolic', 'AV_stenosis',
    'MV_stenosis', 'TV_regurgitation', 'TV_stenosis',
    'TV_pulm_htn', 'AV_regurgitation', 'MV_regurgitation',
    'RA_pressure', 'LV_diastolic', 'RV_volume_overload',
    'RV_wall', 'RV_pressure_overload'
]

# Medical context for each feature
FEATURE_CONTEXT = {
    'LA_cavity': 'left atrial cavity size',
    'RA_dilated': 'right atrial dilation',
    'LV_systolic': 'left ventricular systolic function',
    'LV_cavity': 'left ventricular cavity size',
    'LV_wall': 'left ventricular wall size',
    'RV_cavity': 'right ventricular cavity size',
    'RV_systolic': 'right ventricular systolic function',
    'AV_stenosis': 'atrial virus stenoses',
    'MV_stenosis': 'mitral valve stenoses',
    'TV_regurgitation': 'tricuspid valve regurgitation',
    'TV_stenosis': 'tricuspid valve stenoses',
    'TV_pulm_htn': 'tricuspid valve pulmonary hypertension',
    'AV_regurgitation': 'atrial virus regurgitation',
    'MV_regurgitation': 'mitral valve regurgitation',
    'RA_pressure': 'right atrial pressure',
    'LV_diastolic': 'left ventricular diastolic function',
    'RV_volume_overload': 'right ventricular volume overload',
    'RV_wall': 'right ventricular wall thickness',
    'RV_pressure_overload': 'right ventricular pressure overload'
}

CODING_SCHEMA = {
    "Study is not adequate to evaluate": -3,
    "Abnormality not present": 0,
    "Abnormality is present but not quantifiable": -2,
    "Abnormality can be categorized as": {
        "hyperdynamic": -1,
        "normal": 0,
        "mild dysfunction": 1,
        "moderate dysfunction": 2,
    }
}

In [None]:
# ============================================================================
# PROMPT ENGINEERING: Create structured prompts for Gemma model
# ============================================================================

def create_structured_prompt(text: str) -> str:
    """Creates a structured medical prompt for echo analysis."""

    # Format feature context
    feature_list = "\n".join([f"- {key}: {value}" for key, value in FEATURE_CONTEXT.items()])

    # Format expected output
    expected_output = "\n".join([f"{label}: [number]" for label in LABEL_NAMES])

    return f"""<start_of_turn>user
As a cardiologist, analyze this echocardiogram report and provide a structured assessment.

For each feature:
{feature_list}

Use this coding schema to evaluate:
-3: Study is not adequate to evaluate
 0: Abnormality not present
-2: Abnormality is present but not quantifiable
-1: Hyperdynamic
 0: Normal
 1: Mild dysfunction
 2: Moderate dysfunction
 3: Severe dysfunction

Format your response exactly as follows:
{expected_output}

Report:
{text}<end_of_turn>
<start_of_turn>model"""

In [None]:
test_df['gemma_long_prompt'] = test_df['text'].apply(create_structured_prompt)

In [None]:
# Here is a sample of what the prompt looks like
print(test_df['prompt'].iloc[0])

<start_of_turn>user
As a cardiologist, analyze this echocardiogram report and provide a structured assessment. 

For each feature:
- LA_cavity: left atrial cavity size
- RA_dilated: right atrial dilation
- LV_systolic: left ventricular systolic function
- LV_cavity: left ventricular cavity size
- LV_wall: left ventricular wall size
- RV_cavity: right ventricular cavity size
- RV_systolic: right ventricular systolic function
- AV_stenosis: atrial virus stenoses
- MV_stenosis: mitral valve stenoses
- TV_regurgitation: tricuspid valve regurgitation
- TV_stenosis: tricuspid valve stenoses
- TV_pulm_htn: tricuspid valve pulmonary hypertension
- AV_regurgitation: atrial virus regurgitation
- MV_regurgitation: mitral valve regurgitation
- RA_pressure: right atrial pressure
- LV_diastolic: left ventricular diastolic function
- RV_volume_overload: right ventricular volume overload
- RV_wall: right ventricular wall thickness
- RV_pressure_overload: right ventricular pressure overload

Use this c

In [None]:
# ============================================================================
# MODEL LOADING: Initialize Gemma 2B instruction-tuned model
# ============================================================================

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

# Load Gemma model
model_name = 'google/gemma-2b-it' # 2B parameter instruction-tuned variant
# Load tokenizer (converts text to tokens the model understands)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with optimizations:
# - torch_dtype=torch.bfloat16: Use bfloat16 precision (saves memory, faster inference)
# - device_map="auto": Automatically distribute model across available GPUs
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print("Model loaded successfully!")

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Model loaded successfully!


In [None]:
# ============================================================================
# BATCH INFERENCE: Run model predictions in batches for efficiency
# ============================================================================

def run_gemma_batch(prompts_list, batch_size=4):
    """Run multiple prompts in a batch."""
    results = []

    for i in range(0, len(prompts_list), batch_size):
        batch = prompts_list[i:i+batch_size]

        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        ).to("cuda")

        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode each output in the batch
        for j, output in enumerate(outputs):
            response = tokenizer.decode(
                output[inputs['input_ids'].shape[1]:],
                skip_special_tokens=True
            )
            results.append(response)

    return results

In [None]:

# ============================================================================
# RUN INFERENCE: Process all echo reports through Gemma model
# ============================================================================
# Process in batches
batch_size = 64   # Adjust based on your GPU memory (64 works on Colab high-RAM)
prompts_list = test_df['prompt'].tolist()
results = []

# Run inference with progress bar
for i in tqdm(range(0, len(prompts_list), batch_size)):
    batch = prompts_list[i:i+batch_size]
    batch_results = run_gemma_batch(batch, batch_size=batch_size)
    results.extend(batch_results)
# Store raw model outputs
test_df['gemma_result'] = results

100%|██████████| 104/104 [28:43<00:00, 16.57s/it]


In [None]:
from datetime import datetime
# Save predictions to avoid re-running inference
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f'test_df_with_long_gemma_predictions_{timestamp}.csv'
test_df.to_csv(filename, index=False)

In [None]:
temp_df = pd.read_csv('/content/drive/MyDrive/echo_training/test_df_with_gemma_predictions.csv')

In [None]:
temp_df = temp_df.rename(columns={temp_df.columns[0]: 'id_num'})

In [None]:
# ============================================================================
# PERFORMANCE EVALUATION: Parse model outputs and calculate accuracy
# ============================================================================

# Load your results
test_df = pd.read_csv('/content/drive/MyDrive/echo_training/test_df_with_gemma_predictions.csv')
test_df = test_df.rename(columns={test_df.columns[0]: 'id_num'})

# ----------------------------------------------------------------------------
# STEP 1: Parse raw model outputs into structured predictions
# We are using regular expressions because the Gemma output is
# not consistent.
# ----------------------------------------------------------------------------

def parse_medical_report(report_text):
    """
    Extract label:value pairs from model's text output.

    Example input: "LA_cavity: 0\nRA_dilated: 1\n..."
    Example output: {'LA_cavity': 0, 'RA_dilated': 1, ...}
    """
    pattern = r'\*?\*?([A-Z_a-z]+)\s*:\s*(\d+)\*?\*?'
    matches = re.findall(pattern, str(report_text))
    parsed_data = {}
    for key, value in matches:
        if key in LABEL_NAMES:
            parsed_data[key] = int(value)
    return parsed_data

def parse_to_list(parsed_data):
    """Convert parsed dictionary to ordered list matching LABEL_NAMES."""
    return [parsed_data.get(key, None) for key in LABEL_NAMES]

# Process predictions: Convert text outputs to numeric lists
test_df['gemma_result_list'] = test_df['gemma_result'].apply(
    lambda x: parse_to_list(parse_medical_report(str(x))) if pd.notna(x) else []
)

# ----------------------------------------------------------------------------
# STEP 2: Calculate accuracy metrics
# ----------------------------------------------------------------------------

def calculate_accuracy(df):
    """
    Compare model predictions against ground truth labels.

    Returns:
        Dictionary with overall accuracy and per-label breakdown
    """
    results = {
        'overall_accuracy': 0,
        'per_label_accuracy': {},
        'per_label_correct': {},
        'per_label_total': {}
    }

    correct_per_label = {label: 0 for label in LABEL_NAMES}
    total_per_label = {label: 0 for label in LABEL_NAMES}
    total_predictions = 0
    total_correct = 0

    for idx, row in df.iterrows():
        labels = row['labels_parsed']  # Ground truth
        predictions = row['gemma_result_list']  # Model predictions

        # Parse labels if stored as string
        if isinstance(labels, str):
            try:
                labels = ast.literal_eval(labels)
            except:
                continue

        # Skip invalid rows
        if labels is None or predictions is None or len(labels) == 0 or len(predictions) == 0:
            continue

        # Compare each label position
        for i, label_name in enumerate(LABEL_NAMES):
            if i >= min(len(labels), len(predictions)):
                break

            label_val = labels[i]
            pred_val = predictions[i]

            if label_val is None or pred_val is None:
                continue

            # Normalize to integers for comparison
            try:
                label_val = int(float(label_val))
                pred_val = int(float(pred_val))
            except (ValueError, TypeError):
                continue

            total_per_label[label_name] += 1
            total_predictions += 1

            # Check if prediction matches ground truth
            if label_val == pred_val:
                correct_per_label[label_name] += 1
                total_correct += 1

    # Calculate overall accuracy
    if total_predictions > 0:
        results['overall_accuracy'] = total_correct / total_predictions

    # Calculate per-label accuracy
    for label in LABEL_NAMES:
        if total_per_label[label] > 0:
            results['per_label_accuracy'][label] = correct_per_label[label] / total_per_label[label]
            results['per_label_correct'][label] = correct_per_label[label]
            results['per_label_total'][label] = total_per_label[label]
        else:
            results['per_label_accuracy'][label] = None
            results['per_label_correct'][label] = 0
            results['per_label_total'][label] = 0

    return results

# ----------------------------------------------------------------------------
# STEP 3: Display results
# ----------------------------------------------------------------------------

# Run accuracy calculation
accuracy_results = calculate_accuracy(test_df)

# Display results
print("=" * 80)
print("GEMMA MODEL PERFORMANCE ON ECHO NOTES")
print("=" * 80)
print(f"\nOverall Accuracy: {accuracy_results['overall_accuracy']:.4f} ({accuracy_results['overall_accuracy']*100:.2f}%)")
print("\n" + "-" * 80)
print("Per-Label Performance:")
print("-" * 80)

# Create summary DataFrame
summary_df = pd.DataFrame({
    'Label': LABEL_NAMES,
    'Accuracy': [accuracy_results['per_label_accuracy'][label] if accuracy_results['per_label_accuracy'][label] is not None else 0
                 for label in LABEL_NAMES],
    'Correct': [accuracy_results['per_label_correct'][label] for label in LABEL_NAMES],
    'Total': [accuracy_results['per_label_total'][label] for label in LABEL_NAMES]
})
summary_df['Percentage'] = summary_df['Accuracy'] * 100

print(summary_df.to_string(index=False))
print("=" * 80)

GEMMA MODEL PERFORMANCE ON ECHO NOTES

Overall Accuracy: 0.2024 (20.24%)

--------------------------------------------------------------------------------
Per-Label Performance:
--------------------------------------------------------------------------------
               Label  Accuracy  Correct  Total  Percentage
           LA_cavity  0.158398      530   3346   15.839809
          RA_dilated  0.191320      626   3272   19.132029
         LV_systolic  0.063887      213   3334    6.388722
           LV_cavity  0.139409      462   3314   13.940857
             LV_wall  0.253628      839   3308   25.362757
           RV_cavity  0.219824      723   3289   21.982365
         RV_systolic  0.253180      836   3302   25.317989
         AV_stenosis  0.093809      300   3198    9.380863
         MV_stenosis  0.190709      624   3272   19.070905
    TV_regurgitation  0.200738      653   3253   20.073778
         TV_stenosis  0.250787      797   3178   25.078666
         TV_pulm_htn  0.249529   