<a href="https://colab.research.google.com/github/christinium/Health/blob/main/EchoProject/Echo_Fine_Tunning_vs_Prompt_Stats.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Echocardiogram Analysis: Comparing Fine-tuned vs Prompt-only Models

## Purpose
This notebook compares two approaches for extracting structured labels from
echocardiogram reports using the Gemma language model:

1. **Fine-tuned Model**: Gemma model fine-tuned on labeled echo reports
2. **Prompt-only Model**: Base Gemma-2B-it model using detailed prompting

## Task
Both models extract 19 cardiac features from unstructured echo reports and
classify each feature using a standardized coding schema described in prior notebooks.

## Evaluation Method: Strict Evaluation
This notebook uses **strict evaluation**, where:
- Failed predictions (unparseable outputs) count as **incorrect**
- All predictions are evaluated, not just successful parses
- This provides a fair comparison reflecting real-world deployment scenarios

## Key Findings

### Fine-tuned Model Performance
- **Average Per-Label Accuracy**: ~99.8%
- **Exact Match Accuracy**: ~98.1% (all 19 labels correct)
- **Failed Predictions**: <0.1%
- The fine-tuned model consistently produces valid, structured output

### Prompt-only Model Performance  
- **Average Per-Label Accuracy**: ~15-20%
- **Exact Match Accuracy**: <1%
- **Failed Predictions**: ~50%
- The prompt-only model frequently fails to produce parseable output despite
  detailed instructions and medical context

### Impact of Fine-tuning
Fine-tuning improved accuracy by **~80 percentage points** and reduced failed
predictions from 50% to nearly 0%. This demonstrates that for structured
extraction tasks from specialized medical text, fine-tuning is essential—prompt
engineering alone is insufficient.

## Files Required
- `test_inference_results_batch_run.csv` - Fine-tuned model predictions
- `test_df_with_gemma_predictions.csv` - Prompt with description only predictions
- `test_df_with_gemma_long_predictions_20251005_013525.csv`  - Prompt with description and examples

All files should contain ground truth labels for comparison.

In [2]:
import pandas as pd
import numpy as np
import ast
import re
from datetime import datetime


In [3]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up working directory
import os
WORKING_DIR = '/content/drive/MyDrive/echo_training/'  # Change this to your preferred location
os.makedirs(WORKING_DIR, exist_ok=True)
os.chdir(WORKING_DIR)

Mounted at /content/drive


In [4]:

# ==============================================================================
# CONFIGURATION
# ==============================================================================

# Label definitions
LABEL_NAMES = [
    'LA_cavity', 'RA_dilated', 'LV_systolic', 'LV_cavity',
    'LV_wall', 'RV_cavity', 'RV_systolic', 'AV_stenosis',
    'MV_stenosis', 'TV_regurgitation', 'TV_stenosis',
    'TV_pulm_htn', 'AV_regurgitation', 'MV_regurgitation',
    'RA_pressure', 'LV_diastolic', 'RV_volume_overload',
    'RV_wall', 'RV_pressure_overload'
]


In [5]:
# ==============================================================================
# FUNCTIONS
# ==============================================================================

def load_predictions(model_type):
    """Load predictions for specified model type."""
    if model_type == "fine-tuned":
        df = pd.read_csv('test_inference_results_batch_run.csv')
        return df, 'prediction_text', 'true_labels'
    elif model_type == "gemma-description-only":
        df = pd.read_csv('test_df_with_gemma_predictions.csv')
        df = df.rename(columns={df.columns[0]: 'id_num'})
        return df, 'gemma_result', 'labels'
    elif model_type == "gemma-with-examples":
        df = pd.read_csv('test_df_with_gemma_long_predictions_20251005_013525.csv')
        df = df.rename(columns={df.columns[0]: 'id_num'})
        return df, 'gemma_long_result', 'labels'
    else:
        raise ValueError(f"Unknown model_type: {model_type}. Use 'fine-tuned', 'gemma-description-only', or 'gemma-with-examples'")

In [6]:
def parse_prediction(pred_text: str, model_type: str) -> list:
    """Parse model output into label list."""
    if model_type == "fine-tuned":
        predicted = []
        lines = str(pred_text).split('\n')

        for label_name in LABEL_NAMES:
            found = False
            for line in lines:
                if label_name in line and ':' in line:
                    try:
                        value_str = line.split(':')[1].strip()
                        value = int(value_str)
                        predicted.append(value)
                        found = True
                        break
                    except:
                        pass
            if not found:
                predicted.append(None)
        return predicted
    else:
        pattern = r'\*?\*?([A-Z_a-z]+)\s*:\s*(\d+)\*?\*?'
        matches = re.findall(pattern, str(pred_text))
        parsed_data = {}
        for key, value in matches:
            if key in LABEL_NAMES:
                parsed_data[key] = int(value)
        return [parsed_data.get(label, None) for label in LABEL_NAMES]

In [7]:
def calculate_accuracy(df):
    """Calculate accuracy treating unparseable predictions as WRONG."""
    accuracy_results = []

    for i, label_name in enumerate(LABEL_NAMES):
        true_vals = df['labels_parsed'].apply(
            lambda x: x[i] if isinstance(x, list) and i < len(x) else None
        ).values

        pred_vals = df['pred_labels'].apply(
            lambda x: x[i] if isinstance(x, list) and i < len(x) else None
        ).values

        valid_true_mask = ~pd.isna(true_vals)
        true_vals_all = true_vals[valid_true_mask]
        pred_vals_all = pred_vals[valid_true_mask]

        correct = sum(1 for tv, pv in zip(true_vals_all, pred_vals_all)
                     if pd.notna(pv) and tv == pv)

        total = len(true_vals_all)
        accuracy = correct / total if total > 0 else 0
        failed = sum(pd.isna(pred_vals_all))

        accuracy_results.append({
            'label': label_name,
            'correct': correct,
            'total': total,
            'accuracy': accuracy,
            'failed': failed
        })

    exact_matches = sum(
        1 for idx in range(len(df))
        if (df.iloc[idx]['pred_labels'] is not None and
            df.iloc[idx]['labels_parsed'] == df.iloc[idx]['pred_labels'])
    )

    return accuracy_results, exact_matches

In [8]:
def display_results(model_type, accuracy_results, exact_matches, total_samples):
    """Display results for a model."""
    print("\n" + "="*80)
    print(f"{model_type.upper()} MODEL RESULTS")
    print("="*80)

    for result in accuracy_results:
        failed_str = f" ({result['failed']} failed)" if result['failed'] > 0 else ""
        print(f"{result['label']:20s}: {result['correct']:5d}/{result['total']:5d} = {result['accuracy']:.4f}{failed_str}")

    avg_accuracy = np.mean([r['accuracy'] for r in accuracy_results])
    total_failed = sum(r['failed'] for r in accuracy_results)
    total_predictions = total_samples * len(LABEL_NAMES)

    print("\n" + "-"*80)
    print(f"Average Per-Label Accuracy: {avg_accuracy:.4f} ({avg_accuracy*100:.2f}%)")
    print(f"Exact Match Accuracy: {exact_matches}/{total_samples} = {exact_matches/total_samples:.4f}")
    print(f"Total Failed Predictions: {total_failed}/{total_predictions} ({total_failed/total_predictions*100:.2f}%)")
    print("="*80)


In [11]:


# ==============================================================================
# EVALUATE ALL MODELS
# ==============================================================================

results_comparison = {}

for model_type in ["fine-tuned", "gemma-description-only", "gemma-with-examples"]:
    print(f"\n{'='*80}")
    print(f"Processing {model_type.upper()} model...")
    print('='*80)

    # Load data
    df, pred_col, label_col = load_predictions(model_type)
    print(f"Loaded {len(df)} predictions")

    # Parse predictions
    print("Parsing predictions...")
    df['pred_labels'] = df[pred_col].apply(lambda x: parse_prediction(x, model_type))
    df['labels_parsed'] = df[label_col].apply(
        lambda x: ast.literal_eval(x) if isinstance(x, str) else x
    )

    # Calculate accuracy
    print("Calculating accuracy...")
    accuracy_results, exact_matches = calculate_accuracy(df)

    # Store results
    results_comparison[model_type] = {
        'accuracy_results': accuracy_results,
        'exact_matches': exact_matches,
        'total_samples': len(df)
    }

    # Display results
    display_results(model_type, accuracy_results, exact_matches, len(df))

    # Save individual results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_df = pd.DataFrame(accuracy_results)
    results_df.to_csv(f'accuracy_{model_type}_strict_{timestamp}.csv', index=False)
    print(f"\nSaved to: accuracy_{model_type}_strict_{timestamp}.csv")




Processing FINE-TUNED model...
Loaded 6608 predictions
Parsing predictions...
Calculating accuracy...

FINE-TUNED MODEL RESULTS
LA_cavity           :  6596/ 6608 = 0.9982 (1 failed)
RA_dilated          :  6607/ 6608 = 0.9998 (1 failed)
LV_systolic         :  6598/ 6608 = 0.9985 (1 failed)
LV_cavity           :  6604/ 6608 = 0.9994 (1 failed)
LV_wall             :  6597/ 6608 = 0.9983 (1 failed)
RV_cavity           :  6605/ 6608 = 0.9995 (1 failed)
RV_systolic         :  6601/ 6608 = 0.9989 (1 failed)
AV_stenosis         :  6599/ 6608 = 0.9986 (1 failed)
MV_stenosis         :  6602/ 6608 = 0.9991 (1 failed)
TV_regurgitation    :  6602/ 6608 = 0.9991 (1 failed)
TV_stenosis         :  6606/ 6608 = 0.9997 (1 failed)
TV_pulm_htn         :  6603/ 6608 = 0.9992 (1 failed)
AV_regurgitation    :  6579/ 6608 = 0.9956 (1 failed)
MV_regurgitation    :  6581/ 6608 = 0.9959 (1 failed)
RA_pressure         :  6607/ 6608 = 0.9998 (1 failed)
LV_diastolic        :  6598/ 6608 = 0.9985 (1 failed)
RV_volu

In [13]:
# ==============================================================================
# SIDE-BY-SIDE COMPARISON
# ==============================================================================

print("\n\n" + "="*100)
print("SIDE-BY-SIDE COMPARISON")
print("="*100)
print(f"\n{'Label':<20} {'Fine-tuned':>15} {'Gemma-Base':>15} {'Gemma-Examples':>15} {'Best Model':>15}")
print("-"*100)

for i, label in enumerate(LABEL_NAMES):
    ft_acc = results_comparison['fine-tuned']['accuracy_results'][i]['accuracy']
    gb_acc = results_comparison['gemma-description-only']['accuracy_results'][i]['accuracy']
    ge_acc = results_comparison['gemma-with-examples']['accuracy_results'][i]['accuracy']

    # Find best model for this label
    best_acc = max(ft_acc, gb_acc, ge_acc)
    if best_acc == ft_acc:
        best = "Fine-tuned"
    elif best_acc == gb_acc:
        best = "Gemma-Base"
    else:
        best = "Gemma-Examples"

    print(f"{label:<20} {ft_acc:>14.4f} {gb_acc:>14.4f} {ge_acc:>14.4f} {best:>15}")

print("\n" + "-"*100)

# Calculate averages
ft_avg = np.mean([r['accuracy'] for r in results_comparison['fine-tuned']['accuracy_results']])
gb_avg = np.mean([r['accuracy'] for r in results_comparison['gemma-description-only']['accuracy_results']])
ge_avg = np.mean([r['accuracy'] for r in results_comparison['gemma-with-examples']['accuracy_results']])

# Calculate exact match rates
ft_exact = results_comparison['fine-tuned']['exact_matches'] / results_comparison['fine-tuned']['total_samples']
gb_exact = results_comparison['gemma-description-only']['exact_matches'] / results_comparison['gemma-description-only']['total_samples']
ge_exact = results_comparison['gemma-with-examples']['exact_matches'] / results_comparison['gemma-with-examples']['total_samples']

# Determine best for each metric
best_avg = "Fine-tuned" if ft_avg >= max(gb_avg, ge_avg) else ("Gemma-Base" if gb_avg >= ge_avg else "Gemma-Examples")
best_exact = "Fine-tuned" if ft_exact >= max(gb_exact, ge_exact) else ("Gemma-Base" if gb_exact >= ge_exact else "Gemma-Examples")

print(f"{'Average Accuracy':<20} {ft_avg:>14.4f} {gb_avg:>14.4f} {ge_avg:>14.4f} {best_avg:>15}")
print(f"{'Exact Match Rate':<20} {ft_exact:>14.4f} {gb_exact:>14.4f} {ge_exact:>14.4f} {best_exact:>15}")

print("\n" + "="*100)

# ==============================================================================
# PERFORMANCE DELTAS
# ==============================================================================

print("\n\nPERFORMANCE DELTAS (vs Fine-tuned)")
print("="*80)
print(f"\n{'Label':<20} {'Gemma-Base Δ':>20} {'Gemma-Examples Δ':>20}")
print("-"*80)

for i, label in enumerate(LABEL_NAMES):
    ft_acc = results_comparison['fine-tuned']['accuracy_results'][i]['accuracy']
    gb_diff = results_comparison['gemma-description-only']['accuracy_results'][i]['accuracy'] - ft_acc
    ge_diff = results_comparison['gemma-with-examples']['accuracy_results'][i]['accuracy'] - ft_acc

    print(f"{label:<20} {gb_diff:>+19.4f} {ge_diff:>+19.4f}")

print("\n" + "-"*80)
print(f"{'Average Accuracy Δ':<20} {gb_avg-ft_avg:>+19.4f} {ge_avg-ft_avg:>+19.4f}")
print(f"{'Exact Match Rate Δ':<20} {gb_exact-ft_exact:>+19.4f} {ge_exact-ft_exact:>+19.4f}")
print("="*80)

# ==============================================================================
# PROMPT ENGINEERING IMPACT (Gemma with examples vs Gemma base)
# ==============================================================================

print("\n\nPROMPT ENGINEERING IMPACT (Examples vs Base)")
print("="*80)
print(f"\n{'Label':<20} {'Base':>15} {'With Examples':>15} {'Improvement':>15}")
print("-"*80)

for i, label in enumerate(LABEL_NAMES):
    gb_acc = results_comparison['gemma-description-only']['accuracy_results'][i]['accuracy']
    ge_acc = results_comparison['gemma-with-examples']['accuracy_results'][i]['accuracy']
    improvement = ge_acc - gb_acc

    print(f"{label:<20} {gb_acc:>14.4f} {ge_acc:>14.4f} {improvement:>+14.4f}")

print("\n" + "-"*80)
improvement_avg = ge_avg - gb_avg
improvement_exact = ge_exact - gb_exact

print(f"{'Average Accuracy':<20} {gb_avg:>14.4f} {ge_avg:>14.4f} {improvement_avg:>+14.4f}")
print(f"{'Exact Match Rate':<20} {gb_exact:>14.4f} {ge_exact:>14.4f} {improvement_exact:>+14.4f}")
print("="*80)



SIDE-BY-SIDE COMPARISON

Label                     Fine-tuned      Gemma-Base  Gemma-Examples      Best Model
----------------------------------------------------------------------------------------------------
LA_cavity                    0.9982         0.0802         0.1439      Fine-tuned
RA_dilated                   0.9998         0.0947         0.3571      Fine-tuned
LV_systolic                  0.9985         0.0322         0.2205      Fine-tuned
LV_cavity                    0.9994         0.0699         0.5107      Fine-tuned
LV_wall                      0.9983         0.1270         0.2482      Fine-tuned
RV_cavity                    0.9995         0.1094         0.2576      Fine-tuned
RV_systolic                  0.9989         0.1265         0.3187      Fine-tuned
AV_stenosis                  0.9986         0.0454         0.5113      Fine-tuned
MV_stenosis                  0.9991         0.0944         0.5772      Fine-tuned
TV_regurgitation             0.9991         0.098

In [None]:

# ==============================================================================
# LABEL DISTRIBUTION IN TEST SET
#  This give you an idea of what the distribution of the 19 features are like
#  in the test dataset
# ==============================================================================
test_df = pd.read_csv('echo_test.csv')
print("\n" + "="*70)
print("LABEL DISTRIBUTION IN TEST SET")
print("="*70)

for i, label_name in enumerate(LABEL_NAMES):
    print(f"\n{label_name}:")

    # Extract the i-th value from each label list
    label_values = []
    for idx in range(len(test_df)):
        labels_raw = test_df.iloc[idx]['labels']

        # Parse if string
        if isinstance(labels_raw, str):
            labels = ast.literal_eval(labels_raw)
        else:
            labels = labels_raw

        label_values.append(labels[i])

    # Count values
    value_counts = pd.Series(label_values).value_counts().sort_index()
    null_count = pd.Series(label_values).isna().sum()
    total = len(label_values)

    for value, count in value_counts.items():
        pct = (count/total)*100
        print(f"  {value:>3}: {count:>5} ({pct:>5.1f}%)")
    if null_count > 0:
        pct = (null_count/total)*100
        print(f"  Null: {null_count:>5} ({pct:>5.1f}%)")


LABEL DISTRIBUTION IN TEST SET

LA_cavity:
  -50:    91 (  1.4%)
   -3:    10 (  0.2%)
   -2:   486 (  7.4%)
    0:  4379 ( 66.3%)
    1:  1076 ( 16.3%)
    2:   565 (  8.6%)
    3:     1 (  0.0%)

RA_dilated:
    0:  4563 ( 69.1%)
    1:  2045 ( 30.9%)

LV_systolic:
  -50:    30 (  0.5%)
   -3:    38 (  0.6%)
   -2:   123 (  1.9%)
   -1:   262 (  4.0%)
    0:  4719 ( 71.4%)
    1:   452 (  6.8%)
    2:   385 (  5.8%)
    3:   599 (  9.1%)

LV_cavity:
  -50:     8 (  0.1%)
   -3:    28 (  0.4%)
   -2:    31 (  0.5%)
   -1:   138 (  2.1%)
    0:  5806 ( 87.9%)
    1:   232 (  3.5%)
    2:   292 (  4.4%)
    3:    73 (  1.1%)

LV_wall:
  -50:    11 (  0.2%)
   -3:    26 (  0.4%)
   -2:    30 (  0.5%)
    0:  4418 ( 66.9%)
    1:  1768 ( 26.8%)
    2:   280 (  4.2%)
    3:    75 (  1.1%)

RV_cavity:
  -50:    41 (  0.6%)
   -3:   189 (  2.9%)
   -2:   311 (  4.7%)
   -1:    23 (  0.3%)
    0:  5240 ( 79.3%)
    1:   491 (  7.4%)
    2:   313 (  4.7%)

RV_systolic:
  -50:    48 (  0.7%)
 