In [2]:
# 📦 Imports
import pandas as pd
import torch
import numpy as np
import random
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from collections import Counter

# 🔧 Set seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

RANDOM_SEED = 42
set_seed(RANDOM_SEED)

# 📁 Load data
df1 = pd.read_csv("/content/Train_data.csv")[['text', 'label']]
df2 = pd.read_csv("/content/medical_data.csv").rename(columns={"Patient_Problem": "text", "Disease": "label"})[['text', 'label']]
df_all = pd.concat([df1, df2]).reset_index(drop=True)

# ❗ Entferne Klassen mit weniger als 2 Beispielen (für stratified split notwendig)
counts = df_all['label'].value_counts()
labels_to_remove = counts[counts < 2].index.tolist()
df_all_filtered = df_all[~df_all['label'].isin(labels_to_remove)].reset_index(drop=True)

print(f"Original dataset size: {len(df_all)}")
print(f"Filtered dataset size (removed classes with < 2 samples): {len(df_all_filtered)}")
print(f"Removed labels: {labels_to_remove}")

# 🔠 Label encoding auf gefiltertem Datensatz
le = LabelEncoder()
df_all_filtered['label'] = le.fit_transform(df_all_filtered['label'])
num_labels = len(le.classes_)

# 🔤 Model und Tokenizer (ClinicalBERT)
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 📦 Dataset-Klasse
class MedicalDataset(Dataset):
    def __init__(self, texts, labels):
        self.encodings = tokenizer(texts, truncation=True, padding=True, return_tensors="pt")
        self.labels = torch.tensor(labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

    def __len__(self):
        return len(self.labels)

# 🧪 Stratified Train/Test Split mit fixiertem Seed
X_train, X_test, y_train, y_test = train_test_split(
    df_all_filtered["text"].tolist(),
    df_all_filtered["label"].tolist(),
    test_size=0.2,
    random_state=RANDOM_SEED,
    stratify=df_all_filtered["label"].tolist()
)

print(f"Train size: {len(X_train)}, Test size: {len(X_test)}")
print("Train label distribution:", Counter(y_train))
print("Test label distribution:", Counter(y_test))

# 🏗️ DataLoader
train_loader = DataLoader(MedicalDataset(X_train, y_train), batch_size=16, shuffle=True)
test_loader = DataLoader(MedicalDataset(X_test, y_test), batch_size=16)

# 🧠 Modell-Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# 🛑 Early Stopping Parameter
patience = 3
best_val_loss = float("inf")
patience_counter = 0

train_losses = []
val_losses = []

# 🔁 Trainingsschleife mit Early Stopping
epoch = 0
while True:
    epoch += 1
    print(f"\nEpoch {epoch}")
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            total_val_loss += outputs.loss.item()

    avg_val_loss = total_val_loss / len(test_loader)
    val_losses.append(avg_val_loss)

    print(f"Train Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = model.state_dict()
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"\nEarly stopping triggered at epoch {epoch}")
        break

# 📥 Bestes Modell laden
model.load_state_dict(best_model_state)
model.eval()

# 🧾 Evaluation
preds = []
true = []

with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        logits = model(**batch).logits
        preds.extend(torch.argmax(logits, dim=1).cpu().numpy())
        true.extend(batch["labels"].cpu().numpy())

used_labels = np.unique(true)
used_names = le.inverse_transform(used_labels)

print("\n✅ Classification Report:\n")
print(classification_report(true, preds, labels=used_labels, target_names=used_names))
print("✅ Accuracy:", accuracy_score(true, preds))
print("✅ Precision:", precision_score(true, preds, average="weighted"))
print("✅ Recall:", recall_score(true, preds, average="weighted"))
print("✅ F1 Score:", f1_score(true, preds, average="weighted"))


Original dataset size: 1607
Filtered dataset size (removed classes with < 2 samples): 1517
Removed labels: ['Chronic Obstructive Pulmonary Disease (COPD)', 'Benign Positional Vertigo', 'Pericarditis', 'Mononucleosis', 'Pancreatitis', 'Migraine with Aura', 'Esophageal Reflux', 'Diabetes Type 2', 'Lung Cancer', 'Scabies', 'Pulmonary Embolism', 'Gastrointestinal Infection', 'Myopia and Hyperopia', 'Seasonal Allergies', 'Cystitis', 'Gingivitis', 'Multiple Sclerosis', 'Tension Headaches', 'Pregnancy', 'Diabetes', 'Anxiety Disorder', 'Hereditary Hemorrhagic Telangiectasia', 'Bronchitis', 'Night Blindness', 'Streptococcal Pharyngitis', 'Dry Eye Syndrome', 'Chronic Migraine', 'Peripheral Artery Disease', 'Typhoid Fever', 'Tendinitis', 'Anorexia', 'Atopic Dermatitis', 'Polycystic Ovary Syndrome', 'Carpal Tunnel Syndrome', 'Malignant Melanoma', 'Gastroesophageal Reflux Disease (GERD)', 'Thyroid Cancer', 'Generalized Anxiety Disorder', "Cushing's Syndrome", 'Diabetes Type 1', 'Liver Disease', 'Es

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Train size: 1213, Test size: 304
Train label distribution: Counter({68: 45, 52: 41, 67: 41, 81: 41, 13: 41, 82: 41, 62: 41, 104: 40, 0: 40, 23: 40, 102: 40, 99: 40, 96: 40, 103: 40, 94: 40, 8: 40, 55: 40, 16: 40, 17: 40, 28: 40, 100: 40, 41: 40, 33: 40, 101: 40, 84: 11, 54: 10, 20: 9, 53: 6, 85: 6, 78: 6, 22: 5, 57: 5, 19: 5, 60: 5, 25: 5, 66: 4, 91: 4, 7: 4, 51: 4, 38: 4, 90: 4, 12: 4, 9: 4, 79: 4, 95: 4, 92: 3, 3: 3, 76: 3, 35: 3, 48: 3, 83: 3, 72: 3, 73: 3, 4: 3, 1: 2, 75: 2, 32: 2, 31: 2, 40: 2, 39: 2, 58: 2, 2: 2, 42: 2, 71: 2, 24: 2, 74: 2, 46: 2, 86: 2, 97: 2, 11: 2, 5: 2, 69: 2, 87: 2, 15: 2, 21: 2, 37: 2, 63: 2, 45: 2, 34: 2, 64: 2, 49: 2, 10: 2, 77: 2, 88: 2, 43: 2, 47: 2, 36: 2, 50: 2, 29: 2, 65: 2, 80: 2, 89: 2, 56: 2, 6: 2, 27: 2, 18: 2, 14: 2, 70: 2, 30: 2, 26: 2, 61: 2, 59: 2, 98: 2, 93: 2, 44: 2})
Test label distribution: Counter({13: 11, 62: 11, 68: 11, 81: 11, 82: 11, 52: 11, 96: 10, 55: 10, 33: 10, 103: 10, 17: 10, 16: 10, 23: 10, 100: 10, 67: 10, 41: 10, 101: 10, 0:

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1



  0%|          | 0/76 [00:00<?, ?it/s][A
  1%|▏         | 1/76 [00:10<13:17, 10.63s/it][A
  3%|▎         | 2/76 [00:19<12:03,  9.78s/it][A
  4%|▍         | 3/76 [00:29<11:35,  9.52s/it][A
  5%|▌         | 4/76 [00:38<11:14,  9.37s/it][A
  7%|▋         | 5/76 [00:47<10:57,  9.25s/it][A
  8%|▊         | 6/76 [00:56<10:49,  9.27s/it][A
  9%|▉         | 7/76 [01:05<10:35,  9.22s/it][A
 11%|█         | 8/76 [01:14<10:26,  9.21s/it][A
 12%|█▏        | 9/76 [01:24<10:21,  9.28s/it][A
 13%|█▎        | 10/76 [01:33<10:14,  9.31s/it][A
 14%|█▍        | 11/76 [01:43<10:10,  9.39s/it][A
 16%|█▌        | 12/76 [01:52<09:53,  9.27s/it][A
 17%|█▋        | 13/76 [02:01<09:46,  9.31s/it][A
 18%|█▊        | 14/76 [02:11<09:42,  9.40s/it][A
 20%|█▉        | 15/76 [02:20<09:33,  9.41s/it][A
 21%|██        | 16/76 [02:29<09:18,  9.31s/it][A
 22%|██▏       | 17/76 [02:38<09:05,  9.24s/it][A
 24%|██▎       | 18/76 [02:47<08:49,  9.12s/it][A
 25%|██▌       | 19/76 [02:56<08:41,  9.14s/it]

Train Loss: 4.2406 | Validation Loss: 3.7305

Epoch 2


100%|██████████| 76/76 [11:25<00:00,  9.02s/it]


Train Loss: 3.4406 | Validation Loss: 3.0109

Epoch 3


100%|██████████| 76/76 [11:02<00:00,  8.71s/it]


Train Loss: 2.7119 | Validation Loss: 2.3127

Epoch 4


100%|██████████| 76/76 [10:58<00:00,  8.66s/it]


Train Loss: 2.0430 | Validation Loss: 1.7760

Epoch 5


100%|██████████| 76/76 [11:02<00:00,  8.71s/it]


Train Loss: 1.5482 | Validation Loss: 1.4283

Epoch 6


100%|██████████| 76/76 [11:07<00:00,  8.78s/it]


Train Loss: 1.2170 | Validation Loss: 1.2056

Epoch 7


100%|██████████| 76/76 [11:00<00:00,  8.69s/it]


Train Loss: 1.0032 | Validation Loss: 1.0325

Epoch 8


100%|██████████| 76/76 [11:13<00:00,  8.86s/it]


Train Loss: 0.8577 | Validation Loss: 0.9071

Epoch 9


100%|██████████| 76/76 [11:03<00:00,  8.73s/it]


Train Loss: 0.7532 | Validation Loss: 0.8335

Epoch 10


100%|██████████| 76/76 [10:57<00:00,  8.66s/it]


Train Loss: 0.6647 | Validation Loss: 0.7761

Epoch 11


100%|██████████| 76/76 [10:56<00:00,  8.64s/it]


Train Loss: 0.6208 | Validation Loss: 0.7219

Epoch 12


100%|██████████| 76/76 [10:54<00:00,  8.62s/it]


Train Loss: 0.5464 | Validation Loss: 0.6871

Epoch 13


100%|██████████| 76/76 [10:58<00:00,  8.67s/it]


Train Loss: 0.5004 | Validation Loss: 0.6707

Epoch 14


100%|██████████| 76/76 [11:01<00:00,  8.70s/it]


Train Loss: 0.4587 | Validation Loss: 0.6145

Epoch 15


100%|██████████| 76/76 [10:59<00:00,  8.68s/it]


Train Loss: 0.4183 | Validation Loss: 0.6162

Epoch 16


100%|██████████| 76/76 [11:12<00:00,  8.84s/it]


Train Loss: 0.3837 | Validation Loss: 0.5875

Epoch 17


100%|██████████| 76/76 [11:19<00:00,  8.94s/it]


Train Loss: 0.3483 | Validation Loss: 0.5612

Epoch 18


100%|██████████| 76/76 [11:34<00:00,  9.13s/it]


Train Loss: 0.3167 | Validation Loss: 0.5408

Epoch 19


100%|██████████| 76/76 [12:20<00:00,  9.75s/it]


Train Loss: 0.2943 | Validation Loss: 0.5096

Epoch 20


100%|██████████| 76/76 [11:55<00:00,  9.41s/it]


Train Loss: 0.2684 | Validation Loss: 0.5119

Epoch 21


100%|██████████| 76/76 [11:44<00:00,  9.27s/it]


Train Loss: 0.2445 | Validation Loss: 0.5105

Epoch 22


100%|██████████| 76/76 [11:30<00:00,  9.09s/it]


Train Loss: 0.2249 | Validation Loss: 0.4930

Epoch 23


100%|██████████| 76/76 [11:13<00:00,  8.86s/it]


Train Loss: 0.2078 | Validation Loss: 0.4812

Epoch 24


100%|██████████| 76/76 [10:57<00:00,  8.65s/it]


Train Loss: 0.1935 | Validation Loss: 0.4732

Epoch 25


100%|██████████| 76/76 [10:57<00:00,  8.65s/it]


Train Loss: 0.1751 | Validation Loss: 0.4624

Epoch 26


100%|██████████| 76/76 [11:03<00:00,  8.73s/it]


Train Loss: 0.1580 | Validation Loss: 0.4644

Epoch 27


100%|██████████| 76/76 [11:02<00:00,  8.71s/it]


Train Loss: 0.1478 | Validation Loss: 0.4562

Epoch 28


100%|██████████| 76/76 [10:56<00:00,  8.63s/it]


Train Loss: 0.1362 | Validation Loss: 0.4640

Epoch 29


100%|██████████| 76/76 [10:55<00:00,  8.63s/it]


Train Loss: 0.1219 | Validation Loss: 0.4486

Epoch 30


100%|██████████| 76/76 [10:55<00:00,  8.62s/it]


Train Loss: 0.1103 | Validation Loss: 0.4376

Epoch 31


100%|██████████| 76/76 [10:54<00:00,  8.61s/it]


Train Loss: 0.1055 | Validation Loss: 0.4579

Epoch 32


100%|██████████| 76/76 [10:54<00:00,  8.61s/it]


Train Loss: 0.0975 | Validation Loss: 0.4360

Epoch 33


100%|██████████| 76/76 [10:55<00:00,  8.63s/it]


Train Loss: 0.0892 | Validation Loss: 0.4360

Epoch 34


100%|██████████| 76/76 [10:54<00:00,  8.62s/it]


Train Loss: 0.0804 | Validation Loss: 0.4375

Epoch 35


100%|██████████| 76/76 [10:58<00:00,  8.66s/it]


Train Loss: 0.0744 | Validation Loss: 0.4358

Epoch 36


100%|██████████| 76/76 [10:58<00:00,  8.67s/it]


Train Loss: 0.0671 | Validation Loss: 0.4338

Epoch 37


100%|██████████| 76/76 [11:00<00:00,  8.69s/it]


Train Loss: 0.0627 | Validation Loss: 0.4429

Epoch 38


100%|██████████| 76/76 [10:57<00:00,  8.65s/it]


Train Loss: 0.0600 | Validation Loss: 0.4396

Epoch 39


100%|██████████| 76/76 [10:55<00:00,  8.63s/it]


Train Loss: 0.0559 | Validation Loss: 0.4327

Epoch 40


100%|██████████| 76/76 [10:58<00:00,  8.67s/it]


Train Loss: 0.0498 | Validation Loss: 0.4375

Epoch 41


100%|██████████| 76/76 [10:58<00:00,  8.67s/it]


Train Loss: 0.0471 | Validation Loss: 0.4361

Epoch 42


100%|██████████| 76/76 [11:04<00:00,  8.75s/it]


Train Loss: 0.0441 | Validation Loss: 0.4421

Early stopping triggered at epoch 42

✅ Classification Report:

                                          precision    recall  f1-score   support

                                    Acne       1.00      1.00      1.00        10
        Age-related Macular Degeneration       0.00      0.00      0.00         1
                       Allergic Rhinitis       1.00      1.00      1.00         1
                         Alopecia Areata       1.00      1.00      1.00         1
                     Alzheimer's Disease       0.00      0.00      0.00         1
                            Appendicitis       0.50      1.00      0.67         1
                               Arthritis       1.00      1.00      1.00        10
                                  Asthma       1.00      1.00      1.00         1
                     Atrial Fibrillation       0.00      0.00      0.00         1
Attention Deficit Hyperactivity Disorder       0.50      1.00      0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
