---

### **Can your model match real clinicians in rural Kenyan healthcare?**

#### **Challenge Overview**

This challenge simulates the critical, real-world medical decisions made by nurses in Kenyan rural health settings. Participants are provided with **400 authentic clinical vignettes**, each representing a scenario faced by healthcare workers with limited resources. The task is to **predict the clinician's response** to each vignette, effectively replicating the reasoning of trained professionals.

---

### **Dataset Details**

* **400 training** and **100 test samples** of clinical prompts.
* Prompts cover a wide range of domains: **maternal health, child care, critical care, etc.**
* Each prompt contains:

  * Patient presentation
  * Nurse’s experience
  * Facility type
* Responses are real, written by expert clinicians.
* Dataset is small due to the high cost of collecting high-quality, expert-validated clinical data.

---

### **Goal**

Build an AI model that:

* Accurately **predicts clinician responses**.
* Matches the **nuance and reasoning** of real professionals.
* Can perform **well in low-resource settings**.

---

### **Evaluation Metric**

* **ROUGE Score** (measures text overlap with ground truth)
* Responses are normalized (lowercase, punctuation stripped, paragraph replaced with space).

---

### **Submission Format**

A CSV file with two columns:

```
Master_Index     Clinician
ID_XXXXXX        summary a 30 yr old...
```

---

### **Model & Deployment Constraints**

Your solution **must**:

* Be **quantized** for low memory usage.
* Run inference in **< 100ms** per vignette.
* Use **< 2 GB RAM** during inference.
* Use **≤ 1 billion parameters**.
* Train within **24 hours on an NVIDIA T4 GPU**.
* Inference should work on an **NVIDIA Jetson Nano** or similar.

---

### **Prizes**

* 🥇 1st: **\$5,000**
* 🥈 2nd: **\$3,000**
* 🥉 3rd: **\$2,000**
* **5,000 Zindi points** also available.
* Winners will be acknowledged in an upcoming publication.

---

### **Judging Criteria (For Top 10 Finalists)**

You must submit a **video** explaining your solution. Judging is based on:

1. **Clarity of explanation** – 25%
2. **Insights/feature engineering** – 15%
3. **Real-world applicability** – 25%
4. **Novelty and real-world constraints** – 25%
5. **Clean, readable code** – 10%

---

### **Rules & Requirements**

* Use **open-source** tools and libraries only.
* Max **10 submissions/day**, **300 total**.
* Max **4 people per team**.
* Data **cannot be used outside** this competition.
* If ranked in top 10:

  * Submit code within **48 hours** of request.
  * Code must reproduce leaderboard score.
  * Winners must transfer IP rights of the solution to Zindi.

---

### **Code & Reproducibility**

* Code must be:

  * Deterministic (set seeds).
  * Runnable with no paid tools or credit card trials.
  * Free of custom packages.
* If code fails to run or reproduce scores, you will be disqualified from top positions.

---

### **Disqualification Policy**

* **First offence**: 6-month ban from prizes + 2000 point deduction.
* **Second offence**: Permanent ban.

---

### **Leaderboard Mechanics**

* Public leaderboard: \~20–30% of test set.
* Private leaderboard: \~70–80%, revealed at end.
* Final scores and ranks are based on **private leaderboard**.
* Ties are broken by **earliest submission time**.

---

This challenge is a **high-impact opportunity** to build real-world, deployable AI for global healthcare—especially in **resource-constrained environments**.


In [19]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [20]:
# Install required packages
!pip install huggingface_hub[hf_xet] -q
!pip install rouge-score -q
!pip install nlpaug -q



In [21]:
# Place this cell after your challenge description markdown and before data loading

import os
import random
import numpy as np
import pandas as pd
import torch
import re
import gc
import time
import psutil
import matplotlib.pyplot as plt
import seaborn as sns
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq,
    GenerationConfig,
    get_linear_schedule_with_warmup
)
from datasets import Dataset
from rouge_score import rouge_scorer
from torch.quantization import quantize_dynamic
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac

# CONSTANTS
MODEL_NAME = 't5-base'  # Changed to more powerful but efficient model
SEQ2SEQ_MODEL_NAME = MODEL_NAME
SUBMISSION_CSV = 'submission.csv'
RANDOM_SEED = 42
TRAIN_BATCH_SIZE = 4  # Increased for better optimization
EVAL_BATCH_SIZE = 4
NUM_EPOCHS = 50  # Increased with early stopping
MAX_LENGTH = 290  # Increased for better context
GEN_MAX_LENGTH = 290
NUM_BEAMS = 3  # Increased for better generation
EARLY_STOPPING_PATIENCE = 5  # Reduced to prevent overfitting
WARMUP_STEPS = 100  # Added warmup steps
LEARNING_RATE = 1e-5  # Adjusted learning rate

# Set random seed for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

In [22]:

# Paths and constants
DATA_DIR = '/content/drive/MyDrive/Colab Notebooks/Kenya-Challenge-V2/'

TRAIN_CSV = os.path.join(DATA_DIR, 'data/train.csv')
TEST_CSV  = os.path.join(DATA_DIR, 'data/test.csv')
# SAMPLE_SUB = os.path.join(DATA_DIR, 'SampleSubmission.csv')
OUTPUT_DIR = os.path.join(DATA_DIR, 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)

TRAINING_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'training_results')
os.makedirs(TRAINING_OUTPUT_DIR, exist_ok=True)
SUBMISSION_FILE_PATH = os.path.join(OUTPUT_DIR, SUBMISSION_CSV)

In [23]:
# Download necessary NLTK data for nlpaug (SynonymAug)
try:
    print("Checking for required NLTK data...")
    # Downloads are quiet by default unless download_dir is specified
    # Using quiet=True explicitly for clarity
    nltk.download('punkt_tab', quiet=True)
    nltk.download('punkt', quiet=True)
    nltk.download('wordnet', quiet=True)
    nltk.download('averaged_perceptron_tagger_eng', quiet=True)
    nltk.download('averaged_perceptron_tagger', quiet=True)
    print("Required NLTK data is available or downloaded.")
except Exception as e:
    print(f"Error downloading NLTK data: {e}")
    print("Warning: NLTK data download failed. Data augmentation may not work.")


Checking for required NLTK data...
Required NLTK data is available or downloaded.


## Data loading


In [24]:
try:
    df_train = pd.read_csv(TRAIN_CSV)
    df_test = pd.read_csv(TEST_CSV)

    print(f"Shape of df_train: {df_train.shape}")
    print(f"Shape of df_test: {df_test.shape}")

    print("\nColumns of df_train:")
    print(df_train.columns)
    print("\nColumns of df_test:")
    print(df_test.columns)

    print("\nFirst 5 rows of df_train:")
    display(df_train.head())
    print("\nFirst 5 rows of df_test:")
    display(df_test.head())

except FileNotFoundError:
    print("Error: One or both of the CSV files were not found.")
except pd.errors.ParserError:
    print("Error: There was a problem parsing the CSV file(s). Check the file format.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

Shape of df_train: (400, 12)
Shape of df_test: (100, 7)

Columns of df_train:
Index(['Master_Index', 'County', 'Health level', 'Years of Experience',
       'Prompt', 'Nursing Competency', 'Clinical Panel', 'Clinician', 'GPT4.0',
       'LLAMA', 'GEMINI', 'DDX SNOMED'],
      dtype='object')

Columns of df_test:
Index(['Master_Index', 'County', 'Health level', 'Years of Experience',
       'Prompt', 'Nursing Competency', 'Clinical Panel'],
      dtype='object')

First 5 rows of df_train:


Unnamed: 0,Master_Index,County,Health level,Years of Experience,Prompt,Nursing Competency,Clinical Panel,Clinician,GPT4.0,LLAMA,GEMINI,DDX SNOMED
0,ID_VBWWP,uasin gishu,sub county hospitals and nursing homes,18.0,i am a nurse with 18 years of experience in ge...,pediatric emergency burns,surgery,summary a 4 year old with 5 superficial burns ...,given your vast experience as a nurse in uasin...,1 immediate treatment protocol for second degr...,here s a response addressing the questions reg...,288514009 burn involving 5 percent of body sur...
1,ID_XMBBY,uasin gishu,national referral hospitals,17.0,i am a nurse with 17 years of experience in ge...,child health,paediatrics,summary 6 year old present with vomiting and a...,clinical summary • a 6 year old girl with know...,based on the symptoms and signs you ve describ...,based on the presentation the 6 year old girl ...,420270002 ketoacidosis due to type 1 diabetes ...
2,ID_JZNZW,kiambu,sub county hospitals and nursing homes,12.0,i am a nurse with 12 years of experience in ge...,general emergency,internal medicine,summary a 47 year old man presents with severe...,in this case you re dealing with a 47 year old...,firstly i must commend you on your thorough hi...,this 47 year old male presenting with severe r...,13200003 peptic ulcer disorder 25458004 acute ...
3,ID_QOQTK,uasin gishu,national referral hospitals,12.0,i am a nurse with 12 years of experience in pr...,critical care,internal medicine,summary 72 year old female with inability to w...,given er s clinical presentation and vitals th...,to me with this query based on the information...,this 92 year old female patient er presents wi...,14760008 constipation finding 419284004 altere...
4,ID_ZFJBM,uasin gishu,national referral hospitals,16.0,i am a nurse with 16 years of experience in ge...,adult health,internal medicine,a 22 year old female presents with headache di...,the 22 year old female patient is presenting w...,thank you for presenting this case based on th...,this 22 year old female patient presents with ...,95874006 carbon monoxide poisoning from fire d...



First 5 rows of df_test:


Unnamed: 0,Master_Index,County,Health level,Years of Experience,Prompt,Nursing Competency,Clinical Panel
0,ID_CUAOY,uasin gishu,sub county hospitals and nursing homes,2.0,i am a nurse with 2 years of experience in gen...,adult health,surgery ent
1,ID_OGSAY,kiambu,sub county hospitals and nursing homes,22.0,i am a nurse with 22 years of experience in ge...,child health,surgery
2,ID_TYHSA,uasin gishu,national referral hospitals,,i am a nurse working in a national referral ho...,general emergency,internal medicine
3,ID_CZXLD,kakamega,dispensaries and private clinics,,i am a nurse working in a dispensaries and pri...,child health,paediatrics
4,ID_ZJQUQ,kakamega,health centres,,i am a nurse working in a health centres in ka...,child health,paediatrics


## Data exploration

Explore and analyze the loaded datasets.


In [None]:
# 1. Examine data types
print("Data Types in df_train:")
print(df_train.dtypes)
print("\nData Types in df_test:")
print(df_test.dtypes)

# 2. Investigate unique values and distributions for categorical features
categorical_cols = ['County', 'Health level', 'Years of Experience']
for col in categorical_cols:
    print(f"\nUnique values in {col} (df_train):")
    print(df_train[col].unique())
    print(f"\nUnique values in {col} (df_test):")
    print(df_test[col].unique())

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    sns.countplot(x=col, data=df_train)
    plt.title(f'Distribution of {col} (Train)')
    plt.xticks(rotation=45, ha='right')
    plt.subplot(1, 2, 2)
    sns.countplot(x=col, data=df_test)
    plt.title(f'Distribution of {col} (Test)')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

# 3. Analyze text data length
def text_length(text):
    if isinstance(text, str):
        return len(text.split())
    return 0

df_train['Prompt_Length'] = df_train['Prompt'].apply(text_length)
df_train['Clinician_Length'] = df_train['Clinician'].apply(text_length)

print("\nDescriptive statistics for Prompt Length (df_train):")
print(df_train['Prompt_Length'].describe())
plt.figure(figsize=(10, 5))
plt.hist(df_train['Prompt_Length'], bins=20)
plt.xlabel("Prompt Length")
plt.ylabel("Frequency")
plt.title("Distribution of Prompt Lengths (df_train)")
plt.show()

print("\nDescriptive statistics for Clinician Length (df_train):")
print(df_train['Clinician_Length'].describe())
plt.figure(figsize=(10, 5))
plt.hist(df_train['Clinician_Length'], bins=20)
plt.xlabel("Clinician Length")
plt.ylabel("Frequency")
plt.title("Distribution of Clinician Lengths (df_train)")
plt.show()

# 4. Check for missing values
print("\nMissing Values in df_train:")
print(df_train.isnull().sum())
print("\nMissing Values in df_test:")
print(df_test.isnull().sum())

# 5. Analyze 'Clinician' column
print("\nAverage Clinician Response Length:")
print(df_train['Clinician_Length'].mean())


# 6. Compare Prompt Length distributions (train vs. test)
df_test['Prompt_Length'] = df_test['Prompt'].apply(text_length)

plt.figure(figsize=(10, 5))
plt.hist(df_train['Prompt_Length'], bins=20, alpha=0.5, label='Train')
plt.hist(df_test['Prompt_Length'], bins=20, alpha=0.5, label='Test')
plt.xlabel("Prompt Length")
plt.ylabel("Frequency")
plt.title("Distribution of Prompt Lengths (Train vs. Test)")
plt.legend(loc='upper right')
plt.show()

In [None]:
# Display descriptive statistics
print("\nDistribution of Prompt and Clinician Lengths:")
print(df_train[['Prompt_Length', 'Clinician_Length']].describe())

# Plot histograms with KDE
plt.figure(figsize=(12, 6))

# Plot Prompt Lengths
plt.subplot(1, 2, 1)
sns.histplot(df_train['Prompt_Length'], bins=50, kde=True)
plt.axvline(df_train['Prompt_Length'].quantile(0.95), color='red', linestyle='--', label='95th percentile')
plt.title("Distribution of Prompt Lengths")
plt.xlabel("Number of tokens")
plt.ylabel("Frequency")
plt.legend()

# Plot Clinician Response Lengths
plt.subplot(1, 2, 2)
sns.histplot(df_train['Clinician_Length'], bins=50, kde=True)
plt.axvline(df_train['Clinician_Length'].quantile(0.95), color='red', linestyle='--', label='95th percentile')
plt.title("Distribution of Clinician Response Lengths")
plt.xlabel("Number of tokens")
plt.ylabel("Frequency")
plt.legend()

plt.tight_layout()
plt.show()

# Calculate percentiles to help inform max lengths
prompt_percentiles = df_train['Prompt_Length'].quantile([0.5, 0.75, 0.9, 0.95, 0.99])
clinician_percentiles = df_train['Clinician_Length'].quantile([0.5, 0.75, 0.9, 0.95, 0.99])

# Display percentile values
print("\nPrompt Length Percentiles:")
print(prompt_percentiles)

print("\nClinician Length Percentiles:")
print(clinician_percentiles)

# Suggest updated max lengths based on 99th percentile
new_max_length_prompts = int(prompt_percentiles.loc[0.99]) + 10
new_max_length_labels = int(clinician_percentiles.loc[0.99]) + 10

# Align MAX_LENGTH and GEN_MAX_LENGTH with the Clinician response length
updated_MAX_LENGTH = int(clinician_percentiles.loc[0.99]) + 20
updated_GEN_MAX_LENGTH = updated_MAX_LENGTH

print(f"\nSuggested updated MAX_LENGTH for training (based on 99th percentile Clinician Length + buffer): {updated_MAX_LENGTH}")
print(f"Suggested updated GEN_MAX_LENGTH for inference: {updated_GEN_MAX_LENGTH}")

MAX_LENGTH = updated_MAX_LENGTH
GEN_MAX_LENGTH = updated_GEN_MAX_LENGTH


## Data cleaning

Impute missing values in 'Years of Experience', handle the missing value in 'DDX SNOMED' in df_train, and ensure consistency in categorical features across both dataframes.


**Reasoning**:



In [27]:
# Impute missing 'Years of Experience'
df_train['Years of Experience'] = df_train['Years of Experience'].fillna(df_train['Years of Experience'].median())
df_test['Years of Experience'] = df_test['Years of Experience'].fillna(df_test['Years of Experience'].median())

# Drop the row with the missing value in 'DDX SNOMED'
df_train.dropna(subset=['DDX SNOMED'], inplace=True)

# Ensure consistency in categorical features
for col in ['County', 'Health level', 'Nursing Competency', 'Clinical Panel']:
    if col in df_train.columns and col in df_test.columns:
        unique_train = set(df_train[col].unique())
        unique_test = set(df_test[col].unique())

        # Find inconsistencies
        diff = unique_train.symmetric_difference(unique_test)
        print(f"Inconsistencies in column '{col}': {diff}")

        # Attempt to fix inconsistencies (example: case correction)
        for value in diff:
            if value.lower() in unique_train and value.lower() in unique_test:
                df_train[col] = df_train[col].replace(value, value.lower())
                df_test[col] = df_test[col].replace(value, value.lower())
            # Additional logic for variations in wording could be added here

# Display the first few rows of the cleaned dataframes to verify the changes
display(df_train.head())
display(df_test.head())

Inconsistencies in column 'County': {'bungoma'}
Inconsistencies in column 'Health level': {'health centers'}
Inconsistencies in column 'Nursing Competency': {'emergency care adult', 'emergency care burns', 'pediatric emergency burns', 'wound and ostomy care', 'neonatal care', 'mayernal and child health', 'emergency care rape', 'critical care', 'maternah and child health', 'obstetrics emergency', 'emergency care gbv'}
Inconsistencies in column 'Clinical Panel': {'internal medicine cardiology', 'surgery opthalmology', 'paediatric neurology', 'internal medicine psychiatry', 'surgery paediatrics', 'psychiatry'}


Unnamed: 0,Master_Index,County,Health level,Years of Experience,Prompt,Nursing Competency,Clinical Panel,Clinician,GPT4.0,LLAMA,GEMINI,DDX SNOMED,Prompt_Length,Clinician_Length
0,ID_VBWWP,uasin gishu,sub county hospitals and nursing homes,18.0,i am a nurse with 18 years of experience in ge...,pediatric emergency burns,surgery,summary a 4 year old with 5 superficial burns ...,given your vast experience as a nurse in uasin...,1 immediate treatment protocol for second degr...,here s a response addressing the questions reg...,288514009 burn involving 5 percent of body sur...,158,47
1,ID_XMBBY,uasin gishu,national referral hospitals,17.0,i am a nurse with 17 years of experience in ge...,child health,paediatrics,summary 6 year old present with vomiting and a...,clinical summary • a 6 year old girl with know...,based on the symptoms and signs you ve describ...,based on the presentation the 6 year old girl ...,420270002 ketoacidosis due to type 1 diabetes ...,124,146
2,ID_JZNZW,kiambu,sub county hospitals and nursing homes,12.0,i am a nurse with 12 years of experience in ge...,general emergency,internal medicine,summary a 47 year old man presents with severe...,in this case you re dealing with a 47 year old...,firstly i must commend you on your thorough hi...,this 47 year old male presenting with severe r...,13200003 peptic ulcer disorder 25458004 acute ...,128,111
3,ID_QOQTK,uasin gishu,national referral hospitals,12.0,i am a nurse with 12 years of experience in pr...,critical care,internal medicine,summary 72 year old female with inability to w...,given er s clinical presentation and vitals th...,to me with this query based on the information...,this 92 year old female patient er presents wi...,14760008 constipation finding 419284004 altere...,88,168
4,ID_ZFJBM,uasin gishu,national referral hospitals,16.0,i am a nurse with 16 years of experience in ge...,adult health,internal medicine,a 22 year old female presents with headache di...,the 22 year old female patient is presenting w...,thank you for presenting this case based on th...,this 22 year old female patient presents with ...,95874006 carbon monoxide poisoning from fire d...,110,137


Unnamed: 0,Master_Index,County,Health level,Years of Experience,Prompt,Nursing Competency,Clinical Panel,Prompt_Length
0,ID_CUAOY,uasin gishu,sub county hospitals and nursing homes,2.0,i am a nurse with 2 years of experience in gen...,adult health,surgery ent,102
1,ID_OGSAY,kiambu,sub county hospitals and nursing homes,22.0,i am a nurse with 22 years of experience in ge...,child health,surgery,60
2,ID_TYHSA,uasin gishu,national referral hospitals,15.0,i am a nurse working in a national referral ho...,general emergency,internal medicine,90
3,ID_CZXLD,kakamega,dispensaries and private clinics,15.0,i am a nurse working in a dispensaries and pri...,child health,paediatrics,93
4,ID_ZJQUQ,kakamega,health centres,15.0,i am a nurse working in a health centres in ka...,child health,paediatrics,150


## Text Preprocessing

Preprocess the text data by lowercasing, removing punctuation, replacing paragraph markers, and tokenizing the 'Prompt' and 'Clinician' columns in df_train and the 'Prompt' column in df_test.


In [None]:

# Lowercasing, punctuation removal, and paragraph replacement
def preprocess_text(text):
    if isinstance(text, str):
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        text = re.sub(r'\s+', ' ', text)
        return text
    return ""

df_train['Prompt'] = df_train['Prompt'].apply(preprocess_text)
df_train['Clinician'] = df_train['Clinician'].apply(preprocess_text)
df_test['Prompt'] = df_test['Prompt'].apply(preprocess_text)

# Tokenization
def tokenize_text(text):
    if isinstance(text, str):
        return word_tokenize(text)
    return []

df_train['Prompt_tokens'] = df_train['Prompt'].apply(tokenize_text)
df_train['Clinician_tokens'] = df_train['Clinician'].apply(tokenize_text)
df_test['Prompt_tokens'] = df_test['Prompt'].apply(tokenize_text)

# Display first few rows to verify
display(df_train.head())
display(df_test.head())

## Data splitting

Split the training data into training and validation sets.


In [29]:
if 'df_train' in locals() and isinstance(df_train, pd.DataFrame): # Check if df_train exists and is a DataFrame

    # Impute missing 'Years of Experience' *before* filtering and splitting
    # This ensures the column used for stratification is clean
    print("Checking and imputing missing values in 'Years of Experience'...")
    if df_train['Years of Experience'].isnull().any():
        median_years = df_train['Years of Experience'].median()
        # Use .copy() on df_train before modifying to avoid potential SettingWithCopyWarning issues
        df_train = df_train.copy()
        df_train['Years of Experience'].fillna(median_years, inplace=True)
        print(f"Imputed missing 'Years of Experience' with median: {median_years}")
    else:
        print("'Years of Experience' column has no missing values.")


    # Identify values in 'Years of Experience' that appear only once for stratification
    print("Checking for single-occurrence values in 'Years of Experience' for stratification...")
    value_counts = df_train['Years of Experience'].value_counts()
    values_to_remove = value_counts[value_counts == 1].index
    print(f"Values appearing only once: {list(values_to_remove)}")


    # Remove rows with those values for stratification
    # Ensure this filtering happens *after* imputation
    df_train_filtered = df_train[~df_train['Years of Experience'].isin(values_to_remove)].copy() # Use .copy() to avoid SettingWithCopyWarning
    print(f"Shape of df_train after filtering single-occurrence values: {df_train_filtered.shape}")

    # Verify no NaNs before splitting
    if df_train_filtered['Years of Experience'].isnull().any():
         print("Error: NaN values still present in 'Years of Experience' after imputation and filtering!")
         print("Rows with NaN in 'Years of Experience' before splitting:")
         print(df_train_filtered[df_train_filtered['Years of Experience'].isnull()])
         print("finish_task")
         print('{"status": "failure", "dataframes": []}')
         raise ValueError("NaN values found in 'Years of Experience' column before stratification.")


    # Split the filtered training data
    print("Performing train_test_split...")
    # Using _ for the embedding variables as they are not used by the T5 model
    try:
        # Ensure the 'y' parameter for stratify does not contain NaNs
        stratify_col = df_train_filtered['Years of Experience']
        if stratify_col.isnull().any():
             print("Internal error: stratify_col still contains NaNs before splitting!") # Should be caught by the check above, but defensive
             print("finish_task")
             print('{"status": "failure", "dataframes": []}')
             raise ValueError("NaNs found in stratification column right before splitting.")

        df_train_split, df_val_split, _, _ = train_test_split(
            df_train_filtered,
            stratify_col, # Stratify by the clean 'Years of Experience' column
            test_size=0.2,
            stratify=stratify_col,
            random_state=RANDOM_SEED
        )
        print("Data splitting successful.")
        # Print the shapes of the resulting DataFrames
        print(f"Shape of df_train_split (within training cell): {df_train_split.shape}")
        print(f"Shape of df_val_split (within training cell): {df_val_split.shape}")

    except ValueError as e:
        print(f"ValueError during train_test_split: {e}")
        print("finish_task")
        print('{"status": "failure", "dataframes": []}')
        raise SystemExit # Exit if splitting fails


else:
    print("Error: df_train not found or is not a DataFrame. Cannot perform data splitting.")
    print("finish_task")
    print('{"status": "failure", "dataframes": []}')
    raise SystemExit

Checking and imputing missing values in 'Years of Experience'...
'Years of Experience' column has no missing values.
Checking for single-occurrence values in 'Years of Experience' for stratification...
Values appearing only once: [13.0, 5.0, 11.0]
Shape of df_train after filtering single-occurrence values: (396, 16)
Performing train_test_split...
Data splitting successful.
Shape of df_train_split (within training cell): (316, 16)
Shape of df_val_split (within training cell): (80, 16)


# Data Augmentation

In [30]:
# --- Updated Data Augmentation ---

try:
    aug = naw.SynonymAug(aug_src='wordnet')
    print("NLPaug SynonymAug loaded successfully.")
except Exception as e:
    print(f"Error initializing NLPaug SynonymAug: {e}")
    # Decide if this should halt execution.
    # raise SystemExit

def augment_prompt(prompt, num_augments=1):
    if not isinstance(prompt, str) or not prompt.strip():
        return []
    try:
        results = aug.augment([prompt], n=num_augments)
        return [text for text in results if isinstance(text, str) and text.strip()]
    except Exception as e:
        print(f"Warning: Error augmenting prompt '{prompt[:50]}...': {e}")
        return []

if 'df_train_split' in locals() and isinstance(df_train_split, pd.DataFrame):
    print("\nStarting data augmentation...")
    augmented_data = []

    # This loop now preserves all metadata for original and augmented rows
    for index, row in df_train_split.iterrows():
        # Add the original data row
        augmented_data.append(row.to_dict())

        # Create and add augmented versions
        augmented_prompts = augment_prompt(row['Prompt'], num_augments=1)
        for aug_prompt in augmented_prompts:
            new_row_augmented = row.to_dict()
            new_row_augmented['Prompt'] = aug_prompt
            augmented_data.append(new_row_augmented)

    if augmented_data:
        df_train_augmented = pd.DataFrame(augmented_data)
        print(f"Original train split shape: {df_train_split.shape}")
        print(f"Augmented train split shape: {df_train_augmented.shape}")
    else:
        print("Warning: Data augmentation resulted in an empty dataset.")
        df_train_augmented = df_train_split.copy() # Fallback to original

else:
    print("Error: df_train_split not found. Skipping data augmentation.")
    raise SystemExit

NLPaug SynonymAug loaded successfully.

Starting data augmentation...
Original train split shape: (316, 16)
Augmented train split shape: (632, 16)


In [31]:
if 'df_train_augmented' in locals() and isinstance(df_train_augmented, pd.DataFrame):

    print("\nPreparing data for Multi-Target Training...")
    multi_target_data = []

    # Define which AI columns to include based on your ROUGE evaluation
    # We'll include LLaMA as it had the highest ROUGE score
    # You can optionally include GEMINI if you think it might also help
    ai_target_columns = ['LLAMA']
    # If you want to include Gemini as well, uncomment the line below:
    # ai_target_columns = ['LLAMA', 'GEMINI']


    for index, row in df_train_augmented.iterrows():
        # Always include the human clinician response
        if pd.notna(row['Clinician']) and row['Clinician'].strip():
             row_human_target = row.to_dict()
             # Use a consistent column name for the target in the new dataset
             row_human_target['Target_Response'] = row['Clinician']
             multi_target_data.append(row_human_target)

        # Include AI responses if available
        for ai_col in ai_target_columns:
            if ai_col in row and pd.notna(row[ai_col]) and row[ai_col].strip():
                 row_ai_target = row.to_dict()
                 row_ai_target['Target_Response'] = row[ai_col]
                 # You might optionally add a source column if you want to track source
                 # row_ai_target['Target_Source'] = ai_col
                 multi_target_data.append(row_ai_target)


    if multi_target_data:
        df_train_multi_target = pd.DataFrame(multi_target_data)
        print(f"Original augmented train split shape: {df_train_augmented.shape}")
        print(f"Multi-target train data shape: {df_train_multi_target.shape}")
    else:
        print("Error: Multi-target data preparation resulted in an empty dataset.")
        # Fallback to original data if multi-target data is empty
        df_train_multi_target = df_train_augmented.copy()
        print("Using original augmented data as a fallback.")


else:
    print("Error: df_train_augmented not found. Cannot prepare multi-target data.")
    raise SystemExit


Preparing data for Multi-Target Training...
Original augmented train split shape: (632, 16)
Multi-target train data shape: (1264, 17)


## Model training

In [32]:
# Prepare the training and validation datasets
# These lines now rely on the splitting code that was just included above
import os
os.environ["WANDB_DISABLED"] = "true"
print("Wandb logging disabled.")

try:
    nltk.data.find('tokenizers/punkt')
except nltk.downloader.DownloadError:
    print("NLTK 'punkt' tokenizer not found, attempting download...")
    try:
        nltk.download('punkt', quiet=True)
        print("NLTK 'punkt' tokenizer downloaded successfully.")
    except Exception as e:
        print(f"Error downloading NLTK 'punkt' tokenizer: {e}")
        print("Warning: Sentence tokenization may fail.")

# 1. Create Hugging Face Datasets with all metadata
# This now includes all columns needed for the structured prompt + the new 'Target_Response'
if 'df_train_multi_target' in locals() and isinstance(df_train_multi_target, pd.DataFrame):
    # Define the columns to keep for the training dataset
    # Include 'Target_Response' as the new label column
    required_train_cols = [
        'Master_Index', 'Prompt', 'County', 'Health level',
        'Years of Experience', 'Nursing Competency', 'Clinical Panel',
        'Target_Response'
    ]
    # Filter to only required columns to avoid issues with missing columns
    train_cols_multi_target = [col for col in required_train_cols if col in df_train_multi_target.columns]
    # Use the same columns for validation dataset as before
    # Make sure 'Clinician' is present in val_cols for evaluation
    required_val_cols = [
        'Master_Index', 'Prompt', 'County', 'Health level',
        'Years of Experience', 'Nursing Competency', 'Clinical Panel',
        'Clinician' # Keep Clinician for evaluation target
    ]
    val_cols_split = [col for col in required_val_cols if col in df_val_split.columns]


    # Create datasets. Dataset.from_pandas automatically handles indexing.
    train_dataset_multi_target = Dataset.from_pandas(df_train_multi_target[train_cols_multi_target])
    # Keep the validation dataset as is, using 'Clinician' as the target for evaluation
    val_dataset_split = Dataset.from_pandas(df_val_split[val_cols_split])


    # Define columns to remove after mapping for the multi-target training dataset
    # The preprocess function will output 'input_ids', 'attention_mask', and 'labels'
    # So remove all original columns *not* in this list from the training dataset
    cols_to_remove_train_multi_target = [
        col for col in train_dataset_multi_target.column_names
        if col not in ['input_ids', 'attention_mask', 'labels']
    ]

    # Columns to remove for the validation dataset remain the same
    cols_to_remove_val_split = [
        col for col in val_dataset_split.column_names
        if col not in ['input_ids', 'attention_mask', 'labels']
    ]


    print("Hugging Face Multi-Target Datasets created successfully.")
else:
    print("Error: Multi-target training data not found. Cannot create dataset.")
    print("finish_task")
    print('{"status": "failure", "dataframes": []}')
    raise SystemExit


# 2. Load Tokenizer and Model (Keep using t5-base)
print(f"Loading tokenizer and model from {SEQ2SEQ_MODEL_NAME}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(SEQ2SEQ_MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(SEQ2SEQ_MODEL_NAME)
    print("Tokenizer and model loaded.")
except Exception as e:
    print(f"Error loading tokenizer or model: {e}")
    print("finish_task")
    print('{"status": "failure", "dataframes": []}')
    raise SystemExit


# 3. Update Preprocess Function for Multi-Target Labels
def preprocess_function_multi_target(examples):
    inputs = []
    for i in range(len(examples['Prompt'])):
        # Construct the detailed, structured prompt for each example
        # This part remains the same, using the original prompt and metadata
        years_exp = examples.get('Years of Experience', [0])[i]
        years_exp_str = f"{years_exp:.1f}" if isinstance(years_exp, (float, np.floating)) else str(years_exp)

        structured_prompt = f"""Context:
- County: {examples.get('County', ['N/A'])[i]}
- Health Facility Level: {examples.get('Health level', ['N/A'])[i]}
- Nurse Experience: {years_exp_str} years
- Competency: {examples.get('Nursing Competency', ['N/A'])[i]}
- Clinical Area: {examples.get('Clinical Panel', ['N/A'])[i]}

Task: Provide a clinical summary and recommendation for the following case.

Case:
{examples['Prompt'][i]}
"""
        inputs.append(structured_prompt)

    # Tokenize the structured prompts
    model_inputs = tokenizer(inputs, max_length=MAX_LENGTH, truncation=True, padding="max_length")

    # Tokenize the targets - NOW USING 'Target_Response' FOR TRAINING
    with tokenizer.as_target_tokenizer():
        # Use 'Target_Response' for training data, 'Clinician' for validation data
        # The map function handles which dataset is being processed
        # Access the correct column based on the dataset structure
        if 'Target_Response' in examples: # This is for the multi-target training dataset
             labels_text = examples['Target_Response']
        elif 'Clinician' in examples: # This is for the validation dataset
             labels_text = examples['Clinician']
        else:
             # Fallback if neither target column is found (shouldn't happen with correct data prep)
             labels_text = [""] * len(examples['Prompt'])
             print("Warning: Neither 'Target_Response' nor 'Clinician' found in examples.")

        labels_text = [str(lbl) if lbl is not None else "" for lbl in labels_text]
        labels = tokenizer(labels_text, max_length=MAX_LENGTH, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


print("Tokenizing datasets with structured prompts and multi-targets...")
try:
    # Map the preprocess function to the multi-target training dataset
    tokenized_train_multi_target = train_dataset_multi_target.map(
        preprocess_function_multi_target,
        batched=True,
        remove_columns=cols_to_remove_train_multi_target
    )
    # Map the preprocess function to the validation dataset (still using Clinician as target)
    tokenized_val_split = val_dataset_split.map(
        preprocess_function_multi_target, # Use the same function, it checks for 'Clinician'
        batched=True,
        remove_columns=cols_to_remove_val_split
    )
    print("Tokenization complete.")
except Exception as e:
    print(f"Error during dataset tokenization: {e}")
    print("finish_task")
    print('{"status": "failure", "dataframes": []}')
    raise SystemExit


# 4. Data Collator, Metrics, and Training Arguments (keep as they were)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Use the same advanced compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # --- FIX: Ensure outputs are lists of strings and clean them ---
    # Filter out any potential non-string types and handle empty strings
    decoded_preds = [str(pred).strip() if pred is not None else "" for pred in decoded_preds]
    decoded_labels = [str(label).strip() if label is not None else "" for label in decoded_labels]
    # --- END FIX ---

    scorer = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=True)
    # Calculate ROUGE score for each prediction/label pair
    # Use a list comprehension to get scores for each item in the batch
    scores = [scorer.score(label, pred) for label, pred in zip(decoded_labels, decoded_preds)]

    # Average the rougeLsum scores from the batch
    # Access the 'mid' field and 'fmeasure' for each score dictionary
    rouge_lsum_scores = [score['rougeLsum'].fmeasure for score in scores]
    avg_rouge_Lsum = np.mean(rouge_lsum_scores) * 100 # Convert to percentage

    # For metric_for_best_model, we can also provide average ROUGE L (not summed) if needed,
    # but rougeLsum is standard for summaries. Let's stick to rougeLsum for consistency
    # with the original code's intent for 'rougeL'.
    # If you specifically need average ROUGE L (non-summed), calculate it similarly:
    # rouge_l_scores = [score['rougeL'].fmeasure for score in scores]
    # avg_rouge_L = np.mean(rouge_l_scores) * 100


    result = {'rougeLsum': avg_rouge_Lsum}
    # The trainer looks for the key specified by metric_for_best_model ('rougeL')
    # Map rougeLsum to rougeL for compatibility with the trainer's expectation
    result['rougeL'] = result['rougeLsum']

    result = {k: round(v, 4) for k, v in result.items()}
    return result


# Define the decoder_start_token_id before using it in GenerationConfig
# Use the tokenizer's decoder_start_token_id, which is typically the pad_token_id for T5
decoder_start_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
if decoder_start_token_id is None:
    # Fallback if both are None, though unlikely for T5
    print("Warning: Could not determine decoder_start_token_id. Using 0.")
    decoder_start_token_id = 0


generation_config = GenerationConfig(
    max_length=GEN_MAX_LENGTH,
    num_beams=NUM_BEAMS,
    early_stopping=True,
    repetition_penalty=2.0,
    min_length=40,
    length_penalty=1.0,  # Encourage slightly longer outputs
    no_repeat_ngram_size=3,  # Prevent repetition of 3-grams
    pad_token_id=tokenizer.pad_token_id, # Explicitly set pad token id
    eos_token_id=tokenizer.eos_token_id, # Explicitly set eos token id
    decoder_start_token_id=decoder_start_token_id # **Explicitly set the decoder start token ID**
)

training_args = Seq2SeqTrainingArguments(
    output_dir=TRAINING_OUTPUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=12,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    logging_dir=os.path.join(TRAINING_OUTPUT_DIR, 'logs'),
    logging_steps=100,
    generation_config=generation_config,
    label_smoothing_factor=0.1,
    warmup_ratio=0.06,
    dataloader_num_workers=2, # Added for potentially faster data loading
)

# Early Stopping Callback remains the same
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=3,

    early_stopping_threshold=0.001
)

# Trainer definition remains the same, but now uses the multi-target training dataset
trainer = Seq2SeqTrainer(
    model=model, # Use the same model instance loaded earlier
    args=training_args, # Use the same training arguments
    train_dataset=tokenized_train_multi_target, # Use the new multi-target training dataset
    eval_dataset=tokenized_val_split, # Use the validation dataset (still with Clinician target)
    tokenizer=tokenizer,
    data_collator=data_collator, # Use the same data collator
    compute_metrics=compute_metrics, # Use the same compute_metrics function for validation
    callbacks=[early_stopping_callback] # Use the same callbacks
)

print("Starting model training with multi-target data...")
# Train the model
try:
    trainer.train()
    print("Model training finished.")
except Exception as e:
    print(f"Error during model training: {e}")
    print("Training failed.")
    print("finish_task")
    print('{"status": "failure", "dataframes": []}')

# Get the path to the best model checkpoint
best_model_path = trainer.state.best_model_checkpoint
if best_model_path is None:
     # If early stopping didn't trigger or no checkpoint saved, use the last checkpoint
    best_model_path = trainer.state.last_checkpoint
    if best_model_path is None:
        best_model_path = TRAINING_OUTPUT_DIR
        print(f"No best checkpoint saved. Using training output dir: {best_model_path}")


trained_model = None
if best_model_path and os.path.exists(best_model_path):
    print(f"Attempting to load best trained model from: {best_model_path}")
    try:
        trained_model = AutoModelForSeq2SeqLM.from_pretrained(best_model_path)
        print("Best trained model loaded successfully.")
    except Exception as e:
        print(f"Error loading best trained model from {best_model_path}: {e}")
        print("finish_task")
        print('{"status": "failure", "dataframes": []}')
        trained_model = None
else:
    print(f"Could not find a saved checkpoint at {best_model_path}. Using the final model from training (may not be the best).")
    if 'model' in locals() and model is not None:
        trained_model = model
    else:
        print("Error: Trained model could not be loaded or accessed.")
        print("finish_task")
        print('{"status": "failure", "dataframes": []}')

Wandb logging disabled.
Hugging Face Multi-Target Datasets created successfully.
Loading tokenizer and model from t5-base...
Tokenizer and model loaded.
Tokenizing datasets with structured prompts and multi-targets...


Map:   0%|          | 0/1264 [00:00<?, ? examples/s]



Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Tokenization complete.


  trainer = Seq2SeqTrainer(


Starting model training with multi-target data...


Step,Training Loss,Validation Loss,Rougelsum,Rougel
100,7.3507,3.491699,27.4534,27.4534
200,4.0853,3.283855,28.1694,28.1694
300,3.9228,3.212824,29.6866,29.6866
400,3.828,3.1894,29.3372,29.3372


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Model training finished.
Attempting to load best trained model from: /content/drive/MyDrive/Colab Notebooks/Kenya-Challenge-V2/output/training_results/checkpoint-300
Best trained model loaded successfully.


## Model Optimization and Quantization

In [33]:
if trained_model is not None:
    print("\nStarting model optimization and quantization...")

    # ==================== DEVICE SETUP ====================
    device = 'cpu'
    print(f"Using device for quantization and inference tests: {device}")

    # Move the trained model to CPU before quantization
    trained_model.to(device)
    trained_model.eval() # Set model to evaluation mode

    # ==================== MODEL SIZE BEFORE QUANTIZATION ====================
    def get_model_size(model_obj):
        param_size = sum(p.nelement() * p.element_size() for p in model_obj.parameters())
        buffer_size = sum(b.nelement() * b.element_size() for b in model_obj.buffers())
        return (param_size + buffer_size) / (1024 ** 2)

    print(f"Trained model size (before quantization): {get_model_size(trained_model):.2f} MB")

    # ==================== APPLY DYNAMIC QUANTIZATION ====================
    quantized_model = None # Initialize as None
    try:
        # Correction 4: Quantize the trained model, not a new instance
        quantized_model = quantize_dynamic(
            trained_model, # Quantize the trained model
            {torch.nn.Linear}, # Apply to Linear layers
            dtype=torch.qint8 # Quantize to 8-bit integers
        )
        print("Trained model quantized successfully.")
        print(f"Quantized model size: {get_model_size(quantized_model):.2f} MB")
    except Exception as e:
        print(f"Quantization failed: {e}")
        print("Quantization failed, proceeding with float32 model for evaluation and submission if constraints allow.")
        # For this challenge's constraints, failing quantization should likely stop the process.
        # If quantization is mandatory, uncomment the raise statement:
        # raise SystemExit


    # ==================== INFERENCE AND METRICS (on Quantized Model if successful, else Trained) ====================
    model_for_inference = quantized_model if quantized_model is not None else trained_model

    if model_for_inference is not None:
        print("\nRunning inference test on the model...")
        sample_prompt = "This is a test prompt for model inference."
        # Use the tokenizer loaded previously which corresponds to the trained model
        inputs = tokenizer(sample_prompt, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(device)

        # Measure memory before inference
        process = psutil.Process(os.getpid())
        mem_before = process.memory_info().rss / (1024 ** 2)

        # Measure inference time
        start_time = time.time()
        with torch.no_grad():
            # Use the generate method of the appropriate model
            outputs = model_for_inference.generate(
                 inputs["input_ids"],
                 attention_mask=inputs["attention_mask"],
                 max_length=GEN_MAX_LENGTH,
                 num_beams=NUM_BEAMS,
                 early_stopping=True,
                 min_length=50,
                 pad_token_id=tokenizer.pad_token_id,
                 eos_token_id=tokenizer.eos_token_id
             )
        end_time = time.time()
        mem_after = process.memory_info().rss / (1024 ** 2)

        print(f"Inference output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
        print(f"Memory usage: Before = {mem_before:.2f} MB, After = {mem_after:.2f} MB")
        print(f"Inference time: {(end_time - start_time) * 1000:.2f} ms")
    else:
         print("\nSkipping inference test as no model is available after training and quantization attempts.")

else:
     print("\nSkipping model optimization and quantization as trained model is not available.")



Starting model optimization and quantization...
Using device for quantization and inference tests: cpu
Trained model size (before quantization): 850.31 MB
Trained model quantized successfully.
Quantized model size: 94.31 MB

Running inference test on the model...
Inference output: This is a test prompt for model inference. Please use this prompt to test model inferiority. If you are using a newer version of the Model Inference Prompt, please use this as a reference.
Memory usage: Before = 5511.76 MB, After = 5413.35 MB
Inference time: 2964.58 ms


# Model Evaluation and Submission File Generation

In [34]:
# =====================================================================================
#
#  REFACTORED PREDICTION & EVALUATION PIPELINE
#
# =====================================================================================

from rouge_score import rouge_scorer
import pandas as pd
import torch
import gc
import re # Import re module for cleaning submission text

def generate_summaries(model, tokenizer, prompts, generation_config, device='cuda', batch_size=8):
    """
    Generates summaries for a list of prompts using a consistent generation config.
    """
    model.to(device)
    model.eval()

    predictions = []
    print(f"Generating summaries for {len(prompts)} prompts on device: {device}...")
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]
        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        ).to(device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                generation_config=generation_config # Use the consistent config
            )

        batch_predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(batch_predictions)

    print("Summary generation complete.")
    return predictions

def evaluate_model(model, tokenizer, df_val, generation_config, device='cuda'):
    """
    Evaluates the model on the validation set and returns ROUGE scores.
    """
    print(f"\n--- Starting Model Evaluation on device: {device} ---")
    val_prompts = ["summarize: " + p for p in df_val['Prompt'].tolist()]
    val_references = df_val['Clinician'].tolist()

    predictions = generate_summaries(model, tokenizer, val_prompts, generation_config, device=device)

    # Use both rougeL and rougeLsum for a complete picture
    scorer = rouge_scorer.RougeScorer(["rougeL", "rougeLsum"], use_stemmer=True)

    scores = {'rougeL': [], 'rougeLsum': []}
    for pred, ref in zip(predictions, val_references):
        score = scorer.score(ref, pred)
        scores['rougeL'].append(score['rougeL'].fmeasure)
        scores['rougeLsum'].append(score['rougeLsum'].fmeasure)

    avg_rouge_L = np.mean(scores['rougeL'])
    avg_rouge_Lsum = np.mean(scores['rougeLsum'])

    print(f"Evaluation ROUGE-L Score: {avg_rouge_L:.4f}")
    print(f"Evaluation ROUGE-Lsum Score: {avg_rouge_Lsum:.4f}")
    print("--- Evaluation Complete ---")

    return avg_rouge_L, avg_rouge_Lsum

# Define the text cleaning function outside the try blocks
def clean_submission_text(text):
    if isinstance(text, str):
        # Remove punctuation except for spaces and basic text characters
        text = re.sub(r'[^\w\s]', '', text)
        # Convert to lowercase
        text = text.lower()
        # Replace multiple spaces with a single space and strip leading/trailing whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        return text
    return "" # Return empty string for non-string inputs


# --- Main Execution ---

# Ensure best_model from training is available
best_model = None
if 'trainer' in locals() and trainer is not None and trainer.state.best_model_checkpoint is not None:
    best_model_path = trainer.state.best_model_checkpoint
    try:
        best_model = AutoModelForSeq2SeqLM.from_pretrained(best_model_path)
        print(f"Loaded best model from checkpoint: {best_model_path}")
    except Exception as e:
        print(f"Error loading best model from checkpoint: {e}. Proceeding with trained_model if available.")
        if 'trained_model' in locals() and trained_model is not None:
             best_model = trained_model # Fallback to the final trained model if loading checkpoint fails
        else:
             print("Error: Could not load best model or find trained_model. Cannot proceed with evaluation or prediction.")
             print("finish_task")
             print('{"status": "failure", "dataframes": []}')
             raise SystemExit # Exit if no model is available
elif 'trained_model' in locals() and trained_model is not None:
     print("Best model checkpoint not found. Using the final trained_model for evaluation and prediction.")
     best_model = trained_model
else:
     print("Error: Neither best model checkpoint nor trained_model found. Cannot proceed with evaluation or prediction.")
     print("finish_task")
     print('{"status": "failure", "dataframes": []}')
     raise SystemExit # Exit if no model is available


# 1. Evaluate the BEST NON-QUANTIZED model first to get a true baseline
if best_model is not None and 'df_val_split' in locals() and isinstance(df_val_split, pd.DataFrame):
    print("\n\nEvaluating the best model from training (before quantization)...")
    # Ensure evaluation is on the appropriate device for the non-quantized model
    non_quantized_device = 'cuda' if torch.cuda.is_available() else 'cpu'
    try:
        evaluate_model(best_model, tokenizer, df_val_split, generation_config, device=non_quantized_device)
    except Exception as e:
        print(f"Error during non-quantized model evaluation: {e}")
else:
    print("\nSkipping non-quantized model evaluation. Model or validation data not available.")


# 2. Quantize the best model
print("\nAttempting to quantize the best model...")
quantized_model = None # Initialize as None
if best_model is not None:
    try:
        # Correction 4: Quantize the best_model loaded from the checkpoint or trained_model
        quantized_model = torch.quantization.quantize_dynamic(
            best_model, # Quantize the loaded best model
            {torch.nn.Linear}, # Apply to Linear layers
            dtype=torch.qint8 # Quantize to 8-bit integers
        )
        print("Trained model quantized successfully.")
        print(f"Quantized model size: {get_model_size(quantized_model):.2f} MB") # Assuming get_model_size is defined earlier
    except Exception as e:
        print(f"Quantization failed: {e}")
        print("Quantization failed. Proceeding with float32 model for evaluation and submission.")
else:
    print("Skipping quantization as no best_model is available.")


# 3. Evaluate the QUANTIZED model to measure the performance impact
# Note: Quantized models run on the CPU
if quantized_model is not None and 'df_val_split' in locals() and isinstance(df_val_split, pd.DataFrame):
    print("\nEvaluating the QUANTIZED model...")
    quantized_device = 'cpu' # Quantized models run on CPU
    try:
        evaluate_model(quantized_model, tokenizer, df_val_split, generation_config, device=quantized_device)
    except RuntimeError as e:
        print(f"RuntimeError during quantized model evaluation: {e}")
        print("Skipping quantized model evaluation due to RuntimeError. This might indicate a compatibility issue with quantized model generation on the CPU.")
    except Exception as e:
         print(f"Error during quantized model evaluation: {e}")
else:
     print("\nSkipping quantized model evaluation. Quantized model or validation data not available.")


# 4. Generate submission file with the final model
model_to_predict = quantized_model if quantized_model is not None else best_model # Prefer quantized if available

if model_to_predict is not None and 'df_test' in locals() and isinstance(df_test, pd.DataFrame):
    print("\nGenerating submission file with the final model...")

    # Determine device for prediction based on the model used
    prediction_device = 'cpu' if model_to_predict == quantized_model else ('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using model on device: {prediction_device} for test set prediction.")

    test_prompts = ["summarize: " + p for p in df_test['Prompt'].tolist()]
    test_ids = df_test['Master_Index'].tolist()

    try:
        test_predictions = generate_summaries(model_to_predict, tokenizer, test_prompts, generation_config, device=prediction_device)

        submission_df = pd.DataFrame({
            'Master_Index': test_ids,
            'Clinician': test_predictions
        })

        submission_df['Clinician'] = submission_df['Clinician'].apply(clean_submission_text)


        # Display the first few rows of the submission dataframe
        print("\nSubmission DataFrame head:")
        display(submission_df.head())

        # Save the submission file
        submission_df.to_csv(SUBMISSION_FILE_PATH, index=False)
        print(f"\nSubmission file created at: {SUBMISSION_FILE_PATH}")

    except RuntimeError as e:
        print(f"RuntimeError during test set prediction with the selected model: {e}")
        # Fallback to the non-quantized model if quantized prediction failed
        if model_to_predict == quantized_model and best_model is not None:
            print("Quantized model prediction failed. Falling back to non-quantized model for submission.")
            fallback_prediction_device = 'cuda' if torch.cuda.is_available() else 'cpu'
            try:
                test_predictions_fallback = generate_summaries(best_model, tokenizer, test_prompts, generation_config, device=fallback_prediction_device)
                submission_df_fallback = pd.DataFrame({
                    'Master_Index': test_ids,
                    'Clinician': test_predictions_fallback
                })
                submission_df_fallback['Clinician'] = submission_df_fallback['Clinician'].apply(clean_submission_text) # Call the now accessible function
                print("\nSubmission DataFrame head (using fallback non-quantized model):")
                display(submission_df_fallback.head())
                submission_df_fallback.to_csv(SUBMISSION_FILE_PATH, index=False)
                print(f"\nSubmission file created at: {SUBMISSION_FILE_PATH} using the non-quantized model.")
            except Exception as fallback_e:
                print(f"Error during fallback prediction with non-quantized model: {fallback_e}")
                print("Failed to generate submission file even with the non-quantized model.")
                print("finish_task")
                print('{"status": "failure", "dataframes": []}')

        else:
            print("Prediction failed and no fallback model available or fallback also failed.")
            print("finish_task")
            print('{"status": "failure", "dataframes": []}')

    except Exception as e:
        print(f"An unexpected error occurred during test set prediction: {e}")
        print("finish_task")
        print('{"status": "failure", "dataframes": []}')

else:
    print("\nSkipping submission file generation. Appropriate model or test data (df_test) not available.")


# Optimized inference function - Ensure it can handle CPU device
def generate_prediction(model, tokenizer, prompt, device='cpu'):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH
    ).to(device) # Move inputs to the specified device

    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=GEN_MAX_LENGTH,
            num_beams=NUM_BEAMS,
            early_stopping=True,
            repetition_penalty=2.5,
            min_length=50,
            length_penalty=1.0,
            no_repeat_ngram_size=3,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True  # Enable cache during inference
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Batch inference with caching - Ensure it can handle CPU device
def batch_generate_predictions(model, tokenizer, prompts, batch_size=8, device='cpu'):
    predictions = []

    # Process in batches
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i+batch_size]

        # Tokenize batch
        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        ).to(device) # Move inputs to the specified device

        # Generate predictions
        with torch.no_grad():
            outputs = model.generate(
                inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=GEN_MAX_LENGTH,
                num_beams=NUM_BEAMS,
                early_stopping=True,
                repetition_penalty=2.5,
                min_length=50,
                length_penalty=1.0,
                no_repeat_ngram_size=3,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=True
            )

        # Decode predictions
        batch_predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(batch_predictions)

        # Clear cache periodically (optional, less critical on CPU)
        # if i % (batch_size * 10) == 0:
        #     torch.cuda.empty_cache() if torch.cuda.is_available() else gc.collect()

    return predictions

# Update the evaluation section to use the optimized inference
# This section was causing the error, adding error handling
model_to_evaluate = quantized_model # Try to evaluate the quantized model first
if model_to_evaluate is not None and isinstance(df_val_split, pd.DataFrame):
    print("\nStarting optimized model evaluation...")

    # Fix: Ensure evaluation is on CPU after quantization
    device = 'cpu' # Force evaluation on CPU after quantization
    # Ensure the model is moved to the correct device for this specific evaluation
    try:
      model_to_evaluate.to(device)
      model_to_evaluate.eval()
    except Exception as e:
      print(f"Error moving model to {device} for optimized evaluation: {e}")
      model_to_evaluate = None # Set to None to skip evaluation


    if model_to_evaluate is not None:
        try:
            # Generate predictions using optimized batch inference
            val_predictions = batch_generate_predictions(
                model_to_evaluate,
                tokenizer,
                ["summarize: " + p for p in df_val_split['Prompt'].tolist()], # Add "summarize: " prefix
                batch_size=8,
                device=device # Pass 'cpu' device to the inference function
            )

            # Compute ROUGE scores
            scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
            rouge_scores = []

            for pred, ref in zip(val_predictions, df_val_split['Clinician'].tolist()):
                score = scorer.score(ref, pred)
                rouge_scores.append(score['rougeL'].fmeasure)

            # Handle case with no predictions or references
            if len(rouge_scores) > 0:
                avg_rouge = sum(rouge_scores) / len(rouge_scores)
                print(f"Optimized Average ROUGE-L Score (Quantized Model): {avg_rouge:.4f}")
            else:
                print("No valid predictions or references for ROUGE scoring on the validation set during optimized evaluation.")

        except RuntimeError as e:
            print(f"RuntimeError during optimized quantized model evaluation: {e}")
            print("Skipping optimized quantized model evaluation due to RuntimeError.")
        except Exception as e:
            print(f"An unexpected error occurred during optimized quantized model evaluation: {e}")

else:
    print("Skipping optimized model evaluation. Quantized model or validation data not available.")

Loaded best model from checkpoint: /content/drive/MyDrive/Colab Notebooks/Kenya-Challenge-V2/output/training_results/checkpoint-300


Evaluating the best model from training (before quantization)...

--- Starting Model Evaluation on device: cuda ---
Generating summaries for 80 prompts on device: cuda...
Summary generation complete.
Evaluation ROUGE-L Score: 0.2836
Evaluation ROUGE-Lsum Score: 0.2836
--- Evaluation Complete ---

Attempting to quantize the best model...
Trained model quantized successfully.
Quantized model size: 94.31 MB

Evaluating the QUANTIZED model...

--- Starting Model Evaluation on device: cpu ---
Generating summaries for 80 prompts on device: cpu...
RuntimeError during quantized model evaluation: apply_dynamic is not implemented for this packed parameter type
Skipping quantized model evaluation due to RuntimeError. This might indicate a compatibility issue with quantized model generation on the CPU.

Generating submission file with the final model...
Using model 

Unnamed: 0,Master_Index,Clinician
0,ID_CUAOY,a 24 year old female complains of sharp pain i...
1,ID_OGSAY,i am a nurse with 22 years of experience in ge...
2,ID_TYHSA,a 22 year old man was brought in with a histor...
3,ID_CZXLD,i am a nurse working in a dispensaries and pri...
4,ID_ZJQUQ,a one year old boy who had never received the ...



Submission file created at: /content/drive/MyDrive/Colab Notebooks/Kenya-Challenge-V2/output/submission.csv using the non-quantized model.

Starting optimized model evaluation...
RuntimeError during optimized quantized model evaluation: apply_dynamic is not implemented for this packed parameter type
Skipping optimized quantized model evaluation due to RuntimeError.
