In [69]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset, Dataset, DatasetDict
from sklearn.preprocessing import LabelEncoder
import joblib
import pandas as pd
import torch
from skmultilearn.model_selection import iterative_train_test_split
from database_commands.create_tables import app, mysql, Patient, Doctor, PatientDoctor, ClinicalInfo, Treatment

df = pd.read_csv('dataset.csv')
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")

In [11]:
df.drop(columns=['patient_id','name','age', 'sex', 'id.1', 'patient_id.1', 'doctor_id', 'id.2', 'name.1', 'specialty', 'contact_info', 'id.3', 'patient_id.2', 'doctor_id.1', 'histological_subtype','second_tumor_size', 'third_tumor_size', 'method', 'ki_67', 'clinical_stage'], inplace=True)
df = df.reset_index(drop=True)

In [12]:
df = df.dropna(subset=['node_status', 'tnm_tumor', 'tumor_grade']).reset_index(drop=True)

In [13]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np

test_df = df[df["id"].isin([2, 12, 22, 29])]
val_df = df[df["id"].isin([6, 10, 16, 21, 13, 45, 25, 28])]
train_df = df[~df["id"].isin([2, 12, 22, 29, 6, 10, 16, 21, 13, 45, 25, 28])].copy()

print(test_df)


    id  tumor_size  tumor_grade  node_status  metastasis tnm_tumor  er_status  \
1    2         3.0          2.0          1.0           0         3          1   
7   12         4.7          2.0          0.0           0         2          1   
16  22         3.4          3.0          1.0           0         3          0   
28  29         5.0          1.0          1.0           1         3          0   

    pr_status  her2_status  ki_67_categorised  \
1           1            0                  3   
7           0            0                  3   
16          0            0                  3   
28          0            0                  3   

                                 final_treatment_plan  
1   Mastectomy, Axillary Surgery, Adjuvant Chemoth...  
7   Mastectomy, Axillary Surgery, Adjuvant Chemoth...  
16  Neoadjuvant Chemotherapy, Mastectomy, Axillary...  
28               Neoadjuvant Chemotherapy, Mastectomy  


In [14]:
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
mlb = MultiLabelBinarizer()

labels = [list(map(str.strip, plan.split(","))) for plan in train_df["final_treatment_plan"]]
mlb.fit(labels)
num_classes = len(mlb.classes_)

labels_bin = mlb.transform(labels)
print("Learned labels:", mlb.classes_)
class_counts = np.sum(labels_bin, axis=0)
total_labels = np.sum(class_counts)
class_weights = {i: total_labels / (len(class_counts) * count) if count > 0 else 1.0
                 for i, count in enumerate(class_counts)}

class_weights_list = [class_weights.get(i, 1.0) for i in range(len(mlb.classes_))]

class_weights_tensor = torch.tensor(class_weights_list, dtype=torch.float32)
print(f"Class counts: {class_counts}")
print(f"Class weights: {class_weights}")
print(f"Class Weights Tensor: {class_weights_tensor}")


Learned labels: ['Adjuvant Chemotherapy' 'Axillary Surgery' 'Breast Conserving Surgery'
 'Herceptin' 'Hormonal Therapy' 'Mastectomy' 'Neoadjuvant Chemotherapy'
 'Radiotherapy']
Class counts: [15 20  8  2 13 15 12 20]
Class weights: {0: np.float64(0.875), 1: np.float64(0.65625), 2: np.float64(1.640625), 3: np.float64(6.5625), 4: np.float64(1.0096153846153846), 5: np.float64(0.875), 6: np.float64(1.09375), 7: np.float64(0.65625)}
Class Weights Tensor: tensor([0.8750, 0.6562, 1.6406, 6.5625, 1.0096, 0.8750, 1.0938, 0.6562])


In [15]:
import pandas as pd
import numpy as np
from sklearn.utils import resample

train_df["final_treatment_plan"] = train_df["final_treatment_plan"].str.split(", ")
mlb = MultiLabelBinarizer()
labels_bin = mlb.fit_transform(train_df["final_treatment_plan"])
df_labels = pd.DataFrame(labels_bin, columns=mlb.classes_)
label_counts = df_labels.sum(axis=0)
max_count = label_counts.max()

df_oversampled = train_df.copy()
for label in label_counts.index:
    if label_counts[label] < max_count:
        subset = train_df[train_df["final_treatment_plan"].apply(lambda x: label in x)]
        num_samples_needed = max_count - label_counts[label]
        resampled_subset = resample(subset, replace=True, n_samples=num_samples_needed, random_state=42)
        df_oversampled = pd.concat([df_oversampled, resampled_subset])

df_oversampled = df_oversampled.reset_index(drop=True)
ds = DatasetDict({
    "train": Dataset.from_pandas(df_oversampled),
    "validation": Dataset.from_pandas(val_df),
    "test": Dataset.from_pandas(test_df)
})
print(f"Original dataset size: {len(train_df)}")
print(f"Oversampled dataset size: {len(df_oversampled)}")

Original dataset size: 24
Oversampled dataset size: 79


In [9]:
mlb_split = MultiLabelBinarizer()
y = mlb_split.fit_transform(df_oversampled["final_treatment_plan"])
X = np.array(df_oversampled.index).reshape(-1, 1)

X_train, y_train, X_temp, y_temp = iterative_train_test_split(X, y, test_size=0.3)
X_val, y_val, X_test, y_test = iterative_train_test_split(X_temp, y_temp, test_size=0.4)

ds = DatasetDict({
    "train": Dataset.from_pandas(df_oversampled.iloc[X_train.flatten()]),
    "validation": Dataset.from_pandas(df_oversampled.iloc[X_val.flatten()]),
    "test": Dataset.from_pandas(df_oversampled.iloc[X_test.flatten()])
})

print(ds)

DatasetDict({
    train: Dataset({
        features: ['id', 'tumor_size', 'tumor_grade', 'node_status', 'metastasis', 'tnm_tumor', 'er_status', 'pr_status', 'her2_status', 'ki_67_categorised', 'final_treatment_plan', '__index_level_0__'],
        num_rows: 55
    })
    validation: Dataset({
        features: ['id', 'tumor_size', 'tumor_grade', 'node_status', 'metastasis', 'tnm_tumor', 'er_status', 'pr_status', 'her2_status', 'ki_67_categorised', 'final_treatment_plan', '__index_level_0__'],
        num_rows: 14
    })
    test: Dataset({
        features: ['id', 'tumor_size', 'tumor_grade', 'node_status', 'metastasis', 'tnm_tumor', 'er_status', 'pr_status', 'her2_status', 'ki_67_categorised', 'final_treatment_plan', '__index_level_0__'],
        num_rows: 10
    })
})


In [16]:
def create_text_representation(example):
    return {
        "text": f"Tumor size: {example['tumor_size']}, Grade: {example['tumor_grade']}, "
                f"Node status: {example['node_status']}, Metastasis: {example['metastasis']}, "
                f"TNM: {example['tnm_tumor']}, ER: {example['er_status']}, "
                f"PR: {example['pr_status']}, HER2: {example['her2_status']}, "
                f"Ki-67: {example['ki_67_categorised']}"
    }

ds = ds.map(create_text_representation)

Map: 100%|██████████| 79/79 [00:00<00:00, 1781.34 examples/s]
Map: 100%|██████████| 8/8 [00:00<00:00, 916.79 examples/s]
Map: 100%|██████████| 4/4 [00:00<00:00, 379.51 examples/s]


In [17]:
joblib.dump(mlb, 'label_encoder.pkl')

['label_encoder.pkl']

In [18]:
def encode_labels(example):
    labels = example["final_treatment_plan"]
    if isinstance(labels, str):
        labels = labels.split(", ") 
    binarized_labels = mlb.transform([labels])[0]
    return {"labels": list(map(float, binarized_labels))}

labelled_ds = ds.map(encode_labels)

Map: 100%|██████████| 79/79 [00:00<00:00, 2185.92 examples/s]
Map: 100%|██████████| 8/8 [00:00<00:00, 799.41 examples/s]
Map: 100%|██████████| 4/4 [00:00<00:00, 436.08 examples/s]


In [19]:
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=512)

tokenized_ds = labelled_ds.map(tokenize_function, batched=True)
tokenized_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

tokenized_ds['validation'][0],tokenized_ds['train'][20],tokenized_ds['train'][21]

Map: 100%|██████████| 79/79 [00:00<00:00, 958.62 examples/s]
Map: 100%|██████████| 8/8 [00:00<00:00, 567.29 examples/s]
Map: 100%|██████████| 4/4 [00:00<00:00, 431.77 examples/s]


({'labels': tensor([1., 1., 0., 0., 1., 1., 0., 1.]),
  'input_ids': tensor([    2,  2798,  3081,    30,    24,    18,    27,    16,  4927,    30,
             23,    18,    20,    16,  6267,  3642,    30,    21,    18,    20,
             16,  5967,    30,    20,    16, 19400,    30,    23,    16,  2668,
             30,    21,    16,  2033,    30,    21,    16,  9162,    30,    21,
             16,  8191,    17,  4838,    30,    23,     3,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,  

In [20]:
model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", num_labels=len(mlb.classes_), problem_type="multi_label_classification")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs["labels"] 
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        criterion = torch.nn.BCEWithLogitsLoss(weight=class_weights_tensor) 
        loss = criterion(logits, labels)

        return (loss, outputs) if return_outputs else loss


In [22]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1.288375733581335e-05,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    weight_decay=0.011386646772668109,
    warmup_ratio=0.01662716373178627,
    load_best_model_at_end = True
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
)



In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,No log,1.040119
2,No log,0.999417
3,No log,0.987248
4,No log,0.977278
5,No log,0.976778
6,No log,0.972021


TrainOutput(global_step=60, training_loss=1.0017630259195964, metrics={'train_runtime': 1052.5114, 'train_samples_per_second': 0.45, 'train_steps_per_second': 0.057, 'total_flos': 124721358815232.0, 'train_loss': 1.0017630259195964, 'epoch': 6.0})

In [72]:
import joblib
import numpy as np
import pandas as pd
from scipy.special import expit
from transformers import Trainer
from sklearn.metrics import accuracy_score, f1_score,  precision_score, recall_score
from sklearn.metrics import precision_recall_curve

trainer = Trainer(model=model) 
mlb = joblib.load("label_encoder.pkl")
validation_logits = trainer.predict(tokenized_ds["validation"]).predictions
validation_probs = expit(validation_logits) 

labels = np.array(tokenized_ds["validation"]["labels"])
optimal_thresholds = np.zeros(validation_probs.shape[1])

for i in range(validation_probs.shape[1]):
    precision, recall, thresholds = precision_recall_curve(labels[:, i], validation_probs[:, i])
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-6)
    best_label_threshold = thresholds[np.argmax(f1_scores)] 
    optimal_thresholds[i] = best_label_threshold
joblib.dump(optimal_thresholds, "optimal_thresholds.pkl")
print(f"Thresholds (from validation set): {optimal_thresholds}")

labels = np.array(tokenized_ds["test"]["labels"])
test_logits = trainer.predict(tokenized_ds["test"]).predictions
test_probs = expit(test_logits)
predictions = (test_probs > optimal_thresholds).astype(int)

for i, label in enumerate(mlb.classes_):
    label_pred_count = np.sum(test_probs[:, i] > optimal_thresholds[i])
    label_true_count = np.sum(labels[:, i]) 
    print(f"Label: {label}, Predictions Made: {label_pred_count}, Actual Frequency: {label_true_count}")

labels = np.array(tokenized_ds["test"]["labels"])
predictions_text = mlb.inverse_transform(predictions)
labels_text = mlb.inverse_transform(labels)

correct_counts = []
total_counts = []
per_question_f1_scores = []

for i, (pred, true) in enumerate(zip(predictions_text, labels_text)):
    pred_set = set(pred)
    true_set = set(true)
    
    correct = len(pred_set & true_set) 
    total = len(true_set)
    accuracy = correct / total if total > 0 else 0
    correct_counts.append(correct)
    total_counts.append(total)

    sample_f1 = f1_score(labels[i], predictions[i], average="micro") 
    per_question_f1_scores.append(sample_f1)
    print(f"Sample {i+1}:")
    print(f"  Prediction: {pred}")
    print(f"  Actual: {true}")
    print(f"  Correct Predictions: {correct}/{total} ({accuracy:.2%})")
    print("-" * 50)

macro_f1 = f1_score(labels, predictions, average="macro")
micro_f1 = f1_score(labels, predictions, average="micro")
micro_precision = precision_score(labels, predictions, average="micro")
micro_recall = recall_score(labels, predictions, average="micro")
mean_f1_per_question = np.mean(per_question_f1_scores)
print(f"Macro F1 Score: {macro_f1:.4f}")
print(f"Micro F1 Score: {micro_f1:.4f}")
print(f"Micro Precision: {micro_precision:.4f}")
print(f"Micro Recall: {micro_recall:.4f}")

Thresholds (from validation set): [0.67067975 0.6812135  0.29535976 0.35007584 0.49219078 0.57367903
 0.60886598 0.72942328]


Label: Adjuvant Chemotherapy, Predictions Made: 4, Actual Frequency: 2.0
Label: Axillary Surgery, Predictions Made: 4, Actual Frequency: 3.0
Label: Breast Conserving Surgery, Predictions Made: 3, Actual Frequency: 0.0
Label: Herceptin, Predictions Made: 0, Actual Frequency: 0.0
Label: Hormonal Therapy, Predictions Made: 2, Actual Frequency: 2.0
Label: Mastectomy, Predictions Made: 4, Actual Frequency: 4.0
Label: Neoadjuvant Chemotherapy, Predictions Made: 3, Actual Frequency: 2.0
Label: Radiotherapy, Predictions Made: 4, Actual Frequency: 3.0
Sample 1:
  Prediction: ('Adjuvant Chemotherapy', 'Axillary Surgery', 'Breast Conserving Surgery', 'Hormonal Therapy', 'Mastectomy', 'Neoadjuvant Chemotherapy', 'Radiotherapy')
  Actual: ('Adjuvant Chemotherapy', 'Axillary Surgery', 'Hormonal Therapy', 'Mastectomy', 'Radiotherapy')
  Correct Predictions: 5/5 (100.00%)
--------------------------------------------------
Sample 2:
  Prediction: ('Adjuvant Chemotherapy', 'Axillary Surgery', 'Breast Co

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [91]:
optimal_thresholds = joblib.load('optimal_thresholds.pkl')
label_encoder = joblib.load('label_encoder.pkl')
inputs = tokenizer("", truncation=True, padding="max_length", max_length=512, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
logits = outputs.logits
probabilities = expit(logits.detach().cpu().numpy())
predictions = (probabilities > optimal_thresholds).astype(int)  
predicted_labels = label_encoder.inverse_transform(predictions)

In [29]:
import json
metrics = {}
metrics['Text-Classification'] = {
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "micro_precision": micro_precision,
        "micro_recall": micro_recall,
        "mean_f1_per_question": mean_f1_per_question,
        "f1_per_question": per_question_f1_scores
    }
with open("model_metrics.json", "w") as f:
    json.dump(metrics, f)

In [70]:
trainer.save_model("models/bert_model/text_classification")
tokenizer.save_pretrained("models/bert_model/text_classification")

('models/bert_model/text_classification\\tokenizer_config.json',
 'models/bert_model/text_classification\\special_tokens_map.json',
 'models/bert_model/text_classification\\vocab.txt',
 'models/bert_model/text_classification\\added_tokens.json',
 'models/bert_model/text_classification\\tokenizer.json')