<a href="https://colab.research.google.com/github/christinium/Health/blob/main/EchoProject/Gemma_prompt_only_echo_label_longer_prompt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Using Gemma to Label Echo Notes

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up working directory
import os
WORKING_DIR = '/content/drive/MyDrive/echo_training/'  # Change this to your preferred location
os.makedirs(WORKING_DIR, exist_ok=True)
os.chdir(WORKING_DIR)

Mounted at /content/drive


In [2]:
# Import libraries
import pandas as pd
import numpy as np
import ast
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
from tqdm import tqdm
from datetime import datetime

In [3]:
import ast
from typing import List, Dict

In [4]:
test_df = pd.read_csv('echo_test.csv')
test_df = test_df.rename(columns={test_df.columns[0]: 'id_num'})

In [5]:
train_df = pd.read_csv('echo_train.csv')

In [7]:
# Constants and configurations
LABEL_NAMES = [
    'LA_cavity', 'RA_dilated', 'LV_systolic', 'LV_cavity',
    'LV_wall', 'RV_cavity', 'RV_systolic', 'AV_stenosis',
    'MV_stenosis', 'TV_regurgitation', 'TV_stenosis',
    'TV_pulm_htn', 'AV_regurgitation', 'MV_regurgitation',
    'RA_pressure', 'LV_diastolic', 'RV_volume_overload',
    'RV_wall', 'RV_pressure_overload'
]

# Medical context for each feature
FEATURE_CONTEXT = {
    'LA_cavity': 'left atrial cavity size',
    'RA_dilated': 'right atrial dilation',
    'LV_systolic': 'left ventricular systolic function',
    'LV_cavity': 'left ventricular cavity size',
    'LV_wall': 'left ventricular wall size',
    'RV_cavity': 'right ventricular cavity size',
    'RV_systolic': 'right ventricular systolic function',
    'AV_stenosis': 'atrial virus stenoses',
    'MV_stenosis': 'mitral valve stenoses',
    'TV_regurgitation': 'tricuspid valve regurgitation',
    'TV_stenosis': 'tricuspid valve stenoses',
    'TV_pulm_htn': 'tricuspid valve pulmonary hypertension',
    'AV_regurgitation': 'atrial virus regurgitation',
    'MV_regurgitation': 'mitral valve regurgitation',
    'RA_pressure': 'right atrial pressure',
    'LV_diastolic': 'left ventricular diastolic function',
    'RV_volume_overload': 'right ventricular volume overload',
    'RV_wall': 'right ventricular wall thickness',
    'RV_pressure_overload': 'right ventricular pressure overload'
}

CODING_SCHEMA = {
    "Study is not adequate to evaluate": -3,
    "Abnormality not present": 0,
    "Abnormality is present but not quantifiable": -2,
    "Abnormality can be categorized as": {
        "hyperdynamic": -1,
        "normal": 0,
        "mild dysfunction": 1,
        "moderate dysfunction": 2,
    }
}

In [22]:
# ============================================================================
# PROMPT ENGINEERING: Create structured prompts for Gemma model
# ============================================================================

def create_structured_prompt(text: str) -> str:
    """Creates a structured medical prompt for echo analysis."""

    # Format feature context
    feature_list = "\n".join([f"- {key}: {value}" for key, value in FEATURE_CONTEXT.items()])

    # Format expected output
    expected_output = "\n".join([f"{label}: [number]" for label in LABEL_NAMES])

    return f"""<start_of_turn>user
As a cardiologist, analyze this echocardiogram report and provide a structured assessment.

For each feature:
{feature_list}

Use this coding schema to evaluate:
-3: Study is not adequate to evaluate
 0: Abnormality not present
-2: Abnormality is present but not quantifiable
-1: Hyperdynamic
 0: Normal
 1: Mild dysfunction
 2: Moderate dysfunction
 3: Severe dysfunction

**IMPORTANT: Output ONLY the feature names and numbers in the exact format shown below. Do not add any explanations, descriptions, or additional text.**

Format your response exactly as:
LA_cavity: [number]
RA_dilated: [number]
LV_systolic: [number]
LV_cavity: [number]
LV_wall: [number]
RV_cavity: [number]
RV_systolic: [number]
AV_stenosis: [number]
MV_stenosis: [number]
TV_regurgitation: [number]
TV_stenosis: [number]
TV_pulm_htn: [number]
AV_regurgitation: [number]
MV_regurgitation: [number]
RA_pressure: [number]
LV_diastolic: [number]
RV_volume_overload: [number]
RV_wall: [number]
RV_pressure_overload: [number]

Do not include any other text, explanations, or formatting.
**Examples:**

Example 1:
Report: LEFT ATRIUM: The left atrium is mildly dilated. RIGHT ATRIUM/INTERATRIAL SEPTUM: A catheter or pacing wire is seen in the right atrium and/or right ventricle. LEFT VENTRICLE: Left ventricular wall thickness, cavity size, and systolic function are normal (LVEF>55%). Due to suboptimal technical quality, a focal wall motion abnormality cannot be fully excluded. RIGHT VENTRICLE: The right ventricle is not well seen. AORTIC VALVE: A mechanical aortic valve prosthesis is present. The transaortic gradient is higher than expected for this type of prosthesis. Mild to moderate ([**1-24**]+) aortic regurgitation is seen. MITRAL VALVE: The mitral valve leaflets are mildly thickened. Mild to moderate ([**1-24**]+) mitral regurgitation is seen. PERICARDIUM: There is no pericardial effusion.

Response:
LA_cavity: 1
RA_dilated: 0
LV_systolic: 0
LV_cavity: 0
LV_wall: 0
RV_cavity: -3
RV_systolic: -3
AV_stenosis: 0
MV_stenosis: 0
TV_regurgitation: 0
TV_stenosis: 0
TV_pulm_htn: 0
AV_regurgitation: 0
MV_regurgitation: 0
RA_pressure: 0
LV_diastolic: 0
RV_volume_overload: -3
RV_wall: -3
RV_pressure_overload: -3

Example 2:
Report: LEFT ATRIUM: Mild LA enlargement. RIGHT ATRIUM/INTERATRIAL SEPTUM: Normal RA size. LEFT VENTRICLE: Normal LV wall thickness. Normal LV cavity size. Depressed LVEF. LV dysnchrony is present. RIGHT VENTRICLE: Normal RV wall thickness. Markedly dilated RV cavity. Severe global RV free wall hypokinesis. Abnormal septal motion/position. AORTA: Normal aortic root diameter. Normal ascending aorta diameter. AORTIC VALVE: Normal aortic valve leaflets (3). No AS. MITRAL VALVE: Mildly thickened mitral valve leaflets. TRICUSPID VALVE: Moderate to severe [3+] TR. Moderate PA systolic hypertension. PERICARDIUM: No pericardial effusion.

Response:
LA_cavity: 1
RA_dilated: 0
LV_systolic: -2
LV_cavity: 0
LV_wall: 0
RV_cavity: 2
RV_systolic: 3
AV_stenosis: 0
MV_stenosis: 0
TV_regurgitation: 0
TV_stenosis: 0
TV_pulm_htn: 2
AV_regurgitation: 0
MV_regurgitation: 0
RA_pressure: 0
LV_diastolic: 0
RV_volume_overload: 0
RV_wall: 0
RV_pressure_overload: 0

Example 3:
Report: LEFT ATRIUM: The left atrium is moderately dilated. RIGHT ATRIUM/INTERATRIAL SEPTUM: The right atrium is moderately dilated. LEFT VENTRICLE: There is severe symmetric left ventricular hypertrophy. The left ventricular cavity size is normal. Overall left ventricular systolic function is severely depressed. RIGHT VENTRICLE: Right ventricular chamber size and free wall motion are normal. AORTA: The aortic root is normal in diameter. AORTIC VALVE: The aortic valve leaflets (3) are mildly thickened. Mild to moderate ([**2-12**]+) aortic regurgitation is seen. MITRAL VALVE: The mitral valve leaflets are mildly thickened. There is mild mitral annular calcification. Mild (1+) mitral regurgitation is seen. TRICUSPID VALVE: Mild tricuspid [1+] regurgitation is seen.

Response:
LA_cavity: 2
RA_dilated: 1
LV_systolic: 3
LV_cavity: 0
LV_wall: 3
RV_cavity: 0
RV_systolic: 0
AV_stenosis: 0
MV_stenosis: 0
TV_regurgitation: 0
TV_stenosis: 0
TV_pulm_htn: 0
AV_regurgitation: 0
MV_regurgitation: 1
RA_pressure: 0
LV_diastolic: 0
RV_volume_overload: 0
RV_wall: 0
RV_pressure_overload: 0

**Now analyze this report:**
Report:
{text}
<end_of_turn>
<start_of_turn>model"""

In [24]:
test_df['gemma_long_prompt'] = test_df['text'].apply(create_structured_prompt)

In [25]:
# Here is a sample of what the prompt looks like
print(test_df['gemma_long_prompt'].iloc[0])

<start_of_turn>user
As a cardiologist, analyze this echocardiogram report and provide a structured assessment.

For each feature:
- LA_cavity: left atrial cavity size
- RA_dilated: right atrial dilation
- LV_systolic: left ventricular systolic function
- LV_cavity: left ventricular cavity size
- LV_wall: left ventricular wall size
- RV_cavity: right ventricular cavity size
- RV_systolic: right ventricular systolic function
- AV_stenosis: atrial virus stenoses
- MV_stenosis: mitral valve stenoses
- TV_regurgitation: tricuspid valve regurgitation
- TV_stenosis: tricuspid valve stenoses
- TV_pulm_htn: tricuspid valve pulmonary hypertension
- AV_regurgitation: atrial virus regurgitation
- MV_regurgitation: mitral valve regurgitation
- RA_pressure: right atrial pressure
- LV_diastolic: left ventricular diastolic function
- RV_volume_overload: right ventricular volume overload
- RV_wall: right ventricular wall thickness
- RV_pressure_overload: right ventricular pressure overload

Use this co

In [12]:
# ============================================================================
# MODEL LOADING: Initialize Gemma 2B instruction-tuned model
# ============================================================================

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

# Load Gemma model
model_name = 'google/gemma-2b-it' # 2B parameter instruction-tuned variant
# Load tokenizer (converts text to tokens the model understands)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with optimizations:
# - torch_dtype=torch.bfloat16: Use bfloat16 precision (saves memory, faster inference)
# - device_map="auto": Automatically distribute model across available GPUs
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print("Model loaded successfully!")

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Model loaded successfully!


In [26]:
# ============================================================================
# BATCH INFERENCE: Run model predictions in batches for efficiency
# ============================================================================

def run_gemma_batch(prompts_list, batch_size=8):
    """Run multiple prompts in a batch."""
    results = []

    for i in range(0, len(prompts_list), batch_size):
        batch = prompts_list[i:i+batch_size]

        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        ).to("cuda")

        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode each output in the batch
        for j, output in enumerate(outputs):
            response = tokenizer.decode(
                output[inputs['input_ids'].shape[1]:],
                skip_special_tokens=True
            )
            results.append(response)

    return results

In [27]:

# ============================================================================
# RUN INFERENCE: Process all echo reports through Gemma model
# ============================================================================
# Process in batches
batch_size = 64   # Adjust based on your GPU memory (64 works on Colab high-RAM)
prompts_list = test_df['gemma_long_prompt'].tolist()
results = []

# Run inference with progress bar
for i in tqdm(range(0, len(prompts_list), batch_size)):
    batch = prompts_list[i:i+batch_size]
    batch_results = run_gemma_batch(batch, batch_size=batch_size)
    results.extend(batch_results)
# Store raw model outputs
test_df['gemma_long_result'] = results

100%|██████████| 104/104 [15:20<00:00,  8.85s/it]


In [28]:
from datetime import datetime
# Save predictions to avoid re-running inference
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f'test_df_with_gemma_long_predictions_{timestamp}.csv'
test_df.to_csv(filename, index=False)

In [30]:
def calculate_accuracy(df):
    """Calculate accuracy treating unparseable predictions as WRONG."""
    accuracy_results = []

    for i, label_name in enumerate(LABEL_NAMES):
        true_vals = df['labels_parsed'].apply(
            lambda x: x[i] if isinstance(x, list) and i < len(x) else None
        ).values

        pred_vals = df['pred_labels'].apply(
            lambda x: x[i] if isinstance(x, list) and i < len(x) else None
        ).values

        valid_true_mask = ~pd.isna(true_vals)
        true_vals_all = true_vals[valid_true_mask]
        pred_vals_all = pred_vals[valid_true_mask]

        correct = sum(1 for tv, pv in zip(true_vals_all, pred_vals_all)
                     if pd.notna(pv) and tv == pv)

        total = len(true_vals_all)
        accuracy = correct / total if total > 0 else 0
        failed = sum(pd.isna(pred_vals_all))

        accuracy_results.append({
            'label': label_name,
            'correct': correct,
            'total': total,
            'accuracy': accuracy,
            'failed': failed
        })

    exact_matches = sum(
        1 for idx in range(len(df))
        if (df.iloc[idx]['pred_labels'] is not None and
            df.iloc[idx]['labels_parsed'] == df.iloc[idx]['pred_labels'])
    )

    return accuracy_results, exact_matches

In [36]:
def parse_medical_report(report_text):
    """
    Extract label:value pairs from model's text output.

    Example input: "LA_cavity: 0\nRA_dilated: 1\n..."
    Example output: {'LA_cavity': 0, 'RA_dilated': 1, ...}
    """
    pattern = r'\*?\*?([A-Z_a-z]+)\s*:\s*(\d+)\*?\*?'
    matches = re.findall(pattern, str(report_text))
    parsed_data = {}
    for key, value in matches:
        if key in LABEL_NAMES:
            parsed_data[key] = int(value)
    return parsed_data

def parse_to_list(parsed_data):
    """Convert parsed dictionary to ordered list matching LABEL_NAMES."""
    return [parsed_data.get(key, None) for key in LABEL_NAMES]

# Process predictions: Convert text outputs to numeric lists
test_df['gemma_long_result_list'] = test_df['gemma_long_result'].apply(
    lambda x: parse_to_list(parse_medical_report(str(x))) if pd.notna(x) else []
)

In [37]:
# ============================================================================
# PERFORMANCE EVALUATION: Parse model outputs and calculate accuracy
# ============================================================================

# Load your results
#test_df = pd.read_csv('/content/drive/MyDrive/echo_training/test_df_with_gemma_predictions.csv')
#test_df = test_df.rename(columns={test_df.columns[0]: 'id_num'})

# ----------------------------------------------------------------------------
# STEP 1: Parse raw model outputs into structured predictions
# We are using regular expressions because the Gemma output is
# not consistent.
# ----------------------------------------------------------------------------

def parse_medical_report(report_text):
    """
    Extract label:value pairs from model's text output.

    Example input: "LA_cavity: 0\nRA_dilated: 1\n..."
    Example output: {'LA_cavity': 0, 'RA_dilated': 1, ...}
    """
    pattern = r'\*?\*?([A-Z_a-z]+)\s*:\s*(\d+)\*?\*?'
    matches = re.findall(pattern, str(report_text))
    parsed_data = {}
    for key, value in matches:
        if key in LABEL_NAMES:
            parsed_data[key] = int(value)
    return parsed_data

def parse_to_list(parsed_data):
    """Convert parsed dictionary to ordered list matching LABEL_NAMES."""
    return [parsed_data.get(key, None) for key in LABEL_NAMES]

# Process predictions: Convert text outputs to numeric lists
test_df['gemma_long_result_list'] = test_df['gemma_long_result'].apply(
    lambda x: parse_to_list(parse_medical_report(str(x))) if pd.notna(x) else []
)

# ----------------------------------------------------------------------------
# STEP 2: Calculate accuracy metrics
# ----------------------------------------------------------------------------

def calculate_accuracy(df):
    """
    Compare model predictions against ground truth labels.

    Returns:
        Dictionary with overall accuracy and per-label breakdown
    """
    results = {
        'overall_accuracy': 0,
        'per_label_accuracy': {},
        'per_label_correct': {},
        'per_label_total': {}
    }

    correct_per_label = {label: 0 for label in LABEL_NAMES}
    total_per_label = {label: 0 for label in LABEL_NAMES}
    total_predictions = 0
    total_correct = 0

    for idx, row in df.iterrows():
        labels = row['labels_parsed']  # Ground truth
        predictions = row['gemma_long_result_list']  # Model predictions

        # Parse labels if stored as string
        if isinstance(labels, str):
            try:
                labels = ast.literal_eval(labels)
            except:
                continue

        # Skip invalid rows
        if labels is None or predictions is None or len(labels) == 0 or len(predictions) == 0:
            continue

        # Compare each label position
        for i, label_name in enumerate(LABEL_NAMES):
            if i >= min(len(labels), len(predictions)):
                break

            label_val = labels[i]
            pred_val = predictions[i]

            if label_val is None or pred_val is None:
                continue

            # Normalize to integers for comparison
            try:
                label_val = int(float(label_val))
                pred_val = int(float(pred_val))
            except (ValueError, TypeError):
                continue

            total_per_label[label_name] += 1
            total_predictions += 1

            # Check if prediction matches ground truth
            if label_val == pred_val:
                correct_per_label[label_name] += 1
                total_correct += 1

    # Calculate overall accuracy
    if total_predictions > 0:
        results['overall_accuracy'] = total_correct / total_predictions

    # Calculate per-label accuracy
    for label in LABEL_NAMES:
        if total_per_label[label] > 0:
            results['per_label_accuracy'][label] = correct_per_label[label] / total_per_label[label]
            results['per_label_correct'][label] = correct_per_label[label]
            results['per_label_total'][label] = total_per_label[label]
        else:
            results['per_label_accuracy'][label] = None
            results['per_label_correct'][label] = 0
            results['per_label_total'][label] = 0

    return results

# ----------------------------------------------------------------------------
# STEP 3: Display results
# ----------------------------------------------------------------------------

# Run accuracy calculation
accuracy_results = calculate_accuracy(test_df)

# Display results
print("=" * 80)
print("GEMMA MODEL PERFORMANCE ON ECHO NOTES")
print("=" * 80)
print(f"\nOverall Accuracy: {accuracy_results['overall_accuracy']:.4f} ({accuracy_results['overall_accuracy']*100:.2f}%)")
print("\n" + "-" * 80)
print("Per-Label Performance:")
print("-" * 80)

# Create summary DataFrame
summary_df = pd.DataFrame({
    'Label': LABEL_NAMES,
    'Accuracy': [accuracy_results['per_label_accuracy'][label] if accuracy_results['per_label_accuracy'][label] is not None else 0
                 for label in LABEL_NAMES],
    'Correct': [accuracy_results['per_label_correct'][label] for label in LABEL_NAMES],
    'Total': [accuracy_results['per_label_total'][label] for label in LABEL_NAMES]
})
summary_df['Percentage'] = summary_df['Accuracy'] * 100

print(summary_df.to_string(index=False))
print("=" * 80)

GEMMA MODEL PERFORMANCE ON ECHO NOTES

Overall Accuracy: 0.7300 (73.00%)

--------------------------------------------------------------------------------
Per-Label Performance:
--------------------------------------------------------------------------------
               Label  Accuracy  Correct  Total  Percentage
           LA_cavity  0.245293      951   3877   24.529275
          RA_dilated  0.610608     2360   3865   61.060802
         LV_systolic  0.385348     1457   3781   38.534779
           LV_cavity  0.870743     3375   3876   87.074303
             LV_wall  0.423335     1640   3874   42.333505
           RV_cavity  0.563017     1702   3023   56.301687
         RV_systolic  0.693676     2106   3036   69.367589
         AV_stenosis  0.872450     3379   3873   87.245030
         MV_stenosis  0.984766     3814   3873   98.476633
    TV_regurgitation  0.962561     3728   3873   96.256132
         TV_stenosis  0.972344     3762   3869   97.234428
         TV_pulm_htn  0.500000   

In [34]:

accuracy_results = calculate_accuracy(test_df)

In [35]:
accuracy_results

{'overall_accuracy': 0.1457091003370495,
 'per_label_accuracy': {'LA_cavity': None,
  'RA_dilated': 0.0,
  'LV_systolic': 0.0,
  'LV_cavity': None,
  'LV_wall': None,
  'RV_cavity': 0.0,
  'RV_systolic': 0.0,
  'AV_stenosis': None,
  'MV_stenosis': None,
  'TV_regurgitation': None,
  'TV_stenosis': None,
  'TV_pulm_htn': None,
  'AV_regurgitation': 0.12475247524752475,
  'MV_regurgitation': 0.17701641684511063,
  'RA_pressure': None,
  'LV_diastolic': 0.005504587155963303,
  'RV_volume_overload': None,
  'RV_wall': None,
  'RV_pressure_overload': None},
 'per_label_correct': {'LA_cavity': 0,
  'RA_dilated': 0,
  'LV_systolic': 0,
  'LV_cavity': 0,
  'LV_wall': 0,
  'RV_cavity': 0,
  'RV_systolic': 0,
  'AV_stenosis': 0,
  'MV_stenosis': 0,
  'TV_regurgitation': 0,
  'TV_stenosis': 0,
  'TV_pulm_htn': 0,
  'AV_regurgitation': 63,
  'MV_regurgitation': 496,
  'RA_pressure': 0,
  'LV_diastolic': 3,
  'RV_volume_overload': 0,
  'RV_wall': 0,
  'RV_pressure_overload': 0},
 'per_label_total'