In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-1.4.9-py3-none-any.whl (925 kB)
[K     |████████████████████████████████| 925 kB 995 kB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 38.2 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.9.0-py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 46.8 MB/s 
[?25hCollecting torchmetrics>=0.4.0
  Downloading torchmetrics-0.5.1-py3-none-any.whl (282 kB)
[K     |████████████████████████████████| 282 kB 39.3 MB/s 
Collecting PyYAML>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 32.9 MB/s 
[?25hCollecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting aiohttp
  Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 3

In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.11.2-py3-none-any.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 1.1 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 33.9 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 34.6 MB/s 
Collecting huggingface-hub>=0.0.17
  Downloading huggingface_hub-0.0.17-py3-none-any.whl (52 kB)
[K     |████████████████████████████████| 52 kB 1.5 MB/s 
Installing collected packages: tokenizers, sacremoses, huggingface-hub, transformers
Successfully installed huggingface-hub-0.0.17 sacremoses-0.0.46 tokenizers-0.10.3 transformers-4.11.2


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import pytorch_lightning as pl
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from transformers.optimization import AdamW
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, precision_score, f1_score
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8
MAX_SEQ_LEN = 512
num_labels = 19
lr = 5e-5
max_grad_norm = 1.0
num_training_steps = 10  # TODO
num_warmup_steps = max(1, num_training_steps // 1)
train_data_file = "/content/drive/MyDrive/Predicting Medical Billing Codes (ICD9) from Clinical Notes (in MIMIC-III datasets) using Deep Learning/data/clean_2k.csv"
model_ckpt = 'model0508.ckpt'

In [None]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
clinic_bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", return_dict=False)

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class ClinicBertMultiLabelClassifier(pl.LightningModule):
    def __init__(self, grad_clip=True):
        super(ClinicBertMultiLabelClassifier, self).__init__()
        self.grad_clip = grad_clip
        self.num_labels = num_labels

        # evaluation metrics
        self.best_f1 = 0
        self.train_loss_list = []
        self.val_loss_list = []
        self.train_f1_micro_list = []
        self.train_f1_macro_list = []
        self.val_f1_micro_list = []
        self.val_f1_macro_list = []
        self.train_precision_list = []
        self.train_recall_list = []
        self.val_precision_list = []
        self.val_recall_list = []

        # loss function
        self.pos_weight = torch.ones([num_labels]).to(device)
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)

        # network modules
        self.bert = clinic_bert
        self.classifier = nn.Linear(self.bert.config.hidden_size, self.num_labels)
        self.activate = nn.Sigmoid()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, aggregated_output = self.bert(input_ids, token_type_ids, attention_mask)
        logits = self.classifier(aggregated_output)

        # to avoid gradients vanishing and sigmoid nan
        if self.grad_clip:
            logits = logits.clamp(min=-14.0, max=14.0)

        if labels is not None:
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.reshape(-1) == 1
                active_logits = logits.reshape([-1, self.num_labels])[active_loss]
                active_labels = labels.reshape([-1, self.num_labels])[active_loss]
                loss = self.criterion(active_logits, active_labels)
            else:
                loss = self.criterion(logits.reshape([-1, self.num_labels]),
                                      labels.reshape([-1, self.num_labels]))
            return loss, self.activate(logits)
        else:
            return self.activate(logits)

    @staticmethod
    def _convert_label_id_to_one_hot(code_list):
        labels = np.zeros([len(code_list), num_labels])
        for idx, codes in enumerate(code_list):
            codes = eval(codes)
            if not isinstance(codes, list):
                print("NOT a list: ", idx, codes)
            for code in codes:
                labels[idx][code] = 1
        return torch.tensor(labels)

    @staticmethod
    def _get_pos_weight(labels: torch.Tensor) -> torch.Tensor:
        total_num = labels.size(0)
        pos_cnt = labels.sum(dim=0).cpu().detach().numpy()
        neg_cnt = total_num - pos_cnt
        tmp = neg_cnt / pos_cnt
        return torch.tensor(tmp).to(device)

    def prepare_data(self):
        df = pd.read_csv(train_data_file)
        input_sequence_list = df['CLEAN_WORDS']
        input_data = tokenizer.batch_encode_plus([" ".join(eval(e)) for e in input_sequence_list],
                                                 max_length=MAX_SEQ_LEN,
                                                 truncation=True,
                                                 pad_to_max_length=True,
                                                 return_tensors='pt')
        input_ids = input_data['input_ids']  # IntTensor [batch_size, MAX_SEQ_LEN]
        print('input_ids: ', type(input_ids), input_ids.shape)
        code_list = df['CODED_HIGH_LVL_DIAG']
        labels = self._convert_label_id_to_one_hot(code_list)

        # re-define loss function with weights
        self.pos_weight = self._get_pos_weight(labels)
        print(self.pos_weight)
        self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)

        print('labels: ', type(labels), labels.shape)
        self.train_dataset = TensorDataset(input_ids[:1700], labels[:1700])
        self.val_dataset = TensorDataset(input_ids[1700:], labels[1700:])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=lr, correct_bias=False)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps,
                                                    num_training_steps=num_training_steps)
        return {'optimizer': optimizer,
                'lr_scheduler': scheduler}
        # return optimizer

    def training_step(self, batch, batch_idx):
        input_ids, labels = batch
        batch_loss, probs = self(input_ids, labels=labels)

        probs = probs.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy().astype(int)
        labels_pred = np.round(probs).astype(int)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, labels_pred, average='micro')
        f1_macro = f1_score(labels, labels_pred, average='macro')
        accuracy = accuracy_score(labels, labels_pred)

        log = {'train_loss': batch_loss}
        return {'loss': batch_loss, 'log': log,
                'precision': torch.tensor(precision),
                'recall': torch.tensor(recall),
                'f1': torch.tensor(f1),
                'f1-macro': torch.tensor(f1_macro),
                'accuracy': torch.tensor(accuracy),
                }

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['f1'] for x in outputs]).mean()
        avg_f1_macro = torch.stack([x['f1-macro'] for x in outputs]).mean()
        avg_precision = torch.stack([x['precision'] for x in outputs]).mean()
        avg_recall = torch.stack([x['recall'] for x in outputs]).mean()
        print('Train F1', avg_f1.item())
        print('Train precision', avg_precision.item())
        print('Train recall', avg_recall.item())
        print('Train acc: ', avg_acc.item())

        self.train_loss_list.append(avg_loss)
        self.train_f1_micro_list.append(avg_f1)
        self.train_f1_macro_list.append(avg_f1_macro)
        self.train_precision_list.append(avg_precision)
        self.train_recall_list.append(avg_recall)

        #return {'loss': avg_loss}

    def _precision_top_n(self, probs_pred: np.ndarray, labels_true: np.ndarray, n: int):
        # TODO
        # [batch size, n_classes]
        prob_ids = probs_pred.argsort(axis=1)[:, ::-1][:, :n]  # reverse and trunc to only top n
        top_preds = np.zeros_like(labels_true)
        for i, prob_id in enumerate(prob_ids):
            top_preds[i][prob_id] = 1
        return precision_score(labels_true, top_preds, average='micro')

    def validation_step(self, batch, batch_idx):
        input_ids, labels = batch
        loss, logits = self(input_ids, labels=labels)

        logits = logits.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy().astype(int)
        # p8 = self._precision_top_n(logits, labels, 8)
        # p15 = self._precision_top_n(logits, labels, 15)
        labels_pred = np.round(logits).astype(int)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, labels_pred, average='micro')
        f1_macro = f1_score(labels, labels_pred, average='macro')
        accuracy = accuracy_score(labels, labels_pred)
        return {'precision': torch.tensor(precision),
                'recall': torch.tensor(recall),
                'f1': torch.tensor(f1),
                'f1-macro': torch.tensor(f1_macro),
                'accuracy': torch.tensor(accuracy),
                'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_acc = torch.stack([x['accuracy'] for x in outputs]).mean()
        avg_f1 = torch.stack([x['f1'] for x in outputs]).mean()
        avg_f1_macro = torch.stack([x['f1-macro'] for x in outputs]).mean()
        avg_precision = torch.stack([x['precision'] for x in outputs]).mean()
        avg_recall = torch.stack([x['recall'] for x in outputs]).mean()
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.val_loss_list.append(avg_loss)
        self.val_f1_micro_list.append(avg_f1)
        self.val_f1_macro_list.append(avg_f1_macro)
        self.val_precision_list.append(avg_precision)
        self.val_recall_list.append(avg_recall)
        print('Val Loss', avg_loss.item())
        print('Val F1', avg_f1.item())
        print('Val precision', avg_precision.item())
        print('Val recall', avg_recall.item())
        print('Val acc: ', avg_acc.item())
        f1 = avg_f1.item()
        if f1 > self.best_f1:
            self.best_f1 = f1
            model_dict = {
                'f1': f1,
                'model': self.state_dict(),
                'train_loss': self.train_loss_list,
                'val_loss': self.val_loss_list,
                'train_f1': self.train_f1_micro_list,
                'val_f1': self.val_f1_micro_list,
                'train_f1_macro': self.train_f1_macro_list,
                'val_f1_macro': self.val_f1_macro_list,
                'train_precision': self.train_precision_list,
                'train_recall': self.train_recall_list,
                'val_precision': self.val_precision_list,
                'val_recall': self.val_recall_list,
            }
            torch.save(model_dict, model_ckpt)
            print("Save model at f1[%f] in %s" % (f1, model_ckpt))
        return {'val_loss': avg_acc,
                'val_avg_f1': avg_f1}

In [None]:
if __name__ == '__main__':
    net = ClinicBertMultiLabelClassifier()
    trainer = pl.Trainer(max_epochs=num_training_steps, gpus=1)

    if os.path.exists(model_ckpt):
        print("Load model from %s" % model_ckpt)
        checkpoint = torch.load(model_ckpt)
        net.load_state_dict(checkpoint['model'])
        net.best_f1 = checkpoint['f1']
        print("Previous best val f1: %f" % net.best_f1)
        net.val_loss_list = list(checkpoint['val_loss'])
        net.train_loss_list = list(checkpoint['train_loss'])
    trainer.fit(net)

    model_dict = {
        'model': net.state_dict(),
        'train_loss': net.train_loss_list,
        'val_loss': net.val_loss_list,
        'train_f1': net.train_f1_micro_list,
        'val_f1': net.val_f1_micro_list,
        'train_f1_macro': net.train_f1_macro_list,
        'val_f1_macro': net.val_f1_macro_list,
        'train_precision': net.train_precision_list,
        'train_recall': net.train_recall_list,
        'val_precision': net.val_precision_list,
        'val_recall': net.val_recall_list
    }
    torch.save(model_dict, 'model0508-whole.ckpt')
    print("Save whole process in model0508-whole.ckpt")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


input_ids:  <class 'torch.Tensor'> torch.Size([2000, 512])


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


tensor([  2.8536,   4.5096,   0.5613,   1.7360,   2.4783,   2.6630,   0.2788,
          1.1436,   1.4691,   1.5157, 221.2222,   8.0909,   4.3191,  16.5439,
         14.0376,   1.6420,   1.4390,   0.8018,   2.3223], device='cuda:0',
       dtype=torch.float64)
labels:  <class 'torch.Tensor'> torch.Size([2000, 19])



  | Name       | Type              | Params
-------------------------------------------------
0 | criterion  | BCEWithLogitsLoss | 0     
1 | bert       | BertModel         | 108 M 
2 | classifier | Linear            | 14.6 K
3 | activate   | Sigmoid           | 0     
-------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.300   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  average, "true nor predicted", 'F-score is', len(true_sum)


Val Loss 0.8874989787134204
Val F1 0.3689889549702634
Val precision 0.31510416666666663
Val recall 0.44515810276679846
Val acc:  0.0
Save model at f1[0.368989] in model0508.ckpt


Training: -1it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

Val Loss 0.9400105324764253
Val F1 0.3418615525169841
Val precision 0.3000519131113439
Val recall 0.40110406414769956
Val acc:  0.0
Train F1 0.33965218424733096
Train precision 0.30211010623907053
Train recall 0.3940445120889235
Train acc:  0.0


Validating: 0it [00:00, ?it/s]

Val Loss 0.8974016546674513
Val F1 0.4166939054821186
Val precision 0.33742306778255404
Val recall 0.5524713504708101
Val acc:  0.0
Save model at f1[0.416694] in model0508.ckpt
Train F1 0.399620831440695
Train precision 0.33626275404787476
Train recall 0.5024814254421175
Train acc:  0.0


Validating: 0it [00:00, ?it/s]

Val Loss 0.841398840991206
Val F1 0.5442175905466174
Val precision 0.40576153792289643
Val recall 0.8372789619394327
Val acc:  0.0
Save model at f1[0.544218] in model0508.ckpt
Train F1 0.4939210693627024
Train precision 0.3954588761219447
Train recall 0.6801688632945952
Train acc:  0.0017605633802816902


Validating: 0it [00:00, ?it/s]

Val Loss 0.8143874339327082
Val F1 0.5605628884995344
Val precision 0.4784395306820857
Val recall 0.6918019827773424
Val acc:  0.0
Save model at f1[0.560563] in model0508.ckpt
Train F1 0.5386367591929512
Train precision 0.4457983310217816
Train recall 0.7076687460047054
Train acc:  0.002347417840375587


Validating: 0it [00:00, ?it/s]

Val Loss 0.7341219407575308
Val F1 0.5917523019487275
Val precision 0.5024939897784638
Val recall 0.7309761466150795
Val acc:  0.003289473684210526
Save model at f1[0.591752] in model0508.ckpt
Train F1 0.5705805069237057
Train precision 0.48262371034883883
Train recall 0.724372455399897
Train acc:  0.0035211267605633804


Validating: 0it [00:00, ?it/s]

Val Loss 0.7082859975691103
Val F1 0.5947684263306389
Val precision 0.5412034630689625
Val recall 0.6691780746214953
Val acc:  0.006578947368421052
Save model at f1[0.594768] in model0508.ckpt
Train F1 0.6041560554074299
Train precision 0.5330159602090239
Train recall 0.7206778585893513
Train acc:  0.007629107981220657


Validating: 0it [00:00, ?it/s]

Val Loss 0.6986976891363583
Val F1 0.588339141871014
Val precision 0.5585441317495569
Val recall 0.6323510383712082
Val acc:  0.006578947368421052
Train F1 0.6326033920756658
Train precision 0.5710848392329478
Train recall 0.7260993238580211
Train acc:  0.018779342723004695


Validating: 0it [00:00, ?it/s]

Val Loss 0.6914033466434917
Val F1 0.6001804583107689
Val precision 0.587391966126159
Val recall 0.6260690179859871
Val acc:  0.013157894736842105
Save model at f1[0.600180] in model0508.ckpt
Train F1 0.6564717558182245
Train precision 0.6049267528063789
Train recall 0.734812633423657
Train acc:  0.023474178403755867


Validating: 0it [00:00, ?it/s]

Val Loss 0.7161362097401796
Val F1 0.6050689743524639
Val precision 0.5420645999138592
Val recall 0.6973221778699475
Val acc:  0.013157894736842105
Save model at f1[0.605069] in model0508.ckpt
Train F1 0.6838086480990586
Train precision 0.6378059026362727
Train recall 0.7504660670231117
Train acc:  0.03169014084507042


Validating: 0it [00:00, ?it/s]

Val Loss 0.7402008210071829
Val F1 0.6013040844537406
Val precision 0.5778688771504715
Val recall 0.6363244045578488
Val acc:  0.009868421052631578
Train F1 0.7014608117045457
Train precision 0.6572093521925835
Train recall 0.7666855190834361
Train acc:  0.028169014084507043
Save whole process in model0508-whole.ckpt


In [None]:
!ls -la

total 846532
drwxr-xr-x 1 root root      4096 Sep 30 21:36 .
drwxr-xr-x 1 root root      4096 Sep 30 20:32 ..
drwxr-xr-x 4 root root      4096 Sep 16 13:39 .config
drwx------ 5 root root      4096 Sep 30 20:34 drive
drwxr-xr-x 3 root root      4096 Sep 30 20:35 lightning_logs
-rw-r--r-- 1 root root 433407415 Sep 30 21:30 model0508.ckpt
-rw-r--r-- 1 root root 433408089 Sep 30 21:36 model0508-whole.ckpt
drwxr-xr-x 1 root root      4096 Sep 16 13:40 sample_data


In [None]:
from google.colab import files

In [None]:
files.download('/content/model0508-whole.ckpt')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>