<a href="https://colab.research.google.com/github/christinium/Health/blob/main/echo_note_training_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Echo Note Training with Gemma LLM
This notebook demonstrates how to fine-tune the Gemma language model on echocardiogram notes. The model will learn to analyze echo reports and provide structured assessments across 19 different cardiac features.

## Overview
1. Set up Google Drive and dependencies
2. Load and prepare echo note data
3. Configure data labels and formatting
4. Clean and preprocess the dataset
5. Create train/tune/test splits
6. Convert to HuggingFace format
7. Configure and train the model
8. Evaluate results

## 1. Set Up Google Drive
First, let's mount Google Drive to access our data and save model checkpoints. This step is essential for persisting our data and model files across Colab sessions.

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up working directory
import os
WORKING_DIR = '/content/drive/MyDrive/echo_training/'  # Change this to your preferred location
os.makedirs(WORKING_DIR, exist_ok=True)
os.chdir(WORKING_DIR)

print(f"Working directory set to: {WORKING_DIR}")
print("Contents:", os.listdir())

Mounted at /content/drive
Working directory set to: /content/drive/MyDrive/echo_training/
Contents: ['echo_dataset.csv', 'gemma_echo_finetuned', 'echo_dataset_trunc.csv', 'echo_tune.csv', 'echo_train.csv', 'echo_test.csv']


## 2. Install Required Dependencies
Install and import all necessary libraries for the project. We'll need:
- transformers: For the Gemma model and training
- datasets: For data handling
- pandas & numpy: For data manipulation
- sklearn: For data splitting
- torch: For deep learning operations

In [2]:
# Install required packages
!pip install transformers datasets torch numpy pandas scikit-learn bitsandbytes -q

# Import libraries
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
import ast
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, DataCollatorForLanguageModeling, TrainingArguments

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25hPyTorch version: 2.8.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


## 3. Load and Prepare Echo Data
Load your echo data from the CSV file. Make sure your data file is uploaded to the working directory in Google Drive.

In [3]:
# Load your CSV file
df = pd.read_csv('/content/drive/MyDrive/echo_training/echo_dataset_trunc.csv')

print("Data loaded successfully!")
print(f"Total samples: {len(df)}")
print("\nColumns in the dataset:")
print(df.columns.tolist())
print("\nFirst few rows:")
print(df.head(2).to_string())

Data loaded successfully!
Total samples: 45794

Columns in the dataset:
['Unnamed: 0', 'text', 'labels', 'text_trunc', 'text_trunc_len']

First few rows:
   Unnamed: 0                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 

In [4]:
#Format the df to fewer columns

# Drop the original 'text' column
df = df.drop(columns=['text'])

# Rename 'text_trunc' to 'text'
df = df.rename(columns={'text_trunc': 'text'})

# Remove the 'text_trunc_len' column
df = df.drop(columns=['text_trunc_len'])

print("DataFrame modified successfully!")
print("\nUpdated columns in the dataset:")
print(df.columns.tolist())
print("\nFirst few rows after modification:")
display(df.head(2))

DataFrame modified successfully!

Updated columns in the dataset:
['Unnamed: 0', 'labels', 'text']

First few rows after modification:


Unnamed: 0.1,Unnamed: 0,labels,text
0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -3, -3, 0, 0, ...",LEFT ATRIUM: The left atrium is normal in size...
1,1,"[1, 0, 0, 0, 1, 0, -3, 0, 0, 0, 0, -3, 0, 1, 0...",LEFT ATRIUM: Mild LA enlargement.\n\nLEFT VENT...


## 4. Configure Data Labels
Set up the 19 cardiac feature labels and create functions to parse the label data. These labels represent different aspects of cardiac function that we'll be predicting.

In [5]:
# Define label names
LABEL_NAMES = [
    'LA_cavity', 'RA_dilated', 'LV_systolic', 'LV_cavity',
    'LV_wall', 'RV_cavity', 'RV_systolic', 'AV_stenosis',
    'MV_stenosis', 'TV_regurgitation', 'TV_stenosis',
    'TV_pulm_htn', 'AV_regurgitation', 'MV_regurgitation',
    'RA_pressure', 'LV_diastolic', 'RV_volume_overload',
    'RV_wall', 'RV_pressure_overload'
]

def parse_labels(label_str):
    """Convert string representation of list to actual list."""
    if isinstance(label_str, str):
        return ast.literal_eval(label_str)
    elif isinstance(label_str, list):
        return label_str
    else:
        return label_str

# Parse labels
df['labels_parsed'] = df['labels'].apply(parse_labels)

# Verify labels
assert all(len(labels) == 19 for labels in df['labels_parsed']), \
    "Not all label arrays have 19 values!"

print("\nLabels parsed successfully!")
print(f"Example labels: {df['labels_parsed'].iloc[0]}")

# Display label distribution
label_counts = pd.DataFrame([
    [label, sum(x[i] for x in df['labels_parsed'])]
    for i, label in enumerate(LABEL_NAMES)
], columns=['Label', 'Count'])

print("\nLabel Distribution:")
print(label_counts)


Labels parsed successfully!
Example labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -3, -3, 0, 0, 0, 0, 0, 0, 0]

Label Distribution:
                   Label  Count
0              LA_cavity -22504
1             RA_dilated  13251
2            LV_systolic   8901
3              LV_cavity   2829
4                LV_wall  12501
5              RV_cavity -16274
6            RV_systolic -18107
7            AV_stenosis   7572
8            MV_stenosis    261
9       TV_regurgitation -22032
10           TV_stenosis  -6320
11           TV_pulm_htn -12676
12      AV_regurgitation   6303
13      MV_regurgitation -30278
14           RA_pressure    921
15          LV_diastolic  -5736
16    RV_volume_overload  -3504
17               RV_wall  -4190
18  RV_pressure_overload  -3259


In [6]:
df.head()

Unnamed: 0.1,Unnamed: 0,labels,text,labels_parsed
0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -3, -3, 0, 0, ...",LEFT ATRIUM: The left atrium is normal in size...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -3, -3, 0, 0, ..."
1,1,"[1, 0, 0, 0, 1, 0, -3, 0, 0, 0, 0, -3, 0, 1, 0...",LEFT ATRIUM: Mild LA enlargement.\n\nLEFT VENT...,"[1, 0, 0, 0, 1, 0, -3, 0, 0, 0, 0, -3, 0, 1, 0..."
2,2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, ...",LEFT VENTRICLE: Normal regional LV systolic fu...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, ..."
3,3,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 3, 1, 0, ...",This study was compared to the report of the p...,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 3, 1, 0, ..."
4,4,"[0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",LEFT ATRIUM: Normal LA and RA cavity sizes.\n\...,"[0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


## 5. Format Data for Training
Create a formatting function to prepare the data for Gemma. We'll structure the input as a prompt with the echo report and expected output format.

In [7]:
def format_echo_prompt(row):
    """Format echo report into Gemma instruction format."""
    input_text = row['text']  # Change 'input' to your actual column name
    labels = row['labels_parsed']

    # Create formatted label string
    label_pairs = [f"{LABEL_NAMES[i]}: {labels[i]}" for i in range(19)]
    label_text = "\n".join(label_pairs)

    prompt = f"""<start_of_turn>user
Analyze this echocardiogram report and provide assessment values for each cardiac feature. Output should be in the format "feature: value" for each of the 19 features.

Report:
{input_text}<end_of_turn>
<start_of_turn>model
{label_text}<end_of_turn>"""

    return prompt

# Apply formatting
df['formatted_text'] = df.apply(format_echo_prompt, axis=1)

print(df['formatted_text'].iloc[0])
#print("FORMATTED EXAMPLE:")
#print("="*70)
#print(df['text'].iloc[0])
#print("="*70)

<start_of_turn>user
Analyze this echocardiogram report and provide assessment values for each cardiac feature. Output should be in the format "feature: value" for each of the 19 features.

Report:
LEFT ATRIUM: The left atrium is normal in size.

RIGHT ATRIUM/INTERATRIAL SEPTUM: The right atrium is normal in size.

LEFT VENTRICLE: Left ventricular wall thickness, cavity size, and systolic
function are normal (LVEF>55%). Regional left ventricular wall motion is
normal.

RIGHT VENTRICLE: Right ventricular chamber size and free wall motion are
normal.

AORTA: The aortic root is normal in diameter.

AORTIC VALVE: The aortic valve leaflets (3) appear structurally normal with
good leaflet excursion.

MITRAL VALVE: The mitral valve leaflets are structurally normal.

TRICUSPID VALVE: The tricuspid valve is not well visualized.

PERICARDIUM: There is no pericardial effusion.<end_of_turn>
<start_of_turn>model
LA_cavity: 0
RA_dilated: 0
LV_systolic: 0
LV_cavity: 0
LV_wall: 0
RV_cavity: 0
RV_systol

In [8]:
# Define label names


def parse_labels(label_str):
    """Convert string representation of list to actual list."""
    if isinstance(label_str, str):
        return ast.literal_eval(label_str)
    elif isinstance(label_str, list):
        return label_str
    else:
        return label_str

# Parse labels
df['labels_parsed'] = df['labels'].apply(parse_labels)

# Verify labels
assert all(len(labels) == 19 for labels in df['labels_parsed']), \
    "Not all label arrays have 19 values!"

print("\nLabels parsed successfully!")
print(f"Example labels: {df['labels_parsed'].iloc[0]}")

# Display label distribution
label_counts = pd.DataFrame([
    [label, sum(x[i] for x in df['labels_parsed'])]
    for i, label in enumerate(LABEL_NAMES)
], columns=['Label', 'Count'])

print("\nLabel Distribution:")
print(label_counts)


Labels parsed successfully!
Example labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, -3, -3, -3, 0, 0, 0, 0, 0, 0, 0]

Label Distribution:
                   Label  Count
0              LA_cavity -22504
1             RA_dilated  13251
2            LV_systolic   8901
3              LV_cavity   2829
4                LV_wall  12501
5              RV_cavity -16274
6            RV_systolic -18107
7            AV_stenosis   7572
8            MV_stenosis    261
9       TV_regurgitation -22032
10           TV_stenosis  -6320
11           TV_pulm_htn -12676
12      AV_regurgitation   6303
13      MV_regurgitation -30278
14           RA_pressure    921
15          LV_diastolic  -5736
16    RV_volume_overload  -3504
17               RV_wall  -4190
18  RV_pressure_overload  -3259


## 6. Data Cleaning and Preprocessing
Clean the dataset by removing duplicates, handling missing values, and filtering by text length if needed.

In [9]:
print(f"Before cleaning: {len(df)} samples")

# Remove any rows with missing data
df = df.dropna(subset=['text'])

# Remove duplicates based on input text
df = df.drop_duplicates(subset=['text'])

## Analyze text length
df['text_length'] = df['text'].str.len()
print(f"\nText length stats:")
print(df['text_length'].describe())
#
## Optional: Remove extremely short or long examples
## Uncomment and adjust thresholds as needed
df = df[(df['text_length'] > 100) & (df['text_length'] < 4096)]
#
#df = df.drop(columns=['text_length'])
#
print(f"After cleaning: {len(df)} samples")

Before cleaning: 45794 samples

Text length stats:
count    44155.000000
mean      1052.156268
std        381.791316
min         14.000000
25%        821.000000
50%       1021.000000
75%       1264.000000
max       3312.000000
Name: text_length, dtype: float64
After cleaning: 44047 samples


## 7. Create Dataset Splits
Split the data into training (70%), tuning/validation (15%), and test (15%) sets. We'll save these splits to CSV files in Google Drive for reproducibility.

In [10]:
# First split: separate test set (15%)
train_tune_df, test_df = train_test_split(
    df,
    test_size=0.15,
    random_state=42,
    shuffle=True
)

# Second split: separate train and tune from remaining 85%
train_df, tune_df = train_test_split(
    train_tune_df,
    test_size=0.1765,  # This gives us 15% of original
    random_state=42,
    shuffle=True
)

print("DATASET SPLITS:")
print("="*70)
print(f"Training set:   {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
print(f"Tuning set:     {len(tune_df)} samples ({len(tune_df)/len(df)*100:.1f}%)")
print(f"Test set:       {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
print(f"Total:          {len(df)} samples")

# Save splits to CSV
train_df.to_csv('echo_train.csv', index=False)
tune_df.to_csv('echo_tune.csv', index=False)
test_df.to_csv('echo_test.csv', index=False)

print("\n✓ Splits saved to CSV files in Google Drive:")

DATASET SPLITS:
Training set:   30831 samples (70.0%)
Tuning set:     6608 samples (15.0%)
Test set:       6608 samples (15.0%)
Total:          44047 samples

✓ Splits saved to CSV files in Google Drive:


In [3]:
### Start from this section if you already have the datasplits
import pandas as pd

train_df = pd.read_csv('echo_train.csv')
tune_df = pd.read_csv('echo_tune.csv')
test_df = pd.read_csv('echo_test.csv')


print("Loaded data splits from CSV:")
print(f"Training set:   {len(train_df)} samples")
print(f"Tuning set:     {len(tune_df)} samples")
print(f"Test set:       {len(test_df)} samples")

Loaded data splits from CSV:
Training set:   30831 samples
Tuning set:     6608 samples
Test set:       6608 samples


In [4]:
display(train_df.iloc[1])

Unnamed: 0,1
Unnamed: 0,9958
labels,"[0, 0, 0, 0, 0, -3, -3, 0, 0, 0, 0, 2, 0, 0, 0..."
text,LEFT ATRIUM: Normal LA size.\n\nLEFT VENTRICLE...
labels_parsed,"[0, 0, 0, 0, 0, -3, -3, 0, 0, 0, 0, 2, 0, 0, 0..."
formatted_text,<start_of_turn>user\nAnalyze this echocardiogr...
text_length,889


## 8. Convert to HuggingFace Format
Transform our data into HuggingFace datasets for efficient training. This includes creating a DatasetDict and tokenizing the data.

In [5]:
# Create datasets from the DataFrames
train_dataset = Dataset.from_pandas(train_df[['formatted_text']], preserve_index=False)
tune_dataset = Dataset.from_pandas(tune_df[['formatted_text']], preserve_index=False)
test_dataset = Dataset.from_pandas(test_df[['formatted_text']], preserve_index=False)

# Combine into a DatasetDict
dataset = DatasetDict({
    'train': train_dataset,
    'validation': tune_dataset,
    'test': test_dataset
})

print("HUGGING FACE DATASET:")
print("="*70)
print(dataset)

HUGGING FACE DATASET:
DatasetDict({
    train: Dataset({
        features: ['formatted_text'],
        num_rows: 30831
    })
    validation: Dataset({
        features: ['formatted_text'],
        num_rows: 6608
    })
    test: Dataset({
        features: ['formatted_text'],
        num_rows: 6608
    })
})


## 9. Configure Model and Tokenizer
Set up the Gemma model and tokenizer. We'll use the 2B parameter version for prototyping, but you can switch to 9B for production.

In [6]:
# Load model and tokenizer
model_name = 'google/gemma-2b-it'
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Set padding side here
tokenizer.model_max_length = 8192  # Fix the corrupted max length

# Check token length
sample_text = "a" * 3000
sample_tokens = tokenizer(sample_text, return_length=True)
print(f"~3000 chars ≈ {sample_tokens['length'][0]} tokens")
print(f"Gemma-2b max context: {tokenizer.model_max_length} tokens\n")

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# Tokenization function
def tokenize_function(examples):
    """Tokenize the text data for causal LM."""
    tokenized = tokenizer(
        examples["formatted_text"],
        truncation=True,
        max_length=2048,
        padding=False,
    )
    # Don't add labels here - let the data collator handle it
    return tokenized

# Then update your data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    # The collator will automatically create labels from input_ids
)

# Tokenize all splits ONCE
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing datasets",
)



print("TOKENIZED DATASETS:")
print("="*70)
print(tokenized_datasets)

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

~3000 chars ≈ 376 tokens
Gemma-2b max context: 8192 tokens



config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Tokenizing datasets:   0%|          | 0/30831 [00:00<?, ? examples/s]

Tokenizing datasets:   0%|          | 0/6608 [00:00<?, ? examples/s]

Tokenizing datasets:   0%|          | 0/6608 [00:00<?, ? examples/s]

TOKENIZED DATASETS:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 30831
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 6608
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 6608
    })
})


9.5. Quick Data Sanity Check

In [8]:
# Data Sanity Check
print("Data Sanity Check:")
sample = tokenized_datasets['train'][0]
print(f"Sample keys: {sample.keys()}")
print(f"Input IDs length: {len(sample['input_ids'])}")
print(f"First 20 tokens: {sample['input_ids'][:20]}")
print(f"Labels exist: {len(sample['labels'])}")
print(f"First 20 labels: {sample['labels'][:20]}")
print(f"\nDecoded first 200 chars:\n{tokenizer.decode(sample['input_ids'][:1000])}")

Data Sanity Check:
Sample keys: dict_keys(['input_ids', 'attention_mask'])
Input IDs length: 479
First 20 tokens: [2, 106, 1645, 108, 124082, 736, 214509, 1899, 3484, 578, 3658, 11449, 4035, 604, 1853, 41821, 6268, 235265, 16230, 1412]


KeyError: 'labels'

## 10. Configure Training
Set up the training arguments and create the trainer. We'll use settings optimized for medical domain training.

In [9]:
#Initial config (#1)
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Simple training arguments
training_args = TrainingArguments(
    output_dir="./gemma_echo_finetuned",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_steps=500,

    bf16=True,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    dataloader_num_workers=2,  # Add parallel data loading
    report_to="none",
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)

print("\n✓ Trainer configured!")
print(f"Training on:   {len(tokenized_datasets['train'])} samples")
print(f"Validating on: {len(tokenized_datasets['validation'])} samples")


✓ Trainer configured!
Training on:   30831 samples
Validating on: 6608 samples


In [18]:
### Trying new Training
# Create data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

# Improved training arguments
training_args = TrainingArguments(
    output_dir="./gemma_echo_finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,      # Effective batch size of 8
    learning_rate=2e-4,
    weight_decay=0.01,                  # Regularization
    warmup_steps=500,                   # Learning rate warmup

    bf16=True,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",  # Specify metric
    greater_is_better=False,            # Lower loss is better
    save_total_limit=2,                 # Keep only 2 checkpoints
    report_to="none",
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)

print("\n✓ Trainer configured!")
print(f"Training on:   {len(tokenized_datasets['train'])} samples")
print(f"Validating on: {len(tokenized_datasets['validation'])} samples")


✓ Trainer configured!
Training on:   30831 samples
Validating on: 6608 samples


In [7]:
## Trying 3rd training
# Optimized training arguments
training_args = TrainingArguments(
    output_dir="./gemma_echo_finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=8,       # Increased from 2
    gradient_accumulation_steps=4,       # Effective batch size = 32
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_steps=500,

    bf16=True,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    dataloader_num_workers=2,
    report_to="none",
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
)


## 11. Model Training and Evaluation
Train the model and evaluate its performance. We'll also include a helper function for final evaluation on the test set.

In [8]:
# Training
print("Starting training...")
trainer.train()

Starting training...


Epoch,Training Loss,Validation Loss
1,0.1424,0.141481
2,0.1175,0.118183
3,0.1012,0.113321


There were missing keys in the checkpoint model loaded: ['lm_head.weight'].


TrainOutput(global_step=2892, training_loss=0.16087821707191308, metrics={'train_runtime': 6480.524, 'train_samples_per_second': 14.272, 'train_steps_per_second': 0.446, 'total_flos': 6.657311435474657e+17, 'train_loss': 0.16087821707191308, 'epoch': 3.0})

In [9]:


def evaluate_on_test_set():
    """Evaluate the fine-tuned model on the held-out test set."""
    print("\nEVALUATING ON TEST SET")
    print("="*70)

    test_results = trainer.evaluate(tokenized_datasets['test'])

    print("\nTest Set Results:")
    for key, value in test_results.items():
        print(f"  {key}: {value:.4f}")

    return test_results

# Evaluate on test set
test_results = evaluate_on_test_set()

# Save the final model
output_dir = "final_model"
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print("\n✓ Model and tokenizer saved to:", output_dir)


EVALUATING ON TEST SET



Test Set Results:
  eval_loss: 0.1164
  eval_runtime: 137.3173
  eval_samples_per_second: 48.1220
  eval_steps_per_second: 6.0150
  epoch: 3.0000

✓ Model and tokenizer saved to: final_model


In [10]:
# Save test results to a file
import json

with open('test_results.json', 'w') as f:
    json.dump(test_results, f, indent=2)

# Or as CSV
import pandas as pd
pd.DataFrame([test_results]).to_csv('test_results.csv', index=False)

## Important Notes for Medical Data

1. **Data Balance**
   - Check if your 19 labels are balanced
   - Medical data often has class imbalance
   - Consider stratified splitting if certain conditions are rare

2. **Evaluation Metrics**
   - Loss alone may not be sufficient
   - Consider implementing custom metrics (F1, precision, recall per label)
   - Medical predictions need high precision

3. **Model Size**
   - Gemma 2B is good for prototyping
   - Consider Gemma 9B for production if accuracy is critical
   
4. **Training Tips**
   - Monitor validation loss closely
   - Stop if validation loss stops decreasing
   - Medical domain may need more epochs (5-10)
   - Lower learning rate is safer for specialized domains

5. **Prompt Engineering**
   - Current format outputs all 19 values at once
   - Consider one value at a time for higher reliability
   - Add medical context in prompts if needed

Remember to validate the model thoroughly before any clinical use!