# Importing Libraries & Data Preparation

In [None]:
import numpy as np
import pandas as pd
from fastai import *
from fastai.text import *
from fastai.callbacks.tracker import EarlyStoppingCallback
from fastai.callbacks.tracker import SaveModelCallback

from clinical_note_utils import *

from transformers import AutoTokenizer, AutoModel
from transformers import BertForSequenceClassification, BertTokenizer


In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT")

In [None]:
class FastAiBertTokenizer(BaseTokenizer):
    """Wrapper around BertTokenizer to be compatible with fast.ai"""
    def __init__(self, tokenizer: BertTokenizer, max_seq_len: int=128, **kwargs):
        self._pretrained_tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __call__(self, *args, **kwargs):
        return self

    def tokenizer(self, t:str) -> List[str]:
        """Limits the maximum sequence length"""
        return ["[CLS]"] + self._pretrained_tokenizer.tokenize(t)[:self.max_seq_len - 2] + ["[SEP]"]

In [None]:
DATA_ROOT = Path("./data")  

train, valid = [pd.read_csv(DATA_ROOT / fname) for fname in ["train_sub.csv", "valid.csv"]]

In [None]:
train.head()

In [None]:
valid.head()

In [None]:
fastai_bert_vocab = Vocab(list(tokenizer.vocab.keys()))

In [None]:
fastai_tokenizer = Tokenizer(tok_func=FastAiBertTokenizer(tokenizer, max_seq_len=256), pre_rules=[], post_rules=[])

In [None]:
clinical_data = TextDataBunch.from_df(".", train, valid, 
                  tokenizer=fastai_tokenizer,
                  vocab=fastai_bert_vocab,
                  include_bos=False,
                  include_eos=False,
                  text_cols="TEXT",
                  label_cols="OUTPUT_LABEL",
                  bs=16,
                  collate_fn=partial(pad_collate, pad_first=False, pad_idx=0),
             )


In [None]:
clinical_data.save("bert-clinical-sub.pkl")

In [None]:
clinical_data.show_batch()

# BERT Model

In [None]:
clinical_bert = BertForSequenceClassification.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT",
              num_labels = 2)

In [None]:
class CustomTransformerModel(nn.Module):
  
    def __init__(self, transformer_model: BertForSequenceClassification):
        super(CustomTransformerModel,self).__init__()
        self.transformer = transformer_model
        
    def forward(self, x):
        # Return only the logits from the transfomer
        logits = self.transformer(x)[0]   
        return logits

In [None]:
model = CustomTransformerModel(clinical_bert)

In [None]:
from fastai.callbacks import *

learner = Learner(clinical_data, model, metrics=[AUROC(), Precision(), Recall()])

In [None]:
def bert_clas_split(self) -> List[nn.Module]:
    
    bert = model.transformer.bert
    embedder = bert.embeddings
    pooler = bert.pooler
    encoder = bert.encoder
    classifier = [model.transformer.dropout, model.transformer.classifier]
    n = len(encoder.layer)//3
    print(n)
    groups = [[embedder], list(encoder.layer[:n]), list(encoder.layer[n+1:2*n]), list(encoder.layer[(2*n)+1:]), [pooler], classifier]
    return groups

x = bert_clas_split(model)
learner.split([x[0], x[1], x[2], x[3], x[4], x[5]])

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(3, max_lr=slice(1e-06, 1e-05), moms=(0.8, 0.7), wd =(1e-7, 1e-06, 1e-5, 1e-4, 1e-3, 1e-02))

In [None]:
learner.save("clinical-bert-1-b")

In [None]:
learner.freeze_to(-2)

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(10, max_lr=slice(1e-05, 1e-04), moms=(0.8, 0.7), pct_start=0.2, wd =(1e-7, 1e-06, 1e-5, 1e-4, 1e-3, 1e-02))

In [None]:
learner.save("clinical-bert-2")

In [None]:
learner.freeze_to(-3)

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(3, max_lr=slice(1e-6, 1e-05), moms=(0.8, 0.7), pct_start=0.1, wd =(1e-7, 1e-06, 1e-5, 1e-4, 1e-3, 1e-02))

In [None]:
learner.save("clinical-bert-3")

In [None]:
learner.unfreeze()

In [None]:
learner.lr_find()

In [None]:
learner.recorder.plot(suggestion=True)

In [None]:
learner.fit_one_cycle(2, max_lr=slice(1e-05, 1e-04), moms=(0.8, 0.7), pct_start=0.1, wd =(1e-7, 1e-06, 1e-5, 1e-4, 1e-3, 1e-02))

In [None]:
learner.save("clinical-bert-unfrozen-1")

In [None]:
test = pd.read_csv(DATA_ROOT / "test.csv")
learner.data.add_test(test.TEXT)

In [None]:
learner.data.test_ds

In [None]:
preds = learner.get_preds(DatasetType.Test)

In [None]:
final_preds = [1 if p > 0.5 else 0 for p in preds[0][:, 1]]

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
confusion_matrix(test.OUTPUT_LABEL, final_preds)