In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 2.6 MB 8.4 MB/s 
[K     |████████████████████████████████| 895 kB 55.9 MB/s 
[K     |████████████████████████████████| 636 kB 74.3 MB/s 
[K     |████████████████████████████████| 3.3 MB 49.0 MB/s 
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%matplotlib inline
import os
import torch
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn import metrics
from sklearn.metrics import classification_report
import re
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
from transformers import AutoTokenizer, AutoModel
from transformers import BertForSequenceClassification, AdamW, BertConfig
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data.sampler import SubsetRandomSampler
import transformers
from transformers import RobertaTokenizer, BertTokenizer, RobertaModel, BertModel, AdamW# get_linear_schedule_with_warmup
from transformers import get_linear_schedule_with_warmup
import time

!cp drive/MyDrive/Colab\ Notebooks/MSc-Individual-Project/utils.py .
from utils import *
!cp drive/MyDrive/Colab\ Notebooks/MSc-Individual-Project/Custom_Dataset_Class.py .
from Custom_Dataset_Class import CustomDataset
!cp drive/MyDrive/Colab\ Notebooks/MSc-Individual-Project/pytorchtools.py .
from pytorchtools import EarlyStopping
#from Bert_Classification import Bert_Classification_Model
#from RoBERT import RoBERT_Model

#from BERT_Hierarchical import BERT_Hierarchical_Model
import warnings
warnings.filterwarnings("ignore")

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.preprocessing import LabelBinarizer

In [4]:
import torch
# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")


There are 1 GPU(s) available.
We will use the GPU: Tesla T4


In [5]:
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)

In [6]:
#change to where you store mimic3 data
MIMIC_3_DIR = '/content/drive/MyDrive/Colab Notebooks/MSc-Individual-Project/datasets_date'

train_df = pd.read_csv('%s/train_50_second.csv' % MIMIC_3_DIR)
eval_df = pd.read_csv('%s/dev_50_second.csv' % MIMIC_3_DIR)
test_df = pd.read_csv('%s/test_50_second.csv' % MIMIC_3_DIR)

train_df.head()

Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,LABELS,length
0,5331,142049,hypotension,401.9;414.01;530.81;424.0;250.00;427.31,1
1,2830,193970,title,401.9;285.9;276.2;518.81;244.9;276.1;584.9;427.89,1
2,72671,188200,title,311;276.2;518.81;486;427.31,1
3,17250,134654,title,403.90;496;287.5;995.92;424.0;584.9;507.0,1
4,10502,145440,title,414.01;412;403.90;496;530.81;518.81;486;584.9;...,1


In [7]:
full_df = pd.concat([train_df, eval_df, test_df], ignore_index=True)

In [8]:
# split labels by ";", then convert to list
def split_lab (x):
    #print(x)
    return x.split(";")

full_df['LABELS'] = full_df['LABELS'].apply(split_lab)
#full_df['TEXT'] = full_df['TEXT'].apply(split_lab)

full_df.head()

Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,LABELS,length
0,5331,142049,hypotension,"[401.9, 414.01, 530.81, 424.0, 250.00, 427.31]",1
1,2830,193970,title,"[401.9, 285.9, 276.2, 518.81, 244.9, 276.1, 58...",1
2,72671,188200,title,"[311, 276.2, 518.81, 486, 427.31]",1
3,17250,134654,title,"[403.90, 496, 287.5, 995.92, 424.0, 584.9, 507.0]",1
4,10502,145440,title,"[414.01, 412, 403.90, 496, 530.81, 518.81, 486...",1


In [9]:
#load multi label binarizer for one-hot encoding
mlb = MultiLabelBinarizer(sparse_output=True)

#labels_onehot = mlb.fit_transform(train_df.pop('LABELS'))
#labels_onehot[0][1]

In [10]:
#change label to one-hot encoding per code
full_df = full_df.join(
            pd.DataFrame.sparse.from_spmatrix(
                mlb.fit_transform(full_df.pop('LABELS')),
                columns=mlb.classes_))

full_df

Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,length,038.9,244.9,250.00,272.0,272.4,276.0,276.1,276.2,285.1,285.9,287.5,305.1,311,327.23,401.9,403.90,403.91,410.71,412,414.01,424.0,424.1,427.31,427.89,428.0,486,493.90,496,507.0,511.9,518.0,518.81,530.81,584.5,584.9,585.9,599.0,774.2,785.52,995.92,997.1,V05.3,V15.82,V29.0,V30.00,V30.01,V45.81,V45.82,V58.61,V58.67
0,5331,142049,hypotension,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,2830,193970,title,1,0,1,0,0,0,0,1,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
2,72671,188200,title,1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
3,17250,134654,title,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
4,10502,145440,title,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,0,1,1,0,1,0,0,0,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
43414,98103,133463,chief complaint altered mental status hpi year...,1860,0,0,0,0,1,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1
43415,92846,125385,chief complaint hypotension fatigue shortness ...,1895,1,1,1,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,1,0,1,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0
43416,93610,164181,chief complaint subdural hematoma hpi 85m with...,1922,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
43417,99339,142289,chief complaint chief complaint increasing ple...,2086,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0


In [11]:
full_df.HADM_ID.unique().shape

(5181,)

In [12]:
# Convert columns to list of one hot encoding
icd_classes_50 = mlb.classes_

full_df['labels'] = full_df[icd_classes_50].values.tolist()
#train_df.sort_values(['length'], ascending=False, inplace=True)
full_df.head()


Unnamed: 0,SUBJECT_ID,HADM_ID,TEXT,length,038.9,244.9,250.00,272.0,272.4,276.0,276.1,276.2,285.1,285.9,287.5,305.1,311,327.23,401.9,403.90,403.91,410.71,412,414.01,424.0,424.1,427.31,427.89,428.0,486,493.90,496,507.0,511.9,518.0,518.81,530.81,584.5,584.9,585.9,599.0,774.2,785.52,995.92,997.1,V05.3,V15.82,V29.0,V30.00,V30.01,V45.81,V45.82,V58.61,V58.67,labels
0,5331,142049,hypotension,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,2830,193970,title,1,0,1,0,0,0,0,1,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ..."
2,72671,188200,title,1,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, ..."
3,17250,134654,title,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ..."
4,10502,145440,title,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,0,1,1,0,1,0,0,0,1,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [13]:
#full_df = full_df.drop(full_df[full_df['length']<300].index)

In [14]:
#full_df

In [15]:
train_df, test_df = train_test_split(full_df, test_size=0.2)

In [16]:
train_df, eval_df = train_test_split(train_df, test_size=0.2)

In [17]:
dev_df, test_df = train_test_split(test_df, test_size=0.5)

In [18]:
train_df.sort_values(['length'], inplace=True)
eval_df.sort_values(['length'], inplace=True)
dev_df.sort_values(['length'], inplace=True)
test_df.sort_values(['length'], inplace=True)


In [19]:
#convert into 2 columns dataframe
train_df = pd.DataFrame(train_df, columns=['TEXT', 'labels'])
train_df.columns=['text', 'labels']
train_df.head()

eval_df = pd.DataFrame(eval_df, columns=['TEXT', 'labels'])
eval_df.columns=['text', 'labels']
eval_df.head()

dev_df = pd.DataFrame(dev_df, columns=['HADM_ID', 'TEXT', 'labels'])
dev_df.columns=['id', 'text', 'labels']
dev_df.head()

test_df = pd.DataFrame(test_df, columns=['HADM_ID', 'TEXT', 'labels'])
test_df.columns=['id', 'text', 'labels']
test_df.head()

Unnamed: 0,id,text,labels
19,167505,start feeds today,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
23,144722,resp correction rsbi,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
20,167505,start feeds today,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
32,100566,atrial sensed ventricular paced,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
53,123066,please see nursing transfer note,"[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."


In [20]:
train_df.reset_index(drop=True, inplace=True)
eval_df.reset_index(drop=True, inplace=True)
dev_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
test_df.head()

Unnamed: 0,id,text,labels
0,167505,start feeds today,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,144722,resp correction rsbi,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
2,167505,start feeds today,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,100566,atrial sensed ventricular paced,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
4,123066,please see nursing transfer note,"[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ..."


In [21]:
# Creating the customized model, by adding a drop out and a dense layer on top of distil bert to get the final output for the model.

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        '''
            Load Pretrained model here
            Use return_dict=False for compatibility for 4.x

        '''
        self.l1 = transformers.AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", return_dict=False)
        #self.l1 = transformers.BertModel.from_pretrained('bert-base-uncased', return_dict=False)


        self.l2 = torch.nn.Dropout(0.3)

        '''
            Changed Linear Output layer to 50 based on the class
        '''
        self.l3 = torch.nn.Linear(768, 50)

    def forward(self, ids, mask, token_type_ids):
#        print("ids: ", ids.size(), "mask: ", mask.size(), "token type ids: ", token_type_ids.size())
        _, output_1= self.l1(ids, attention_mask = mask, token_type_ids = token_type_ids)
        output_2 = self.l2(output_1)
        output = self.l3(output_2)
        return output

model = BERTClass()
model.to(device)

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BERTClass(
  (l1): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    

In [22]:
# Defining some key variables to configure model training
MAX_LEN = 512
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 8
TEST_BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 3e-05

#set tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

#custom dataset for BERT class
class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_len):
        
        '''
            set text as training data
            set labels as targets
        '''
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.text
        self.targets = self.data.labels
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split())

        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            return_token_type_ids=True
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]


        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }



Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

In [23]:
#load df to dataset

training_set = CustomDataset(train_df, tokenizer, MAX_LEN)
valid_set = CustomDataset(eval_df, tokenizer, MAX_LEN)
dev_set = CustomDataset(dev_df, tokenizer, MAX_LEN)
testing_set = CustomDataset(test_df, tokenizer, MAX_LEN)

In [24]:
#data loader
train_params = {'batch_size': TRAIN_BATCH_SIZE,
                'shuffle': False
                }

val_params = {'batch_size': VALID_BATCH_SIZE,
                'shuffle': False
                }

dev_params = {'batch_size': TEST_BATCH_SIZE,
                'shuffle': False
                }

test_params = {'batch_size': TEST_BATCH_SIZE,
                'shuffle': False
                }

training_loader = DataLoader(training_set, **train_params)
valid_loader = DataLoader(valid_set, **val_params)
dev_loader = DataLoader(dev_set, **dev_params)
testing_loader = DataLoader(testing_set, **test_params)

In [25]:
#loss function
def loss_fn(outputs, targets):
    return torch.nn.BCEWithLogitsLoss()(outputs, targets)

#optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

In [26]:
def train(epoch):
    model.train()
    for _,data in enumerate(training_loader, 0):
        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)

        outputs = model(ids, mask, token_type_ids)

        optimizer.zero_grad()
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        
    print(f'Epoch: {epoch}, Training Loss:  {loss.item()}')

In [27]:
# Evaluate the model

def validation(epoch):
    model.eval()
    fin_targets=[]
    fin_outputs=[]
    losses=[]
    with torch.no_grad():
        for _, data in enumerate(valid_loader, 0):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
    print(f'Epoch: {epoch}, Validation Loss:  {np.mean(losses):.2f}')
    return fin_outputs, fin_targets, losses

In [None]:
start_epoch=0
DIR = '/content/drive/MyDrive/Colab Notebooks/MSc-Individual-Project/'
resume = True     
if resume:
    if os.path.isfile(f"%s/models/model_first2days_epoch{start_epoch}.pth" % DIR):
        print("Resume from checkpoint...")
        checkpoint = torch.load(f"%s/models/model_first2days_epoch{start_epoch}.pth" % DIR)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initepoch = checkpoint['epoch']
        print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
    else:
        print("====>no checkpoint found.")
        initepoch = 0

#patience = 3
#early_stopping = EarlyStopping(patience, verbose=True)


for epoch in tqdm(range(EPOCHS)):
    train(epoch)
    validation(epoch)

    if (epoch+start_epoch+1)%5 == 0:
        checkpoint = {"model_state_dict": model.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch+start_epoch+1}
        path_checkpoint = f"%s/models/model_first2days_epoch{epoch+start_epoch+1}.pth" % DIR
        torch.save(checkpoint, path_checkpoint)

#

====>no checkpoint found.


  0%|          | 0/10 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Epoch: 0, Training Loss:  0.36915984749794006


 10%|█         | 1/10 [15:50<2:22:36, 950.76s/it]

Epoch: 0, Validation Loss:  0.28
Epoch: 1, Training Loss:  0.3491505980491638


 20%|██        | 2/10 [31:34<2:06:10, 946.37s/it]

Epoch: 1, Validation Loss:  0.27
Epoch: 2, Training Loss:  0.312276154756546


 30%|███       | 3/10 [47:19<1:50:21, 945.99s/it]

Epoch: 2, Validation Loss:  0.27
Epoch: 3, Training Loss:  0.2886822521686554


 40%|████      | 4/10 [1:03:03<1:34:32, 945.35s/it]

Epoch: 3, Validation Loss:  0.27
Epoch: 4, Training Loss:  0.257680207490921
Epoch: 4, Validation Loss:  0.26


 50%|█████     | 5/10 [1:18:52<1:18:53, 946.64s/it]

Epoch: 5, Training Loss:  0.23453517258167267


 60%|██████    | 6/10 [1:34:35<1:03:00, 945.18s/it]

Epoch: 5, Validation Loss:  0.25
Epoch: 6, Training Loss:  0.20862773060798645


 70%|███████   | 7/10 [1:50:18<47:13, 944.52s/it]  

Epoch: 6, Validation Loss:  0.25
Epoch: 7, Training Loss:  0.1760110706090927


 80%|████████  | 8/10 [2:06:02<31:28, 944.46s/it]

Epoch: 7, Validation Loss:  0.25
Epoch: 8, Training Loss:  0.17118039727210999


 90%|█████████ | 9/10 [2:21:47<15:44, 944.71s/it]

Epoch: 8, Validation Loss:  0.25
Epoch: 9, Training Loss:  0.146488219499588
Epoch: 9, Validation Loss:  0.26


100%|██████████| 10/10 [2:37:35<00:00, 945.51s/it]


In [28]:
DIR = '/content/drive/MyDrive/Colab Notebooks/MSc-Individual-Project/'

checkpoint = torch.load(f"%s/models/model_first2days_epoch10.pth" % DIR)
model.load_state_dict(checkpoint['model_state_dict'])



<All keys matched successfully>

In [29]:
# Evaluate the model

def evaluation():
    model.eval()

    fin_targets=[]
    fin_outputs=[]
    losses=[]
    with torch.no_grad():
        for _, data in enumerate(dev_loader, 0):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            fin_targets.extend(targets.cpu().detach().numpy())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy())
    print(f'Loss:  {np.mean(losses):.2f}')
    return fin_outputs, fin_targets, losses

In [None]:
dev_out, dev_tar, losses = evaluation()

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Loss:  0.26


### Normal evaluation

In [None]:
outputs = np.array(dev_out) >= 0.5
targets = dev_tar
accuracy = metrics.accuracy_score(targets, outputs)
f1_score_micro = metrics.f1_score(targets, outputs, average='micro')
f1_score_macro = metrics.f1_score(targets, outputs, average='macro')
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.08636573007830493
F1 Score (Micro) = 0.49388371245640517
F1 Score (Macro) = 0.4154530022306258


In [None]:
print(classification_report(targets, outputs, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.5082    0.3953    0.4447       473
       244.9     0.7982    0.2269    0.3534       401
      250.00     0.5000    0.3029    0.3773       713
       272.0     0.6905    0.0765    0.1378       379
       272.4     0.5136    0.4545    0.4823       704
       276.0     0.7308    0.2682    0.3924       425
       276.1     0.7647    0.2167    0.3377       360
       276.2     0.5149    0.3047    0.3828       512
       285.1     0.5920    0.4568    0.5157       486
       285.9     0.6977    0.2158    0.3297       556
       287.5     0.4941    0.2222    0.3066       378
       305.1     0.7212    0.2308    0.3497       325
         311     0.7424    0.1731    0.2808       283
      327.23     0.7826    0.3103    0.4444       232
       401.9     0.6271    0.5763    0.6006      1567
      403.90     0.6272    0.4194    0.5026       341
      403.91     0.6526    0.2857    0.3974       217
      410.71     0.5318    

In [None]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, outputs, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, outputs, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.690363845710103
RUC AUC Score (Macro) = 0.6574872060030126


In [None]:
dev_df['prediction'] = dev_out
dev_df['tar'] = dev_tar

In [None]:
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


### Avg outputs

In [None]:
out_mean_dict = dev_df.groupby('id').prediction.apply(np.mean).to_dict()
dev_df['out_mean'] = dev_df['id'].map(out_mean_dict)
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


In [None]:
loss_mean = [nn.BCELoss()(torch.tensor(dev_df['out_mean'][i]), torch.tensor(dev_df['tar'][i])) for i in dev_df.index]
np.mean(loss_mean)

0.22567055

In [None]:
out_mean = np.vstack([dev_df['out_mean'][i]>=0.5 for i in dev_df.index])

#targets = dev_tar
accuracy = metrics.accuracy_score(targets, out_mean)
f1_score_micro = metrics.f1_score(targets, out_mean, average='micro')
f1_score_macro = metrics.f1_score(targets, out_mean, average='macro')
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.07807461999078766
F1 Score (Micro) = 0.5196164383561644
F1 Score (Macro) = 0.4117382057562599


In [None]:
print(classification_report(targets, out_mean, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.7210    0.4207    0.5314       473
       244.9     0.8864    0.1945    0.3190       401
      250.00     0.6189    0.2847    0.3900       713
       272.0     0.6800    0.0449    0.0842       379
       272.4     0.6128    0.4205    0.4987       704
       276.0     0.8317    0.1976    0.3194       425
       276.1     0.8052    0.1722    0.2838       360
       276.2     0.6465    0.2715    0.3824       512
       285.1     0.6783    0.4815    0.5632       486
       285.9     0.8195    0.1960    0.3164       556
       287.5     0.5304    0.1614    0.2475       378
       305.1     0.8861    0.2154    0.3465       325
         311     0.9429    0.1166    0.2075       283
      327.23     0.8730    0.2371    0.3729       232
       401.9     0.6949    0.6133    0.6515      1567
      403.90     0.7488    0.4457    0.5588       341
      403.91     0.7654    0.2857    0.4161       217
      410.71     0.6391    

In [None]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, out_mean, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, out_mean, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.6951961914860378
RUC AUC Score (Macro) = 0.6557419893900109


### Most freq 5 labels

In [None]:
out_sum_dict = dev_df.groupby('id').prediction.apply(np.sum).to_dict()
dev_df['out_sum'] = dev_df['id'].map(out_sum_dict)
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


In [None]:
def freq_5(df, column): # column: out_sum
    df['freq_5'] = df[column]
    for idx in df.index:
      sorted = np.sort(df[column][idx])
      thres = sorted[-5] # position 5
      df['freq_5'][idx] = df[column][idx]>= thres


In [None]:
freq_5(dev_df, 'out_sum')

In [None]:
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


In [None]:
out_freq_5 = np.vstack([dev_df['freq_5'][i] for i in dev_df.index])

#targets = dev_tar
accuracy = metrics.accuracy_score(targets, out_freq_5)
f1_score_micro = metrics.f1_score(targets, out_freq_5, average='micro')
f1_score_macro = metrics.f1_score(targets, out_freq_5, average='macro')
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.024643021649009673
F1 Score (Micro) = 0.5027117509206561
F1 Score (Macro) = 0.4181097769284202


In [None]:
print(classification_report(targets, out_freq_5, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.4412    0.3805    0.4086       473
       244.9     0.4733    0.1771    0.2577       401
      250.00     0.3329    0.3408    0.3368       713
       272.0     0.2994    0.1398    0.1906       379
       272.4     0.4618    0.4901    0.4755       704
       276.0     0.5578    0.1929    0.2867       425
       276.1     0.4602    0.2250    0.3022       360
       276.2     0.3364    0.2871    0.3098       512
       285.1     0.4434    0.5967    0.5088       486
       285.9     0.4883    0.1871    0.2705       556
       287.5     0.4335    0.2328    0.3029       378
       305.1     0.5045    0.3477    0.4117       325
         311     0.5000    0.1661    0.2493       283
      327.23     0.6168    0.2845    0.3894       232
       401.9     0.5574    0.7626    0.6440      1567
      403.90     0.6569    0.3930    0.4917       341
      403.91     0.5344    0.3226    0.4023       217
      410.71     0.4774    

In [None]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, out_freq_5, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, out_freq_5, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.7168918673903204
RUC AUC Score (Macro) = 0.6803295268119769


In [None]:
precision_micro = metrics.average_precision_score(targets, out_freq_5, average='micro')
precision_macro = metrics.average_precision_score(targets, out_freq_5, average='macro')

print(f"Average Precision Score (Micro) = {precision_micro}")
print(f"Average Precision Score (Macro) = {precision_macro}")

Average Precision Score (Micro) = 0.3074656539004526
Average Precision Score (Macro) = 0.2734250897537291


### Predicted percentage

In [None]:
note_count_dict = dev_df.groupby('id').size().to_dict()
dev_df['note_count'] = dev_df['id'].map(note_count_dict)

In [None]:
dev_df['out_bool'] = [(dev_df['prediction'][i]>=0.5).astype(int) for i in dev_df.index]

In [None]:

out_freq_dict = dev_df.groupby('id').out_bool.apply(np.sum).to_dict()
dev_df['num_pred'] = dev_df['id'].map(out_freq_dict)
dev_df['num_pred'] = [(dev_df['num_pred'][i]>=0.4*dev_df['note_count'][i]).astype(int) for i in dev_df.index]

In [None]:
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


In [None]:
out_freq = np.vstack([dev_df['num_pred'][i] for i in dev_df.index])

#targets = dev_tar
accuracy = metrics.accuracy_score(targets, out_freq)
f1_score_micro = metrics.f1_score(targets, out_freq, average='micro')
f1_score_macro = metrics.f1_score(targets, out_freq, average='macro')
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.0951174573929065
F1 Score (Micro) = 0.5601830440814934
F1 Score (Macro) = 0.4783930423631684


In [None]:
print(classification_report(targets, out_freq, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.5658    0.5455    0.5554       473
       244.9     0.8496    0.2818    0.4232       401
      250.00     0.5188    0.3871    0.4434       713
       272.0     0.6818    0.0792    0.1418       379
       272.4     0.5348    0.5355    0.5351       704
       276.0     0.7908    0.3647    0.4992       425
       276.1     0.7891    0.2806    0.4139       360
       276.2     0.5435    0.3535    0.4284       512
       285.1     0.6241    0.5535    0.5867       486
       285.9     0.7339    0.2878    0.4134       556
       287.5     0.5333    0.2540    0.3441       378
       305.1     0.7717    0.3015    0.4336       325
         311     0.8481    0.2367    0.3702       283
      327.23     0.7398    0.3922    0.5127       232
       401.9     0.6405    0.6809    0.6601      1567
      403.90     0.6304    0.5103    0.5640       341
      403.91     0.6891    0.3779    0.4881       217
      410.71     0.5300    

In [None]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, out_freq, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, out_freq, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.732420009083827
RUC AUC Score (Macro) = 0.6925716653349505


In [None]:
precision_micro = metrics.average_precision_score(targets, out_freq, average='micro')
precision_macro = metrics.average_precision_score(targets, out_freq, average='macro')

print(f"Average Precision Score (Micro) = {precision_micro}")
print(f"Average Precision Score (Macro) = {precision_macro}")

Average Precision Score (Micro) = 0.3721066457036031
Average Precision Score (Macro) = 0.33263860118194366


### Exponential moving average

In [None]:
def ewma(sub_df, window=3):
#    print(sub_df)
    alpha = 2 / (window + 1)
#    print(sub_df['ewma'])
    sub_df['ewma'] = sub_df['out_bool']
    for r in range(len(sub_df)):
        if r == 0:
            sub_df['ewma'].iloc[r] = sub_df['prediction'].iloc[r]
        else:
            sub_df['ewma'].iloc[r] = alpha*sub_df['prediction'].iloc[r] + (1-alpha)*sub_df['prediction'].iloc[r-1]
 #   print(type(sub_df['ewma']))
    return sub_df['ewma']

In [None]:

out_ewma_dict = dev_df.groupby('id', group_keys=False).apply(ewma).to_dict()
dev_df['out_ewma'] = pd.Series(dev_df.index, index=dev_df.index).map(out_ewma_dict)



In [None]:
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


In [None]:
out_ewma = np.vstack([dev_df['out_ewma'][i]>0.5 for i in dev_df.index])

#targets = dev_tar
accuracy = metrics.accuracy_score(targets, out_ewma)
f1_score_micro = metrics.f1_score(targets, out_ewma, average='micro')
f1_score_macro = metrics.f1_score(targets, out_ewma, average='macro')
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.07554122524182405
F1 Score (Micro) = 0.47892896826472264
F1 Score (Macro) = 0.3851357639969061


In [None]:
print(classification_report(targets, out_ewma, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.5658    0.3636    0.4427       473
       244.9     0.8152    0.1870    0.3043       401
      250.00     0.5081    0.2651    0.3484       713
       272.0     0.6552    0.0501    0.0931       379
       272.4     0.5562    0.4290    0.4844       704
       276.0     0.7840    0.2306    0.3564       425
       276.1     0.7692    0.1667    0.2740       360
       276.2     0.5205    0.2480    0.3360       512
       285.1     0.6261    0.4342    0.5128       486
       285.9     0.7226    0.1781    0.2857       556
       287.5     0.5000    0.1772    0.2617       378
       305.1     0.7765    0.2031    0.3220       325
         311     0.7955    0.1237    0.2141       283
      327.23     0.7857    0.2371    0.3642       232
       401.9     0.6488    0.5801    0.6125      1567
      403.90     0.6447    0.3724    0.4721       341
      403.91     0.6962    0.2535    0.3716       217
      410.71     0.5385    

In [None]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, out_ewma, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, out_ewma, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.6784930237763072
RUC AUC Score (Macro) = 0.6431135506868825


In [None]:
precision_micro = metrics.average_precision_score(targets, out_ewma, average='micro')
precision_macro = metrics.average_precision_score(targets, out_ewma, average='macro')

print(f"Average Precision Score (Micro) = {precision_micro}")
print(f"Average Precision Score (Macro) = {precision_macro}")

Average Precision Score (Micro) = 0.3105792480465515
Average Precision Score (Macro) = 0.27171223658749033


### Most frequent prediction

In [None]:

most_freq_dict = dev_df.groupby('id')['out_bool'].apply(lambda x: x.value_counts().index[0]).to_dict()
dev_df['most_freq'] = dev_df['id'].map(most_freq_dict)

In [None]:
dev_df

Unnamed: 0,id,text,labels,prediction,tar,out_mean,out_sum,freq_5,note_count,out_bool,num_pred,out_ewma,most_freq
0,193970,title,"[0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, ...","[0.1564287, 0.7215239, 0.2871825, 0.060666334,...","[0.7821435, 3.6076193, 1.4359126, 0.30333167, ...","[False, True, False, False, False, False, True...",5,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, ...","[0.17306054, 0.38443032, 0.12503847, 0.0572557...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
1,177271,a v paced rhythm,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.3298624, 0.14561419, 0.23566455, 0.12802416...","[0.6597248, 0.29122838, 0.4713291, 0.25604832,...","[True, False, False, False, False, False, Fals...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.022196637, 0.052830663, 0.35087362, 0.21030...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,133242,hr s not s,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.049565446, 0.053445615, 0.05199201, 0.12330...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,181295,sinus rhythm long qtc interval,"[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ...","[0.1865225, 0.069831856, 0.17037447, 0.0480493...","[0.5595675, 0.20949556, 0.5111234, 0.14414798,...","[False, False, False, False, False, False, Fal...",3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.3052587, 0.09415755, 0.12252541, 0.07794826...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,191074,nsg see patient transfer note,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[False, False, False, False, False, False, Fal...",1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.01033157, 0.036302544, 0.12405683, 0.146747...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4337,146659,chief complaint altered mental status bloody d...,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0.0033487554, 0.056424737, 0.42125046, 0.0055...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0.0065594586, 0.11678544, 0.58675706, 0.01665...","[False, False, False, False, True, False, Fals...",2,"[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, ...","[0.0032797293, 0.05839272, 0.29337853, 0.00832...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4338,107710,chief complaint respiratory distress hpi year ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.015700001, 0.008243611, 0.058876704, 0.0064...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0.06778842, 0.022426467, 0.17893201, 0.017613...","[False, False, False, False, False, False, Fal...",2,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0.03389421, 0.011213234, 0.089466006, 0.00880...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
4339,199085,admission date discharge date date of birth se...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.13779767, 0.9013208, 0.6370169, 0.121300906...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.124280356, 0.37028039, 0.48501018, 0.09647,...","[0.37284106, 1.1108412, 1.4550306, 0.28941, 0....","[False, True, True, False, False, False, False...",3,"[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0.07349442, 0.5285124, 0.67090994, 0.1283422,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4340,168029,chief complaint s p fall and point hct drop hp...,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0.04504987, 0.1184832, 0.20547101, 0.07595165...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, ...","[0.021544356, 0.12078671, 0.12409455, 0.058570...","[0.08617742, 0.48314685, 0.4963782, 0.23428087...","[False, False, False, False, False, False, Fal...",4,"[0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, ...","[0.024463832, 0.18008158, 0.18994005, 0.068110...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, ..."


In [None]:
out_most_freq = np.vstack([dev_df['most_freq'][i] for i in dev_df.index])

#targets = dev_tar
accuracy = metrics.accuracy_score(targets, out_most_freq)
f1_score_micro = metrics.f1_score(targets, out_most_freq, average='micro')
f1_score_macro = metrics.f1_score(targets, out_most_freq, average='macro')
print(f"Accuracy Score = {accuracy}")
print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Accuracy Score = 0.08360202671579917
F1 Score (Micro) = 0.4749716336385466
F1 Score (Macro) = 0.3939261496460077


In [None]:
print(classification_report(targets, out_most_freq, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.5056    0.3827    0.4356       473
       244.9     0.7544    0.2145    0.3340       401
      250.00     0.4752    0.2959    0.3647       713
       272.0     0.7143    0.0660    0.1208       379
       272.4     0.5383    0.4588    0.4954       704
       276.0     0.6941    0.2776    0.3966       425
       276.1     0.7436    0.1611    0.2648       360
       276.2     0.4565    0.2461    0.3198       512
       285.1     0.5750    0.4259    0.4894       486
       285.9     0.6753    0.1871    0.2930       556
       287.5     0.4626    0.1799    0.2590       378
       305.1     0.6900    0.2123    0.3247       325
         311     0.7800    0.1378    0.2342       283
      327.23     0.6842    0.2802    0.3976       232
       401.9     0.6297    0.5469    0.5854      1567
      403.90     0.5837    0.3988    0.4739       341
      403.91     0.6139    0.2857    0.3899       217
      410.71     0.5305    

In [None]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, out_most_freq, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, out_most_freq, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.6798941467054299
RUC AUC Score (Macro) = 0.6477460663973096


In [None]:
precision_micro = metrics.average_precision_score(targets, out_most_freq, average='micro')
precision_macro = metrics.average_precision_score(targets, out_most_freq, average='macro')

print(f"Average Precision Score (Micro) = {precision_micro}")
print(f"Average Precision Score (Macro) = {precision_macro}")

Average Precision Score (Micro) = 0.3018685091755825
Average Precision Score (Macro) = 0.2707157800691207


In [None]:
def precision_at_5(df, column): # column: prediction
    df['p@5'] = 0
    for idx in df.index:
        idx_sorted = np.argsort(df[column][idx])
        idcs = idx_sorted[-3:]
        perc = (3 - sum(df['tar'][idx][idcs])) / 3
        df['p@5'][idx] = perc

### Testing

In [30]:
# Evaluate the model

# Evaluate the model

def testing():
    model.eval()

    fin_targets=[]
    fin_outputs=[]
    losses=[]
    with torch.no_grad():
        for _, data in enumerate(testing_loader, 0):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)
            outputs = model(ids, mask, token_type_ids)
            loss = loss_fn(outputs, targets)
            losses.append(loss.item())

            fin_targets.extend(targets.cpu().detach().numpy())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy())
    print(f'Loss:  {np.mean(losses):.2f}')
    return fin_outputs, fin_targets, losses

In [31]:
test_out, targets, losses = testing()
outputs = np.array(test_out) >= 0.5
accuracy = metrics.accuracy_score(targets, outputs)
f1_score_micro = metrics.f1_score(targets, outputs, average='micro')
f1_score_macro = metrics.f1_score(targets, outputs, average='macro')

print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Loss:  0.26
F1 Score (Micro) = 0.49005472532098515
F1 Score (Macro) = 0.4039059357517949


In [32]:
print(classification_report(targets, outputs, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.5306    0.4235    0.4710       451
       244.9     0.7478    0.2257    0.3468       381
      250.00     0.5631    0.3053    0.3959       760
       272.0     0.7692    0.0539    0.1008       371
       272.4     0.5332    0.4544    0.4907       724
       276.0     0.7532    0.2874    0.4161       414
       276.1     0.6832    0.1896    0.2968       364
       276.2     0.5385    0.3105    0.3939       541
       285.1     0.5755    0.4911    0.5300       450
       285.9     0.6294    0.1737    0.2723       616
       287.5     0.5635    0.2810    0.3750       363
       305.1     0.6633    0.2083    0.3171       312
         311     0.7273    0.1784    0.2866       269
      327.23     0.8250    0.2435    0.3761       271
       401.9     0.6191    0.5761    0.5969      1583
      403.90     0.5741    0.3503    0.4351       354
      403.91     0.7093    0.2618    0.3824       233
      410.71     0.5325    

In [33]:

ruc_auc_score_micro = metrics.roc_auc_score(targets, outputs, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, outputs, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.6884715049565272
RUC AUC Score (Macro) = 0.6541345343856356


In [34]:

test_df['prediction'] = test_out
test_df['tar'] = targets

In [35]:
note_count_dict = test_df.groupby('id').size().to_dict()
test_df['note_count'] = test_df['id'].map(note_count_dict)

In [36]:
test_df['out_bool'] = [(test_df['prediction'][i]>=0.5).astype(int) for i in test_df.index]

In [37]:

out_freq_dict = test_df.groupby('id').out_bool.apply(np.sum).to_dict()
test_df['num_pred'] = test_df['id'].map(out_freq_dict)
test_df['num_pred'] = [(test_df['num_pred'][i]>=0.4*test_df['note_count'][i]).astype(int) for i in test_df.index]

In [38]:
df_freq = test_df.drop_duplicates('id')

In [39]:
out_freq = np.vstack([df_freq['num_pred'][i] for i in df_freq.index])
targets = np.vstack([df_freq['tar'][i] for i in df_freq.index])
#targets = dev_tar

f1_score_micro = metrics.f1_score(targets, out_freq, average='micro')
f1_score_macro = metrics.f1_score(targets, out_freq, average='macro')

print(f"F1 Score (Micro) = {f1_score_micro}")
print(f"F1 Score (Macro) = {f1_score_macro}")

F1 Score (Micro) = 0.47435066338216864
F1 Score (Macro) = 0.38374212327293816


In [40]:
print(classification_report(targets, out_freq, target_names=icd_classes_50, digits=4))

              precision    recall  f1-score   support

       038.9     0.4632    0.4100    0.4350       261
       244.9     0.6875    0.1880    0.2953       234
      250.00     0.5056    0.3036    0.3794       448
       272.0     0.7778    0.0576    0.1073       243
       272.4     0.4976    0.4771    0.4871       436
       276.0     0.6329    0.2326    0.3401       215
       276.1     0.5821    0.1781    0.2727       219
       276.2     0.4842    0.2997    0.3702       307
       285.1     0.5120    0.4812    0.4961       266
       285.9     0.5579    0.1532    0.2404       346
       287.5     0.4904    0.2417    0.3238       211
       305.1     0.5690    0.1833    0.2773       180
         311     0.6410    0.1453    0.2370       172
      327.23     0.7073    0.2086    0.3222       139
       401.9     0.5862    0.6043    0.5951       973
      403.90     0.5385    0.3665    0.4361       191
      403.91     0.7241    0.2763    0.4000       152
      410.71     0.5175    

In [41]:
ruc_auc_score_micro = metrics.roc_auc_score(targets, out_freq, average='micro')
ruc_auc_score_macro = metrics.roc_auc_score(targets, out_freq, average='macro')

print(f"RUC AUC Score (Micro) = {ruc_auc_score_micro}")
print(f"RUC AUC Score (Macro) = {ruc_auc_score_macro}")

RUC AUC Score (Micro) = 0.6852030568007662
RUC AUC Score (Macro) = 0.6485034718917585
