In [1]:
import os
import json
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, accuracy_score
import joblib
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
)
import warnings
warnings.filterwarnings("ignore")

In [None]:
MODEL_NAME = "answerdotai/ModernBERT-base"
DATA_PATH = "../data/medical_conversations.csv"     # Path to the CSV file
OUTPUT_DIR = "../models/modernbert_medical"         # Folder to save checkpoints
SAVE_DIR = "../models/disease"                       # Final folder for deployment/service
MAX_LEN = 256                               # Maximum token length (conversations are usually short)
BATCH_SIZE = 16
EPOCHS = 5
LR = 2e-5
SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

DEVICE = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"‚úÖ Device: {DEVICE}")

‚úÖ Device: cuda


In [25]:
def load_and_preprocess(path: str):
    df = pd.read_csv(path)
    # Normalize column names
    df.columns = [c.strip().lower() for c in df.columns]

    # Find text and label columns
    text_col = next((c for c in df.columns if "convers" in c or "text" in c or "input" in c), df.columns[0])
    label_col = next((c for c in df.columns if "disease" in c or "label" in c or "class" in c or "specialty" in c), df.columns[1])

    print(f"üìå Text column: '{text_col}' | Label column: '{label_col}'")

    df = df[[text_col, label_col]].dropna()
    df.columns = ["text", "label"]

    # Clean text
    df["text"] = (
        df["text"]
        .str.replace(r"User:\s*", "", regex=True)   # Remove "User:" prefix
        .str.replace(r"\s+", " ", regex=True)
        .str.strip()
    )

    # Filter very short rows
    df = df[df["text"].str.len() > 5]

    print(f"\nüìä Number of samples: {len(df)}")
    print(f"üè∑Ô∏è Number of classes: {df['label'].nunique()}")
    print(f"\nLabel distribution:\n{df['label'].value_counts()}")

    # Encode labels
    le = LabelEncoder()
    df["label_id"] = le.fit_transform(df["label"])

    return df, le


In [26]:
class MedicalDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts.tolist()
        self.labels = labels.tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
        )
        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long),
        }


In [5]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

In [27]:
def train(df, le):
    num_labels = df["label_id"].nunique()

    # Split data
    train_df, val_df = train_test_split(
        df,
        test_size=0.15,
        random_state=SEED,
        stratify=df["label_id"]
    )

    print(f"\nüìÇ Train: {len(train_df):,} | Val: {len(val_df):,}")

    # Load tokenizer and model
    print(f"\n‚è≥ Loading {MODEL_NAME} ...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=num_labels,
        # ignore_mismatched_sizes=True   # ‚Üê only enable when really necessary
    )

    # Create datasets
    train_ds = MedicalDataset(train_df["text"], train_df["label_id"], tokenizer, MAX_LEN)
    val_ds = MedicalDataset(val_df["text"], val_df["label_id"], tokenizer, MAX_LEN)

    # Training arguments - updated version
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=LR,
        weight_decay=0.01,
        warmup_ratio=0.1,
        # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ Important change ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        eval_strategy="epoch",          # previously: evaluation_strategy
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        logging_dir="./logs",
        logging_steps=50,
        fp16=(DEVICE == "cuda" and torch.cuda.is_available()),  # safer
        seed=SEED,
        report_to="none",
    )

    # Define trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,                    # required for early stopping & load_best
        compute_metrics=compute_metrics,        # must be defined
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
    )

    print("\nüöÄ Starting training...")
    trainer.train()

    # Final evaluation (on the best checkpoint)
    print("\nüìà Final evaluation on validation set (best model):")
    preds_output = trainer.predict(val_ds)
    preds = np.argmax(preds_output.predictions, axis=-1)

    print(classification_report(
        val_df["label_id"],
        preds,
        target_names=le.classes_,
        zero_division=0,
    ))

    # If you want to return/save the best model (not the last one)
    # best_model_path = os.path.join(OUTPUT_DIR, "checkpoint-best")
    # trainer.save_model(best_model_path)          # optional

    return model, tokenizer


In [28]:
def save_model(model, tokenizer, le, save_dir: str):
    os.makedirs(save_dir, exist_ok=True)

    # Save model & tokenizer (HuggingFace format)
    model_path = os.path.join(save_dir, "specialty_modernbert")
    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)

    # Save label encoder and metadata
    joblib.dump(le, os.path.join(save_dir, "specialty_label_encoder.pkl"))
    joblib.dump(list(le.classes_), os.path.join(save_dir, "specialty_classes.pkl"))

    # Save config for deployment/service
    config = {
        "model_path": "specialty_modernbert",
        "max_len": MAX_LEN,
        "num_labels": len(le.classes_),
        "classes": list(le.classes_),
        "model_name": MODEL_NAME,
    }

    with open(os.path.join(save_dir, "specialty_config.json"), "w", encoding="utf-8") as f:
        json.dump(config, f, ensure_ascii=False, indent=2)

    print(f"\n‚úÖ Model saved to: {save_dir}/specialty_modernbert/")
    print(f" - config.json, pytorch_model.bin, tokenizer files")
    print(f" - specialty_label_encoder.pkl")
    print(f" - specialty_classes.pkl")
    print(f" - specialty_config.json")

    print(f"\nüè∑Ô∏è Detectable classes ({len(le.classes_)}):")
    for i, c in enumerate(le.classes_):
        print(f" {i}: {c}")


In [29]:
data_path = DATA_PATH
save_dir  = SAVE_DIR

In [30]:
df, le = load_and_preprocess(data_path)

üìå Text column: 'conversations' | Label column: 'disease'

üìä Number of samples: 960
üè∑Ô∏è Number of classes: 24

Label distribution:
label
allergy                            40
bronchial asthma                   40
malaria                            40
impetigo                           40
varicose veins                     40
diabetes                           40
drug reaction                      40
psoriasis                          40
pneumonia                          40
jaundice                           40
migraine                           40
urinary tract infection            40
arthritis                          40
peptic ulcer disease               40
cervical spondylosis               40
chicken pox                        40
typhoid                            40
dimorphic hemorrhoids              40
hypertension                       40
gastroesophageal reflux disease    40
acne                               40
fungal infection                   40
common cold       

In [31]:
model, tokenizer = train(df, le)



üìÇ Train: 816 | Val: 144

‚è≥ Loading answerdotai/ModernBERT-base ...


Loading weights:   0%|          | 0/136 [00:00<?, ?it/s]

ModernBertForSequenceClassification LOAD REPORT from: answerdotai/ModernBERT-base
Key               | Status     | 
------------------+------------+-
decoder.bias      | UNEXPECTED | 
classifier.weight | MISSING    | 
classifier.bias   | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
`logging_dir` is deprecated and will be removed in v5.2. Please set `TENSORBOARD_LOGGING_DIR` instead.



üöÄ Starting training...


Epoch,Training Loss,Validation Loss,Accuracy
1,2.560745,0.765415,0.833333
2,0.187802,0.056151,0.979167
3,0.00812,0.033127,0.986111
4,0.001217,0.033654,0.986111
5,0.000384,0.032671,0.993056


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]


üìà Final evaluation on validation set (best model):


                                 precision    recall  f1-score   support

                           acne       1.00      1.00      1.00         6
                        allergy       0.86      1.00      0.92         6
                      arthritis       1.00      1.00      1.00         6
               bronchial asthma       1.00      1.00      1.00         6
           cervical spondylosis       1.00      1.00      1.00         6
                    chicken pox       1.00      1.00      1.00         6
                    common cold       1.00      0.83      0.91         6
                         dengue       1.00      1.00      1.00         6
                       diabetes       1.00      1.00      1.00         6
          dimorphic hemorrhoids       1.00      1.00      1.00         6
                  drug reaction       1.00      1.00      1.00         6
               fungal infection       1.00      1.00      1.00         6
gastroesophageal reflux disease       1.00      1.

In [32]:
save_model(model, tokenizer, le, save_dir)

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]


‚úÖ Model saved to: ./models/specialty_modernbert/
 - config.json, pytorch_model.bin, tokenizer files
 - specialty_label_encoder.pkl
 - specialty_classes.pkl
 - specialty_config.json

üè∑Ô∏è Detectable classes (24):
 0: acne
 1: allergy
 2: arthritis
 3: bronchial asthma
 4: cervical spondylosis
 5: chicken pox
 6: common cold
 7: dengue
 8: diabetes
 9: dimorphic hemorrhoids
 10: drug reaction
 11: fungal infection
 12: gastroesophageal reflux disease
 13: hypertension
 14: impetigo
 15: jaundice
 16: malaria
 17: migraine
 18: peptic ulcer disease
 19: pneumonia
 20: psoriasis
 21: typhoid
 22: urinary tract infection
 23: varicose veins


In [21]:
!zip -r models_disease.zip models

  adding: models/ (stored 0%)
  adding: models/specialty_modernbert/ (stored 0%)
  adding: models/specialty_modernbert/config.json (deflated 71%)
  adding: models/specialty_modernbert/tokenizer_config.json (deflated 43%)
  adding: models/specialty_modernbert/model.safetensors (deflated 7%)
  adding: models/specialty_modernbert/tokenizer.json (deflated 82%)
  adding: models/specialty_label_encoder.pkl (deflated 36%)
  adding: models/specialty_config.json (deflated 48%)
  adding: models/specialty_classes.pkl (deflated 32%)


In [22]:
from google.colab import files
files.download('/content/models_disease.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [23]:
!du -sh models

575M	models
