In [None]:
!nvidia-smi

In [22]:
import warnings
warnings.filterwarnings("ignore")

In [21]:
import pandas as pd

from datasets import load_dataset
from sklearn.model_selection import train_test_split

import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments

from sklearn.metrics import (accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report, roc_auc_score, 
                             roc_curve, precision_recall_curve, auc)

from sklearn.preprocessing import label_binarize
from collections import Counter
import matplotlib.pyplot as plt 

In [16]:

import matplotlib.pyplot as plt
from collections import Counter
from sklearn.preprocessing import label_binarize

data = load_dataset("NickyNicky/medical_mtsamples", split='train')
data = pd.DataFrame(data)

texts = data['transcription'].astype(str).tolist()
labels = data['medical_specialty'].astype('category').cat.codes.tolist()

train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.2, stratify=labels
)

tokenizer = BertTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

def tokenize_function(texts):
    return tokenizer(texts, padding='max_length', truncation=True, max_length=128, return_tensors="pt")

train_encodings = tokenize_function(train_texts)
val_encodings = tokenize_function(val_texts)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Dataset(train_encodings, train_labels)
val_dataset = Dataset(val_encodings, val_labels)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

model = BertForSequenceClassification.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', num_labels=len(set(labels)))

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy='epoch'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

metrics = trainer.evaluate()
print(metrics)

print("Training Set Label Distribution:", Counter(train_labels))
print("Validation Set Label Distribution:", Counter(val_labels))

val_preds = trainer.predict(val_dataset)
preds = val_preds.predictions.argmax(-1)
probs = torch.nn.functional.softmax(torch.tensor(val_preds.predictions), dim=-1).numpy()

cm = confusion_matrix(val_labels, preds)
report = classification_report(val_labels, preds)

print("Confusion Matrix:")
print(cm)
print("\nClassification Report:")
print(report)

if len(set(val_labels)) > 1:
    val_labels_binarized = label_binarize(val_labels, classes=range(len(set(labels))))

    roc_auc = {}
    fpr = {}
    tpr = {}
    for i in range(len(set(labels))):
        fpr[i], tpr[i], _ = roc_curve(val_labels_binarized[:, i], probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    plt.figure()
    for i in range(len(set(labels))):
        plt.plot(fpr[i], tpr[i], label=f'ROC curve of class {i} (area = {roc_auc[i]:.2f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.show()

    precision = {}
    recall = {}
    pr_auc = {}
    for i in range(len(set(labels))):
        precision[i], recall[i], _ = precision_recall_curve(val_labels_binarized[:, i], probs[:, i])
        pr_auc[i] = auc(recall[i], precision[i])

    plt.figure()
    for i in range(len(set(labels))):
        plt.plot(recall[i], precision[i], label=f'Precision-Recall curve of class {i} (area = {pr_auc[i]:.2f})')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower right')
    plt.show()
else:
    print("Not enough classes in validation set to compute ROC AUC and PR curves.")

KeyboardInterrupt: 