In [None]:
from typing import Any, Optional, Tuple, Union
import functools
from torch.utils.data import Dataset, DataLoader, RandomSampler
import torch
from collections import OrderedDict
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch import nn
from torch import Tensor as T
from transformers import BertForMaskedLM, BertConfig, BertTokenizer, AutoModel
from transformers.modeling_outputs import ModelOutput
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
import pytorch_lightning as pl
import CXRBERT


In [None]:
checkpoint = 'allenai/scibert_scivocab_cased'
config = CXRBERT.CXRBertConfig.from_pretrained(checkpoint)
tokenizer = CXRBERT.CXRBertTokenizer.from_pretrained(checkpoint,padding="max_length", truncation=True, max_length=512)
model = CXRBERT.CXRBertModel(config).from_pretrained(checkpoint)

In [None]:
path = './CXR-BERT/2v87h10f/checkpoints/epoch=49-step=25000.ckpt'


model_cpt = torch.load(path)

new_state_dict = OrderedDict()
for k, v in model_cpt['state_dict'].items():
    name = k[6:] # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

In [None]:
from datasets import load_dataset
dataset = load_dataset("csv", data_files={"train": "5000_train_data_full_reports.csv",
                                           "validation": "full_val_with_c.csv",
                                             "test": "full_test_with_c.csv"})

In [None]:
labels = [label for label in dataset['train'].features.keys() if label not in ['subject_id', 'study_id', 'report','c_report']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

In [None]:
from transformers import AutoTokenizer
import numpy as np

tokenizer = AutoTokenizer.from_pretrained(checkpoint, do_lower_case = False)

def preprocess_data(examples):
  # take a batch of texts
  text = examples["c_report"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=512)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()
  
  return encoding

In [None]:
import os
from collections import OrderedDict
import torch
from transformers import AutoModelForSequenceClassification, AutoModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

In [None]:
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)
encoded_dataset.set_format("torch", device = device)
encoded_dataset

In [None]:
model2 = AutoModelForSequenceClassification.from_pretrained(checkpoint,
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id,
                                                          #attention_probs_dropout_prob=0.5,
                                                          #hidden_dropout_prob=0.5,
                                                           )

#for pre-trained models in this thesis we swap the weights with the model architecture
# model2.bert.embeddings = model.base_model.embeddings
# model2.bert.encoder = model.base_model.encoder
##model2.bert.pooler.dense = torch.nn.Linear(in_features=768, out_features=768, bias=True)
## model2.bert.pooler.activation = torch.nn.Tanh()

model2 = model2.to(device)

In [None]:
# ##freezing params for main model and only training the classifier layers for experiments
# for param in model2.bert.parameters():
#     param.requires_grad = False

In [None]:
batch_size = 32
metric_name = "f1_macro"
from transformers import EarlyStoppingCallback, IntervalStrategy

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"./",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    logging_strategy = 'epoch',
    learning_rate= 3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs= 4,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #save_total_limit=1,
    #push_to_hub=True,
    seed = 9000,
    report_to = None
)

In [None]:
import torch
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, precision_recall_fscore_support, roc_curve
from transformers import EvalPrediction
import evaluate

def my_multi_label_metrics(predictions, labels):
    # Convert the predictions and labels to numpy arrays
    #preds = predictions #.detach().cpu().numpy()
    sigmoid = torch.nn.Sigmoid()
    preds = sigmoid(torch.Tensor(predictions))
    labels = labels #.detach().cpu().numpy()
    # Initialize the thresholds, precision, and recall arrays
    thresholds = []
    precisions = []
    recalls = []
    for i in range(preds.shape[1]):
        precision, recall, threshold = precision_recall_curve(labels[:, i], preds[:, i])
        thresholds.append(threshold)
        precisions.append(precision)
        recalls.append(recall)
    f1_scores = 2 * (np.array(precisions) * np.array(recalls)) / (np.array(precisions) + np.array(recalls))
    best_thresholds = [thresholds[i][np.argmax(f1_scores[i])] for i in range(len(f1_scores))]
    print(best_thresholds)


    
    probs = sigmoid(torch.Tensor(predictions))
    #y_pred = np.zeros(probs.shape)
    #print(probs.shape)
    pred_labels = np.zeros_like(probs)
    for i in range(preds.shape[1]):
        pred_labels[:, i] = np.where(preds[:, i] > best_thresholds[i], 1, 0)
    
    y_pred = pred_labels
    y_true = labels

    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, probs, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    f1_macro_average = f1_score(y_true=y_true, y_pred=y_pred, average='macro')
    #roc_auc_mac = roc_auc_score(y_true, y_pred, average = 'macro')
    roc_auc_mac = roc_auc_score(y_true, probs, average = 'macro')
    f1_w_average = f1_score(y_true=y_true, y_pred=y_pred, average='weighted')
    roc_auc_w = roc_auc_score(y_true, probs, average = 'weighted')
    #############################################################################
    roc_auc_score2 = evaluate.load("roc_auc", "multilabel")
    results = roc_auc_score2.compute(references=y_true, prediction_scores=probs, average = None)['roc_auc']
    
    #precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    # return as dictionary
    metrics = {'f1_micro': f1_micro_average,
               'roc_auc_micro': roc_auc,
               'f1_macro': f1_macro_average,
               'roc_auc_macro': roc_auc_mac,
               'f1_weighted': f1_w_average,
               'roc_auc_weighted': roc_auc_w,
               'accuracy': accuracy,
               'roc_auc_per_class': [round(res, 3) for res in results]
               }
    return metrics

def my_compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = my_multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [None]:
trainer = Trainer(
    model2,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"], 
    tokenizer=tokenizer,
    compute_metrics=my_compute_metrics,
    #callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
    
)

trainer.train()

In [None]:
raw_pred = trainer.predict(encoded_dataset["test"])
raw_pred.metrics