In [None]:
import numpy as np
import pandas as pd
from fastai import *
from fastai.text import *

from clinical_note_utils import *

from transformers import AutoTokenizer, AutoModel
from transformers import BertForSequenceClassification, BertTokenizer

In [None]:
valid_roc = []
test_roc = []

In [None]:
path = "/home/littlefield/MIMIC-NLP/readmission-prediction/"
bs = 64

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT")

In [None]:
class FastAiBertTokenizer(BaseTokenizer):
    """Wrapper around BertTokenizer to be compatible with fast.ai"""
    def __init__(self, tokenizer: BertTokenizer, max_seq_len: int=128, **kwargs):
        self._pretrained_tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __call__(self, *args, **kwargs):
        return self

    def tokenizer(self, t:str) -> List[str]:
        """Limits the maximum sequence length"""
        return ["[CLS]"] + self._pretrained_tokenizer.tokenize(t)[:self.max_seq_len - 2] + ["[SEP]"]

In [None]:
clinical_data = load_data(path, 'data/bert-clinical.pkl', bs=bs)

In [None]:
clinical_bert = BertForSequenceClassification.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT",
              num_labels = 2)

In [None]:
class CustomTransformerModel(nn.Module):
  
    def __init__(self, transformer_model: BertForSequenceClassification, include_act=False, act_func=None):
        super(CustomTransformerModel,self).__init__()
        self.include_act = include_act
        self.transformer = transformer_model
        
        if include_act:
            self.act = act_func
   
    def forward(self, x):
        # Return only the logits from the transfomer
        logits = self.transformer(x)[0] 
        
        if self.include_act:
            return self.act(logits)
        
        return logits.reshape(-1)

model = CustomTransformerModel(clinical_bert, include_act=True, act_func=nn.Sigmoid())

from fastai.callbacks import *

learn = Learner(clinical_data, model, loss_func=nn.CrossEntropyLoss(), metrics=[AUROC(), Precision(), Recall()])

In [None]:
learn.load("clinical-bert-1")

In [None]:
preds = learn.get_preds(ds_type=DatasetType.Valid)

In [None]:
def eval_model(inp, preds, ds="Validation", thresh=None):
    best_thresh = thresh
    if thresh is None:
        print("Theshold is not provided, calculating...")
        best_thresh = J_statistic(inp, preds)
    final_preds = np.array([1 if p > best_thresh else 0 for p in preds])
    acc = pos_accuracy(learn.data.valid_ds.y.items, final_preds, thresh=best_thresh)
    auc_score = auc(inp, preds)
    f1, precision, recall = f1_precision_recall(inp, final_preds)
    
    print("================================", ds, "Metrics for Postive Class ================================")
    print("Best Threshold:", best_thresh)
    print("Positive Class Acc.:", acc)
    print("AUC:", auc_score)
    print("F1 Score:", f1)
    print("Precision:", precision)
    print("Recall:", recall)
    
    scores = {"best_thresh": best_thresh,
              "acc": acc,
              "auc": auc_score,
              "f1": f1, 
              "precision": precision,
              "recall": recall}
    
    return scores


In [None]:
valid_metrics_1 = eval_model(learn.data.valid_ds.y.items, preds[0][:, 1])
valid_roc.append(valid_metrics_1["auc"])

In [None]:
test = pd.read_csv("./data/test.csv")
learn.data.add_test(test)
t_preds = learn.get_preds(ds_type=DatasetType.Test)

In [None]:
test_metrics_1 = eval_model(test.OUTPUT_LABEL, t_preds[0][:, 1], ds="Test", thresh=valid_metrics_1["best_thresh"])
test_roc.append(test_metrics_1["auc"])

In [None]:
learn.load("clinical-bert-2")

In [None]:
preds = learn.get_preds(ds_type=DatasetType.Valid)
valid_metrics_2 = eval_model(learn.data.valid_ds.y.items, preds[0][:, 1])
valid_roc.append(valid_metrics_2["auc"])

test_metrics_2 = eval_model(test.OUTPUT_LABEL, t_preds[0][:, 1], ds="Test", thresh=valid_metrics_2["best_thresh"])
test_roc.append(test_metrics_2["auc"])

In [None]:
learn.load("clinical-bert-3")

preds = learn.get_preds(ds_type=DatasetType.Valid)
valid_metrics_3 = eval_model(learn.data.valid_ds.y.items, preds[0][:, 1])
valid_roc.append(valid_metrics_3["auc"])

test_metrics_3 = eval_model(test.OUTPUT_LABEL, t_preds[0][:, 1], ds="Test", thresh=valid_metrics_3["best_thresh"])
test_roc.append(test_metrics_3["auc"])

In [None]:
learn.load("clinical-bert-unfrozen-1")

preds = learn.get_preds(ds_type=DatasetType.Valid)
valid_metrics_4 = eval_model(learn.data.valid_ds.y.items, preds[0][:, 1])
valid_roc.append(valid_metrics_4["auc"])

test_metrics_4 = eval_model(test.OUTPUT_LABEL, t_preds[0][:, 1], ds="Test", thresh=valid_metrics_4["best_thresh"])
test_roc.append(test_metrics_4["auc"])