In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import DatasetDict
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchmetrics
import polars as pl
import numpy as np
import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths
from src.utils import plot_embeddings
from huggingface_hub import notebook_login
import tqdm
from sklearn.preprocessing import OneHotEncoder
import matplotlib.pyplot as plt
from accelerate import Accelerator

In [None]:
# Login to Hugging Face Hub as model is gated
notebook_login()

In [None]:
# Load dataset
dataset = DatasetDict.load_from_disk(paths.DATA_PATH_PREPROCESSED/'line_labelling_clean_dataset')

# Num Labels
num_labels = len(set(dataset['train']['label']))

In [None]:
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Checkpoint
checkpoint = "GerMedBERT/medbert-512"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Load model for embedding
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels, problem_type="multi_label_classification").to(device)

In [None]:
# Filter out None labels
# train_dataset = train_dataset.filter(lambda example: example['label'] is not None)
# val_dataset = val_dataset.filter(lambda example: example['label'] is not None)
# test_dataset = test_dataset.filter(lambda example: example['label'] is not None)

In [None]:
# Define Dataset
class MedDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe: DatasetDict, tokenizer, max_length: int = 512, split: str = 'train'):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False).fit(np.stack(self.dataframe['train']['label']).reshape(-1, 1))
        self.labels = self.enc.transform(np.stack(self.dataframe[split]['label']).reshape(-1, 1))
        self.encodings = self.tokenizer(self.dataframe[split]['text'], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')

    def __getitem__(self, idx):
        item = {key: (val[idx].clone().detach()) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
# Create Datasets
train_dataset = MedDataset(dataset, tokenizer, split='train')
val_dataset = MedDataset(dataset, tokenizer, split='validation')
test_dataset = MedDataset(dataset, tokenizer, split='test')

In [None]:
# Pytorch Implementation

# # Set only specific layers to be trainable
# for param in model.base_model.parameters():
#     param.requires_grad = False

# # Set only specific layers to be trainable
# for param in model.classifier.parameters():
#     param.requires_grad = True

# Dataloader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

# Optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# Loss
def loss_fn(logits, targets):
    loss = (torch.nn.CrossEntropyLoss()(logits, targets) + 
            torchmetrics.classification.MulticlassF1Score(num_classes=num_labels, average='weighted').to(device)(logits, targets))
    return loss

# GPU Memory optimization
model.gradient_checkpointing_enable()
accelerator = Accelerator(fp16=True)
model, optimizer, train_loader, val_loader, test_loader = accelerator.prepare(model, optimizer, train_loader, val_loader, test_loader)

In [None]:
# Training
epochs = 36

for epoch in range(epochs):
    pbar = tqdm.tqdm(train_loader)
    for i, batch in enumerate(pbar):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = loss_fn(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        # Logging Progress
        if i % 10 == 0:
            pbar.set_description(f"Epoch {epoch} training loss: {loss.item()}")
    
    # Evaluate on Validation
    val_CE_loss = []
    val_f1 = []

    pbar = tqdm.tqdm(val_loader)
    pbar.set_description(f"Epoch {epoch} Validation")
    for batch in pbar:
        with torch.no_grad():
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        val_CE_loss.append(torch.nn.CrossEntropyLoss()(outputs.logits, labels).item())
        val_f1.append(torchmetrics.classification.MulticlassF1Score(num_classes=num_labels, average='weighted').to(device)(outputs.logits, labels).item())
    
    print(f"Epoch {epoch} CrossEntropy Val loss: {np.mean(val_CE_loss)}")
    print(f"Epoch {epoch} F1 Val score: {np.mean(val_f1)}")
    
    # # Saving Model    
    # if epoch % 10 == 0:
    #     torch.save(model.state_dict(), paths.MODEL_PATH/f"line-label_medBERT-finetuned_{epoch}.pt")


In [None]:
torch.cuda.empty_cache()

In [None]:
# Free GPU Memory
torch.cuda.empty_cache()
del input_ids
del attention_mask
del labels

In [None]:
paths.MODEL_PATH/f"line-label_medBERT-finetuned_{epoch}.pt"

In [None]:
model.eval()

# Evaluate on test set
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)
outputs = []

with torch.no_grad():
    for batch in tqdm.tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs.append(model(input_ids, attention_mask=attention_mask))


In [None]:
# Get predictions
preds = [np.argmax(output.logits.cpu().numpy(), axis=1) for output in outputs]
preds = np.concatenate(preds)

# Get true labels
true = np.argmax(test_labels_enc, axis=1)

# Calculate accuracy
acc = np.sum(preds == true) / len(true)

# F1 Score
from sklearn.metrics import f1_score
f1 = f1_score(true, preds, average='weighted')
print(f"Accuracy: {acc}")
print(f"F1 Score: {f1}")

In [None]:
# Training Arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=4,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=200,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    load_best_model_at_end=True,
    save_strategy='epoch',
    evaluation_strategy='epoch',
)

# Trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset            # evaluation dataset
)

In [None]:
# Setting pooling and head to trainable
for name, param in trainer.model.named_parameters():
    if "pooler" in name or "classifier" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

In [None]:
# Train
trainer.train()

In [None]:
# Save model
trainer.save_model(os.path.join(paths.MODEL_PATH, "medbert-diag-label"))

In [None]:
# Evaluate
trainer.evaluate()

In [None]:
# Predict
predictions = trainer.predict(test_dataset)

In [None]:
# Accuracy
preds = np.argmax(predictions.predictions, axis=1)
labels = np.argmax(predictions.label_ids, axis=1)

print(f"Accuracy: {np.sum(preds == labels) / len(labels)}")

In [None]:
# f1 score, precision, recall
from sklearn.metrics import f1_score, precision_score, recall_score

print(f"F1 Score: {f1_score(labels, preds, average='macro')}")
print(f"Precision: {precision_score(labels, preds, average='macro')}")
print(f"Recall: {recall_score(labels, preds, average='macro')}")

In [None]:
predictions

In [None]:
enc.categories_[0]

In [None]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

disp = ConfusionMatrixDisplay.from_predictions(labels, preds, display_labels=enc.categories_[0], xticks_rotation=90)
# Plot the confusion matrix with rotated x-axis labels
# fig, ax = plt.subplots(figsize=(8, 6))
# disp.plot(ax=ax, xticks_rotation=45)  # Adjust the rotation angle as needed
# plt.show()

In [None]:
# Get pooled embeddings
embeddings = []
batch_size = 16

for i in tqdm.tqdm(range(0, len(df), batch_size)):
    tokens = tokenizer(df['text'][i:i+batch_size].to_list(), padding=True, truncation=True, return_tensors="pt").to(device)
    attention_mask = tokens["attention_mask"][i:i+batch_size]
    with torch.no_grad():
        embeddings.append(trainer.model(**tokens, output_hidden_states=True).hidden_states[-1].cpu())
    del tokens

In [None]:
# Save embeddings
torch.save(embeddings, os.path.join(paths.DATA_PATH_PREPROCESSED, "embeddings-fine-tuned.pt"))

In [None]:
# Load embeddings
embeddings = torch.load(os.path.join(paths.DATA_PATH_PREPROCESSED, "embeddings-fine-tuned.pt"))

# Mean over sequence
embeddings_mean = [torch.mean(embedding, dim=1) for embedding in embeddings]
embeddings_mean = torch.cat(embeddings_mean, dim=0)

# Plot Mean Embeddings
plot_embeddings(embeddings_mean, df["class_agg"], title="Mean Embeddings", method="pca")

In [None]:
plot_embeddings(embeddings_mean, df["class_agg"], title="Mean Embeddings", method="umap")