In [None]:
# dataset.csv is a tabel with two columns: one indicating the MRI protocol class, and the other containing the corresponding entries
# for preliminary diagnoses, prior treatment history, and the clinical question.

import pandas as pd
from sklearn.model_selection import KFold, StratifiedKFold
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSequenceClassification, AdamW
from torch.optim import AdamW
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.read_csv("dataset.csv")
df['category_id'] = df['MRI protocol class'].factorize()[0]
category_id_df = df[['MRI protocol class', 'category_id']].drop_duplicates().sort_values('category_id')
category_to_id = dict(category_id_df.values)
id_to_category = dict(category_id_df[['category_id', 'MRI protocol class']].values)
num_labels = len(df['MRI protocol class'].unique())

In [None]:
# Set up tokenizer

tokenizer = BertTokenizer.from_pretrained("/Path_to_BERT_model/", do_lower_case=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# KFold Cross-Validation

accuracies, balanced_accuracies, f1_scores_micro, f1_scores_macro, f1_scores_weighted = [], [], [], [], []
recall_scores_micro, recall_scores_macro, recall_scores_weighted = [], [], []
precision_scores_micro, precision_scores_macro, precision_scores_weighted = [], [], []
all_true_labels = []
all_predicted_labels = []

kf = StratifiedKFold(n_splits=5, shuffle=True)
fold = 1

for train_index, test_index in kf.split(df, df['category_id']):
    print(f"\nFold {fold}")
    
    # Split the dataset
    X_train, X_test = df['entries'].iloc[train_index].tolist(), df['entries'].iloc[test_index].tolist()
    y_train, y_test = df['category_id'].iloc[train_index].values, df['category_id'].iloc[test_index].values

    # Tokenize
    train_encodings = tokenizer(X_train, truncation=True, padding=True, return_tensors='pt')
    test_encodings = tokenizer(X_test, truncation=True, padding=True, return_tensors='pt')

    train_dataset = TensorDataset(train_encodings['input_ids'],
                                  train_encodings['attention_mask'],
                                  torch.tensor(y_train, dtype=torch.long))
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Load and prepare the model for training
    model = BertForSequenceClassification.from_pretrained("/Path_to_BERT_model/", num_labels=num_labels)
    model.to(device)  
    optimizer = AdamW(model.parameters(), lr=5e-5)

    label_counts = df['MRI protocol class'].value_counts()
    class_weights = torch.tensor(
        [1.0 / count for count in label_counts[sorted(label_counts.index)]] 
    ).to(device)
    loss_func = torch.nn.CrossEntropyLoss(weight=class_weights)

    # Fine-tune the model
    for epoch in range(22):
        print(f"Epoch {epoch + 1}")
        model.train()
        all_preds = []
        all_labels = []
        for batch in train_loader:
            input_ids, attention_mask, labels = batch
            optimizer.zero_grad()
            outputs = model(input_ids.to(device), attention_mask=attention_mask.to(device), labels=labels.to(device))
            loss = loss_func(outputs.logits, labels.to(device))
            loss.backward()
            optimizer.step()
            logits = outputs.logits
            preds = logits.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
        balanced_acc = balanced_accuracy_score(all_labels, all_preds)
        print(f"Balanced Accuracy: {balanced_acc:.4f}")

    # Evaluation on the test set
    model.eval()
    all_predictions = []
    all_probabilities = []
    with torch.no_grad():
        for i in range(0, len(test_encodings['input_ids']), 32):
            input_ids = test_encodings['input_ids'][i:i + 32].to(device)
            attention_mask = test_encodings['attention_mask'][i:i + 32].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=1)
            all_predictions.extend(predictions.cpu().numpy())
            probabilities = F.softmax(outputs.logits, dim=1)
            all_probabilities.extend(probabilities.cpu().numpy())

    all_true_labels.extend(y_test)
    all_predicted_labels.extend(all_predictions)

    accuracy = accuracy_score(y_test, all_predictions)
    b_accuracy = balanced_accuracy_score(y_test, all_predictions)
    f1_ma = f1_score(y_test, all_predictions, average="macro")
    f1_mi = f1_score(y_test, all_predictions, average="micro")
    f1_w = f1_score(y_test, all_predictions, average="weighted")
    recall_ma = recall_score(y_test, all_predictions, average="macro")
    recall_mi = recall_score(y_test, all_predictions, average="micro")
    recall_w = recall_score(y_test, all_predictions, average="weighted")
    precision_ma = precision_score(y_test, all_predictions, average="macro")
    precision_mi = precision_score(y_test, all_predictions, average="micro")
    precision_w = precision_score(y_test, all_predictions, average="weighted")

    print(f"Fold {fold} Accuracy: {accuracy}")
    print(f"Fold {fold} Balanced Accuracy: {b_accuracy}")
    print(f"Fold {fold} F1-Score: {f1_ma}")
    print(f"Fold {fold} Recall: {recall_ma}")
    print(f"Fold {fold} Precision: {precision_ma}")

    f1_scores_macro.append(f1_ma)
    f1_scores_micro.append(f1_mi)
    f1_scores_weighted.append(f1_w)
    recall_scores_macro.append(recall_ma)
    recall_scores_micro.append(recall_mi)
    recall_scores_weighted.append(recall_w)
    precision_scores_macro.append(precision_ma)
    precision_scores_micro.append(precision_mi)
    precision_scores_weighted.append(precision_w)
    accuracies.append(accuracy)
    balanced_accuracies.append(b_accuracy)
    
    fold += 1

In [None]:
# Evaluation

print(f"\nAverage Accuracy: {sum(accuracies) / len(accuracies)}")
print(f"Average Balanced Accuracy: {sum(balanced_accuracies) / len(balanced_accuracies)}")
print(f"\nAverage F1-Score ma: {sum(f1_scores_macro) / len(f1_scores_macro)}")
print(f"\nAverage F1-Score mi: {sum(f1_scores_micro) / len(f1_scores_micro)}")
print(f"\nAverage F1-Score w: {sum(f1_scores_weighted) / len(f1_scores_weighted)}")
print(f"\nAverage Recall ma: {sum(recall_scores_macro) / len(recall_scores_macro)}")
print(f"\nAverage Recall mi: {sum(recall_scores_micro) / len(recall_scores_micro)}")
print(f"\nAverage Recall w: {sum(recall_scores_weighted) / len(recall_scores_weighted)}")
print(f"\nAverage Precision ma: {sum(precision_scores_macro) / len(precision_scores_macro)}")
print(f"\nAverage Precision mi: {sum(precision_scores_micro) / len(precision_scores_micro)}")
print(f"\nAverage Precision w: {sum(precision_scores_weighted) / len(precision_scores_weighted)}")

conf_mat = confusion_matrix(all_true_labels, all_predicted_labels)
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(conf_mat, annot=True, fmt='d',
            xticklabels=category_id_df['MRI protocol class'].values, yticklabels=category_id_df['MRI protocol class'].values)
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix Across All Folds')
plt.show()