In [1]:
!pip install datasets transformers > /dev/null

In [2]:
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [4]:
from torch.utils.data import Dataset, DataLoader
from nltk import word_tokenize
from collections import Counter, defaultdict

class SentDataset(Dataset):
    def __init__(self, fname, word2ix=None, df=None):
        if df is None:
            self.sents = pd.read_csv(fname)
        else:
            self.sents = df.dropna(subset=['sentence'])
        # reset index necessary for getting context sentences when using unlabeled data
        #self.sents = self.sents.dropna(subset=['sentence'])
        if word2ix is not None:
            self.word2ix = word2ix
            self.ix2word = {ix:word for word,ix in self.word2ix.items()}
        else:
            self.word2ix = {'<PAD>': 0}
            self.ix2word = {}
            self.min_cnt = 0

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        row = self.sents.iloc[idx]
        if row.sentence.startswith('[') and row.sentence.endswith(']'):
            try:
                sentence = eval(row.sentence)
            except:
                sentence = word_tokenize(row.sentence)
        else:
            assert isinstance(row.sentence, str)
            sentence = word_tokenize(row.sentence)
        sent = [x.lower() for x in sentence]
        if 'labels' in row:
          labels = eval(row.labels)
        else:
          labels = []
        doc_id = row.doc_id
        sent_ix = None
        if 'sent_index' in row:
            sent_ix = row.sent_index
        elif 'sent_ix' in row:
            sent_ix = row.sent_ix
        return sent, labels, doc_id, sent_ix

In [5]:
# load dataset
import ast
def str_to_list(s):
    # Use the ast library to safely evaluate the string as a list
    return ast.literal_eval(s)

sents = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/sentence_level.csv")

# Apply the function to the column to convert the strings to lists
#sents['sentence'] = sents['sentence'].apply(str_to_list)

# Apply the function to the column to convert the strings to lists
#sents['labels'] = sents['labels'].apply(str_to_list)
train_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/train_ids.csv", header=None)
test_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/test_ids.csv", header=None)
eval_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/val_ids.csv", header=None)

In [6]:
train_dataset.columns = ['doc_id']
train_dataset = train_dataset.merge(sents, on = "doc_id")
test_dataset.columns = ['doc_id']
test_dataset = test_dataset.merge(sents, on = "doc_id")
eval_dataset.columns = ['doc_id']
eval_dataset = eval_dataset.merge(sents, on = "doc_id")

In [7]:
train_dataset = SentDataset('contet',df =train_dataset)
eval_dataset = SentDataset('content',df=eval_dataset)
test_dataset = SentDataset('content',df=test_dataset)

In [8]:
tokenizer1 = BertTokenizer.from_pretrained('prajjwal1/bert-mini')

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

# Datasets & Dataloaders


In [9]:
batch_size = 64

In [10]:
def collator(batch, tokenizer, eval=False, doc_position=False):
    # standard bert collator that applies the right label for the training task
    sents = []
    labels = []
    doc_poses = []
    doc_ids = []
    sent_ixs = []
    for sent, label, doc_id, sent_ix in batch:
        sents.append(' '.join(sent))
        label = multilabel_labeler(label)
        labels.append(label)
        doc_ids.append(doc_id)
        sent_ixs.append(sent_ix)
    tokd = tokenizer(sents, padding=True, max_length = 512, truncation=True)
    input_ids, token_type_ids, attention_mask = tokd['input_ids'], tokd['token_type_ids'], tokd['attention_mask']
    toks = torch.LongTensor(input_ids)
    mask = torch.LongTensor(attention_mask)
    labels = torch.Tensor(labels)
    # for whatever reason, pytorch's cross entropy requires long labels but BCE doesn't
    if not eval:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}
    else:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'sentences': sents, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}

data_collator = lambda x: collator(x, tokenizer = tokenizer1)
eval_collate_fn = lambda x: collator(x, tokenizer =tokenizer1, eval=True)

In [11]:
# import torch
# import tqdm

# class BertDataset(torch.utils.data.Dataset):
#   def __init__(self, txt_list, tokenizer, max_length=512):
#     self.tokenizer = tokenizer
#     self.sents = []
#     self.labels = []
#     self.doc_poses = []
#     self.doc_ids = []
#     self.sent_ixs = []
#     self.input_ids = []
#     self.token_type_ids = []
#     self.attention_maks = []


#     for doc_id, sent_ix, sent, label in tqdm.tqdm(txt_list.values, desc="Tokenizing data"):
#         #sents.append(' '.join(sent))
#         tokd = tokenizer(sent, padding=True, max_length = 512)
#         self.input_ids.append(torch.Tensor(tokd['input_ids']))
#         self.token_type_ids.append(tokd['token_type_ids'])
#         self.attention_maks.append(torch.Tensor(tokd['attention_mask']))
#         label = multilabel_labeler(label)
#         self.labels.append(label)
#         self.doc_ids.append(doc_id)
#         self.sent_ixs.append(sent_ix)
#     self.labels = torch.Tensor(self.labels)
    
#   def __len__(self):
#     return len(self.sent_ixs)

#   def __getitem__(self, idx):
#     return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_maks[idx], 'labels': self.labels[idx], 'doc_ids': self.doc_ids[idx], 'sent_ixs': self.sent_ixs[idx]} 

In [12]:
LABEL_TYPES = ['I-Imaging-related followup',
 'I-Appointment-related followup',
 'I-Medication-related followups',
 'I-Procedure-related followup',
 'I-Lab-related followup',
 'I-Case-specific instructions for patient',
 'I-Other helpful contextual information',
 ]

label2abbrev = {'I-Imaging-related followup': 'Imaging',
        'I-Appointment-related followup': 'Appointment',
        'I-Medication-related followups': 'Medication',
        'I-Procedure-related followup': 'Procedure',
        'I-Lab-related followup': 'Lab',
        'I-Case-specific instructions for patient': 'Patient instructions',
        'I-Other helpful contextual information': 'Other',
 }

def multilabel_labeler(label):
    # convert list of label names to a multi-hot array
    if isinstance(label, str):
        label = eval(label)
    label_ixs = [LABEL_TYPES.index(l) for l in label]
    label = np.zeros(len(LABEL_TYPES))
    label[label_ixs] = 1
    return label

In [13]:

# train_dataset = BertDataset(train_dataset, tokenizer, max_length=512, collate_fn=data_collator)
# eval_dataset = BertDataset(eval_dataset, tokenizer, max_length=512, collate_fn=eval_collate_fn)

# print()
# print('{:>5,} training samples'.format(len(train_dataset)))
# print('{:>5,} validation samples'.format(len(eval_dataset)))

In [14]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=data_collator
)

validation_dataloader = DataLoader(
  eval_dataset,
  shuffle=False,
  batch_size=1,
  collate_fn=eval_collate_fn
)

test_dataloader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, collate_fn=eval_collate_fn
    )

# Finetune BERT Model

In [15]:
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, KLDivLoss
import torch.nn.functional as F

class MiniBertForSequenceMultilabelClassification(BertPreTrainedModel):
    """
        simple mod of MiniBERT to accept multilabel or binary task
        there may or may not be a class in huggingface that does this but this is here for historical reasons
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.minibert = BertModel.from_pretrained('prajjwal1/bert-mini')
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def add_doc_position_feature(self, config):
        self.classifier = nn.Linear(config.hidden_size+1, config.num_labels)

    def set_task(self, task):
        self.task = task

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        doc_positions=None,
    ):
        outputs = self.minibert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if doc_positions is not None and len(doc_positions) > 0:
            pooled_output = torch.cat((pooled_output, doc_positions.unsqueeze(1)), dim=1)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(outputs[0], labels)
            outputs = (loss,) + outputs

        return outputs




In [16]:
label_set = LABEL_TYPES
num_labels = len(label_set)

label2id = {label:ix for ix,label in enumerate(label_set)}
id2label = {ix:label for label,ix in label2id.items()}
config = BertConfig.from_pretrained(
        'prajjwal1/bert-mini',
        num_labels=num_labels,
        finetuning_task="text_classification",
        label2id=label2id,
        id2label=id2label,
    )

In [17]:
import random
model = MiniBertForSequenceMultilabelClassification.from_pretrained('prajjwal1/bert-mini', config=config)

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

Downloading pytorch_model.bin:   0%|          | 0.00/45.1M [00:00<?, ?B/s]

Some weights of the model checkpoint at prajjwal1/bert-mini were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at prajjwal1/bert-mini were not used when initializing MiniBertForSequenceMultilabel

freeze weights if needed

In [18]:
# for name, params in model.named_parameters():
#   if name.startswith('bert.embeddings.'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.0'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.1'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.2'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.3'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.4'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.5'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.6'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.7'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.8'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.9'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.10'):
#     params.requires_grad = False
#   if name.startswith('bert.encoder.layer.11'):
#     params.requires_grad = False

In [19]:
import time

timestamp = time.strftime('%b_%d_%H:%M:%S', time.localtime())

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)



# Compute Predictions

In [20]:
def gather_predictions(model, loader, device, doc_position=False, save_preds=False, max_preds=1e9):
    # run through data loader and get model predictions
    with torch.no_grad():
        model.eval()
        yhat_raw = np.zeros((len(loader), 7))
        yhat = np.zeros((len(loader), 7))
        y = np.zeros((len(loader), 7))
        sentences = []
        doc_ids = []
        sent_ixs = []
        for ix, x in enumerate(loader):
            if ix >= max_preds:
                break
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            sentences.extend(x['sentences'])
            doc_ids.extend(x['doc_ids'])
            sent_ixs.extend(x['sent_ixs'])
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            loss, pred = model(**inputs)
            pred = torch.sigmoid(pred)
            pred = pred.cpu().numpy()
            yhat_raw[ix] = pred
            yhat[ix] = np.round(pred)
            y[ix] = x['labels'].cpu().numpy()[0]
    return yhat_raw, yhat, y, sentences, doc_ids, sent_ixs

# Evaluation Metrics

In [21]:
import datetime
import time

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [22]:
from collections import defaultdict
import csv
import json
import numpy as np
import os
import sys

from sklearn.metrics import roc_curve, auc, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm


def auc_metrics(yhat_raw, y, ymic):
    if yhat_raw.shape[0] <= 1:
        return
    fpr = {}
    tpr = {}
    roc_auc = {}
    #get AUC for each label individually
    relevant_labels = []
    auc_labels = {}
    for i in range(y.shape[1]):
        #only if there are true positives for this label
        if y[:,i].sum() > 0:
            fpr[i], tpr[i], _ = roc_curve(y[:,i], yhat_raw[:,i])
            if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                auc_score = auc(fpr[i], tpr[i])
                if not np.isnan(auc_score): 
                    auc_labels["auc_%d" % i] = auc_score
                    relevant_labels.append(i)

    #macro-AUC: just average the auc scores
    aucs = []
    for i in relevant_labels:
        aucs.append(auc_labels['auc_%d' % i])
    roc_auc['auc_macro'] = np.mean(aucs)

    #micro-AUC: just look at each individual prediction
    yhatmic = yhat_raw.ravel()
    fpr["micro"], tpr["micro"], _ = roc_curve(ymic, yhatmic) 
    roc_auc["auc_micro"] = auc(fpr["micro"], tpr["micro"])

    return roc_auc

def union_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_or(yhat, y).sum(axis=axis).astype(float)

def intersect_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_and(yhat, y).sum(axis=axis).astype(float)


#########################################################################
#MACRO METRICS: calculate metric for each label and average across labels
#########################################################################

def macro_accuracy(yhat, y):
    num = intersect_size(yhat, y, 0) / (union_size(yhat, y, 0) + 1e-10)
    return np.mean(num)

def macro_precision(yhat, y):
    num = intersect_size(yhat, y, 0) / (yhat.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_recall(yhat, y):
    num = intersect_size(yhat, y, 0) / (y.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_f1(yhat, y):
    prec = macro_precision(yhat, y)
    rec = macro_recall(yhat, y)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1

##########################################################################
#MICRO METRICS: treat every prediction as an individual binary prediction
##########################################################################

def micro_accuracy(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / union_size(yhatmic, ymic, 0)

def micro_precision(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / yhatmic.sum(axis=0)

def micro_recall(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / ymic.sum(axis=0)

def micro_f1(yhatmic, ymic):
    prec = micro_precision(yhatmic, ymic)
    rec = micro_recall(yhatmic, ymic)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1


def all_macro(yhat, y):
    return macro_accuracy(yhat, y), macro_precision(yhat, y), macro_recall(yhat, y), macro_f1(yhat, y)

def all_micro(yhatmic, ymic):
    return micro_accuracy(yhatmic, ymic), micro_precision(yhatmic, ymic), micro_recall(yhatmic, ymic), micro_f1(yhatmic, ymic)


def all_metrics(yhat, y, yhat_raw=None, calc_auc=True, label_order=[]):
    """
        Inputs:
            yhat: binary predictions matrix 
            y: binary ground truth matrix
            yhat_raw: prediction scores matrix (floats)
        Outputs:
            dict holding relevant metrics
    """
    names = ["acc", "prec", "rec", "f1"]

    metrics = {}
    for ix,label in enumerate(label_order):
        metrics[f"{label2abbrev[label]}-f1"] = f1_score(y[:,ix], yhat[:,ix])

    #macro
    #print("GETTING ALL MACRO")
    macro = all_macro(yhat, y)

    #micro
    #print("GETTING ALL MICRO")
    ymic = y.ravel()
    yhatmic = yhat.ravel()
    micro = all_micro(yhatmic, ymic)

    metrics.update({names[i] + "_macro": macro[i] for i in range(len(macro))})
    metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

    #AUC
    #print("AUC")
    if yhat_raw is not None and calc_auc:
        roc_auc = auc_metrics(yhat_raw, y, ymic)
        metrics.update(roc_auc)

    return metrics

In [23]:
def evaluate(model, dv_loader, device, tokenizer, return_thresholds=False, doc_position=False):
    # apply model to validation set and compute metrics, including identifying best thresholds
    yhat_raw, yhat, y, sentences, doc_ids, sent_ixs = gather_predictions(model, dv_loader, device)


    # unbalanced metrics
    metrics = all_metrics(yhat, y, yhat_raw=yhat_raw, calc_auc=True, label_order=LABEL_TYPES)

    return metrics

In [24]:
# parameters

max_epochs=5
learning_rate=5e-5
max_steps=-1
gradient_accumulation_steps=4
eval_steps = 100


In [25]:
# Training loop
total_t0 = time.time()
tr_loss = 0.0
model.zero_grad()
model.train()
metrics_hist = defaultdict(list)
best_epoch = 0
best_step = 0
step = 0
losses = []
for epoch in range(max_epochs):
        step = 0
        t0 = time.time()
        for x in tqdm(train_dataloader):
            if max_steps > -1 and step > max_steps:
                break
            # transfer to gpu
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            #outputs = model(**inputs)
            loss, pred = model(**inputs)

            # update parameters
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % gradient_accumulation_steps == 0 or (
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                len(train_dataloader) <= gradient_accumulation_steps
                and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                model.zero_grad()
                losses.append(tr_loss)
                tr_loss = 0.0

            # Validation
            if (step + 1) == len(train_dataloader):
                metrics = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
                for name, metric in metrics.items():
                    metrics_hist[name].append(metric)

            # tensor output
            if (step + 1) % eval_steps == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), np.mean(losses[-10:]), elapsed))

            step += 1

#evaluation metrics
metrics = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics.items():
  metrics_hist[name].append(metric)

#test metrics
metrics_hist_test = defaultdict(list)
metrics_test = evaluate(model, test_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics_test.items():
  metrics_hist_test[name].append(metric)

#save model
out_dir = "/content/drive/My Drive/Deep Learning Project"
sd = model.state_dict()
torch.save(sd, out_dir + "/model_final.pth")

  

  labels = torch.Tensor(labels)
  9%|▊         | 102/1198 [00:10<01:21, 13.38it/s]

  Batch    99  of  1,198. Loss: 0.46003335192799566.   Elapsed: 0:00:10.


 17%|█▋        | 201/1198 [00:16<00:58, 16.98it/s]

  Batch   199  of  1,198. Loss: 0.32514231353998185.   Elapsed: 0:00:16.


 25%|██▌       | 302/1198 [00:22<00:50, 17.67it/s]

  Batch   299  of  1,198. Loss: 0.23087943755090237.   Elapsed: 0:00:22.


 34%|███▎      | 404/1198 [00:28<00:42, 18.70it/s]

  Batch   399  of  1,198. Loss: 0.17209042347967624.   Elapsed: 0:00:28.


 42%|████▏     | 501/1198 [00:34<00:38, 17.89it/s]

  Batch   499  of  1,198. Loss: 0.14792772121727465.   Elapsed: 0:00:34.


 50%|█████     | 601/1198 [00:40<00:44, 13.46it/s]

  Batch   599  of  1,198. Loss: 0.12727501448243855.   Elapsed: 0:00:41.


 59%|█████▉    | 704/1198 [00:47<00:25, 19.00it/s]

  Batch   699  of  1,198. Loss: 0.11654797904193401.   Elapsed: 0:00:47.


 67%|██████▋   | 804/1198 [00:53<00:22, 17.81it/s]

  Batch   799  of  1,198. Loss: 0.1110987632535398.   Elapsed: 0:00:53.


 75%|███████▌  | 903/1198 [00:59<00:18, 15.95it/s]

  Batch   899  of  1,198. Loss: 0.10441991426050663.   Elapsed: 0:01:00.


 84%|████████▎ | 1003/1198 [01:05<00:11, 17.43it/s]

  Batch   999  of  1,198. Loss: 0.10269517907872797.   Elapsed: 0:01:05.


 92%|█████████▏| 1102/1198 [01:11<00:05, 17.35it/s]

  Batch 1,099  of  1,198. Loss: 0.09712152490392327.   Elapsed: 0:01:11.


  return intersect_size(yhatmic, ymic, 0) / yhatmic.sum(axis=0)
100%|██████████| 1198/1198 [02:39<00:00,  7.52it/s]
  8%|▊         | 101/1198 [00:06<01:21, 13.48it/s]

  Batch    99  of  1,198. Loss: 0.09043871210888028.   Elapsed: 0:00:06.


 17%|█▋        | 201/1198 [00:12<00:57, 17.40it/s]

  Batch   199  of  1,198. Loss: 0.0775149367749691.   Elapsed: 0:00:12.


 25%|██▌       | 302/1198 [00:18<00:50, 17.66it/s]

  Batch   299  of  1,198. Loss: 0.06986751011572778.   Elapsed: 0:00:19.


 34%|███▎      | 403/1198 [00:24<00:40, 19.42it/s]

  Batch   399  of  1,198. Loss: 0.06225367519073188.   Elapsed: 0:00:25.


 42%|████▏     | 501/1198 [00:30<00:38, 18.25it/s]

  Batch   499  of  1,198. Loss: 0.0651960332877934.   Elapsed: 0:00:31.


 50%|█████     | 601/1198 [00:37<00:42, 14.18it/s]

  Batch   599  of  1,198. Loss: 0.06143377721309662.   Elapsed: 0:00:37.


 59%|█████▉    | 704/1198 [00:43<00:25, 19.15it/s]

  Batch   699  of  1,198. Loss: 0.05773179749958217.   Elapsed: 0:00:43.


 67%|██████▋   | 804/1198 [00:49<00:21, 18.32it/s]

  Batch   799  of  1,198. Loss: 0.05993672697804868.   Elapsed: 0:00:49.


 75%|███████▌  | 903/1198 [00:55<00:18, 15.98it/s]

  Batch   899  of  1,198. Loss: 0.054523811349645256.   Elapsed: 0:00:55.


 84%|████████▎ | 1003/1198 [01:01<00:10, 17.77it/s]

  Batch   999  of  1,198. Loss: 0.059010479226708414.   Elapsed: 0:01:01.


 92%|█████████▏| 1102/1198 [01:07<00:05, 17.24it/s]

  Batch 1,099  of  1,198. Loss: 0.05450966497883201.   Elapsed: 0:01:07.


100%|██████████| 1198/1198 [02:35<00:00,  7.68it/s]
  9%|▊         | 102/1198 [00:06<01:19, 13.71it/s]

  Batch    99  of  1,198. Loss: 0.060077344975434245.   Elapsed: 0:00:06.


 17%|█▋        | 201/1198 [00:12<00:58, 17.09it/s]

  Batch   199  of  1,198. Loss: 0.04935198244638741.   Elapsed: 0:00:13.


 25%|██▌       | 302/1198 [00:18<00:50, 17.82it/s]

  Batch   299  of  1,198. Loss: 0.046334127662703395.   Elapsed: 0:00:19.


 34%|███▎      | 404/1198 [00:24<00:41, 18.92it/s]

  Batch   399  of  1,198. Loss: 0.04110756819136441.   Elapsed: 0:00:25.


 42%|████▏     | 501/1198 [00:30<00:38, 17.98it/s]

  Batch   499  of  1,198. Loss: 0.04618859279435128.   Elapsed: 0:00:31.


 50%|█████     | 601/1198 [00:37<00:42, 14.21it/s]

  Batch   599  of  1,198. Loss: 0.04431038638576865.   Elapsed: 0:00:37.


 59%|█████▊    | 703/1198 [00:43<00:26, 18.68it/s]

  Batch   699  of  1,198. Loss: 0.03930155953858048.   Elapsed: 0:00:43.


 67%|██████▋   | 804/1198 [00:49<00:21, 17.97it/s]

  Batch   799  of  1,198. Loss: 0.04316351932939142.   Elapsed: 0:00:49.


 75%|███████▌  | 904/1198 [00:55<00:17, 16.57it/s]

  Batch   899  of  1,198. Loss: 0.04002765074837953.   Elapsed: 0:00:55.


 84%|████████▎ | 1003/1198 [01:01<00:11, 16.64it/s]

  Batch   999  of  1,198. Loss: 0.04487797438632697.   Elapsed: 0:01:01.


 92%|█████████▏| 1102/1198 [01:07<00:05, 17.58it/s]

  Batch 1,099  of  1,198. Loss: 0.042673731897957624.   Elapsed: 0:01:07.


100%|██████████| 1198/1198 [02:35<00:00,  7.71it/s]
  8%|▊         | 101/1198 [00:06<01:20, 13.59it/s]

  Batch    99  of  1,198. Loss: 0.04735151205677539.   Elapsed: 0:00:06.


 17%|█▋        | 201/1198 [00:12<00:59, 16.66it/s]

  Batch   199  of  1,198. Loss: 0.03821628113510087.   Elapsed: 0:00:13.


 25%|██▌       | 302/1198 [00:18<00:50, 17.82it/s]

  Batch   299  of  1,198. Loss: 0.035887209267821164.   Elapsed: 0:00:19.


 34%|███▎      | 402/1198 [00:24<00:40, 19.81it/s]

  Batch   399  of  1,198. Loss: 0.03344026461709291.   Elapsed: 0:00:25.


 42%|████▏     | 501/1198 [00:30<00:38, 18.03it/s]

  Batch   499  of  1,198. Loss: 0.03788269255310297.   Elapsed: 0:00:31.


 50%|█████     | 601/1198 [00:37<00:43, 13.85it/s]

  Batch   599  of  1,198. Loss: 0.037077158188913016.   Elapsed: 0:00:37.


 59%|█████▉    | 704/1198 [00:43<00:25, 19.27it/s]

  Batch   699  of  1,198. Loss: 0.030459912167862056.   Elapsed: 0:00:43.


 67%|██████▋   | 804/1198 [00:49<00:21, 18.20it/s]

  Batch   799  of  1,198. Loss: 0.03533502211794257.   Elapsed: 0:00:49.


 75%|███████▌  | 903/1198 [00:55<00:18, 15.68it/s]

  Batch   899  of  1,198. Loss: 0.03329079230315983.   Elapsed: 0:00:56.


 84%|████████▎ | 1003/1198 [01:01<00:10, 17.74it/s]

  Batch   999  of  1,198. Loss: 0.037722500471863894.   Elapsed: 0:01:01.


 92%|█████████▏| 1102/1198 [01:07<00:05, 17.62it/s]

  Batch 1,099  of  1,198. Loss: 0.036434451804962006.   Elapsed: 0:01:07.


100%|██████████| 1198/1198 [02:35<00:00,  7.70it/s]
  8%|▊         | 101/1198 [00:06<01:22, 13.34it/s]

  Batch    99  of  1,198. Loss: 0.03858049339614809.   Elapsed: 0:00:06.


 17%|█▋        | 201/1198 [00:12<00:58, 17.16it/s]

  Batch   199  of  1,198. Loss: 0.03265303391963244.   Elapsed: 0:00:12.


 25%|██▌       | 301/1198 [00:18<00:55, 16.28it/s]

  Batch   299  of  1,198. Loss: 0.029785696184262633.   Elapsed: 0:00:19.


 34%|███▎      | 404/1198 [00:25<00:42, 18.74it/s]

  Batch   399  of  1,198. Loss: 0.028280473593622447.   Elapsed: 0:00:25.


 42%|████▏     | 501/1198 [00:30<00:39, 17.66it/s]

  Batch   499  of  1,198. Loss: 0.03201115753035992.   Elapsed: 0:00:31.


 50%|█████     | 601/1198 [00:37<00:42, 14.16it/s]

  Batch   599  of  1,198. Loss: 0.03135218098759651.   Elapsed: 0:00:37.


 59%|█████▉    | 704/1198 [00:43<00:25, 19.19it/s]

  Batch   699  of  1,198. Loss: 0.024839116737712174.   Elapsed: 0:00:43.


 67%|██████▋   | 804/1198 [00:49<00:21, 18.13it/s]

  Batch   799  of  1,198. Loss: 0.03073982266942039.   Elapsed: 0:00:49.


 75%|███████▌  | 903/1198 [00:55<00:19, 15.14it/s]

  Batch   899  of  1,198. Loss: 0.028488232346717268.   Elapsed: 0:00:56.


 84%|████████▎ | 1003/1198 [01:01<00:11, 17.38it/s]

  Batch   999  of  1,198. Loss: 0.03144577117054723.   Elapsed: 0:01:02.


 92%|█████████▏| 1102/1198 [01:07<00:05, 17.56it/s]

  Batch 1,099  of  1,198. Loss: 0.032513184973504396.   Elapsed: 0:01:07.


100%|██████████| 1198/1198 [02:35<00:00,  7.71it/s]


In [29]:
print(metrics)
print(metrics_test)

{'Imaging-f1': 0.0, 'Appointment-f1': 0.829736211031175, 'Medication-f1': 0.4792079207920792, 'Procedure-f1': 0.0, 'Lab-f1': 0.0847457627118644, 'Patient instructions-f1': 0.8197981739548295, 'Other-f1': 0.0, 'acc_macro': 0.25185625069796547, 'prec_macro': 0.47723877511785506, 'rec_macro': 0.29646726552989827, 'f1_macro': 0.36573496193871124, 'acc_micro': 0.5752688172043011, 'prec_micro': 0.8185792349726776, 'rec_micro': 0.659330985915493, 'f1_micro': 0.7303754266211604, 'auc_macro': 0.9343423172534683, 'auc_micro': 0.9664496720465994}
{'Imaging-f1': 0.0, 'Appointment-f1': 0.8344370860927153, 'Medication-f1': 0.5092838196286472, 'Procedure-f1': 0.0, 'Lab-f1': 0.08695652173913045, 'Patient instructions-f1': 0.7624576681180454, 'Other-f1': 0.0, 'acc_macro': 0.2455867114446016, 'prec_macro': 0.48127934175860215, 'rec_macro': 0.2842362304799192, 'f1_macro': 0.35739841453335774, 'acc_micro': 0.5529411764705883, 'prec_micro': 0.8292544109277177, 'rec_micro': 0.623982869379015, 'f1_micro': 0.

# Full Bert Model

In [34]:
class BertForSequenceMultilabelClassification(BertPreTrainedModel):
    """
        simple mod of MiniBERT to accept multilabel or binary task
        there may or may not be a class in huggingface that does this but this is here for historical reasons
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def add_doc_position_feature(self, config):
        self.classifier = nn.Linear(config.hidden_size+1, config.num_labels)

    def set_task(self, task):
        self.task = task

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        doc_positions=None,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if doc_positions is not None and len(doc_positions) > 0:
            pooled_output = torch.cat((pooled_output, doc_positions.unsqueeze(1)), dim=1)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(outputs[0], labels)
            outputs = (loss,) + outputs

        return outputs

In [35]:
tokenizer1 = BertTokenizer.from_pretrained('bert-base-uncased')

data_collator = lambda x: collator(x, tokenizer = tokenizer1)
eval_collate_fn = lambda x: collator(x, tokenizer =tokenizer1, eval=True)

train_dataloader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=data_collator
)

validation_dataloader = DataLoader(
  eval_dataset,
  shuffle=False,
  batch_size=1,
  collate_fn=eval_collate_fn
)

test_dataloader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, collate_fn=eval_collate_fn
    )

In [36]:
config = BertConfig.from_pretrained(
        'bert-base-uncased',
        num_labels=num_labels,
        finetuning_task="text_classification",
        label2id=label2id,
        id2label=id2label,
    )


model = BertForSequenceMultilabelClassification.from_pretrained('bert-base-uncased', config=config)

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceMultilabelClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceMultilabelClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceMultilabelClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceMultilabelClassification were not 

In [37]:

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)

In [39]:
# Training loop
total_t0 = time.time()
tr_loss = 0.0
model.zero_grad()
model.train()
metrics_hist_fullBert = defaultdict(list)
best_epoch = 0
best_step = 0
step = 0
losses = []
for epoch in range(max_epochs):
        step = 0
        t0 = time.time()
        for x in tqdm(train_dataloader):
            if max_steps > -1 and step > max_steps:
                break
            # transfer to gpu
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            #outputs = model(**inputs)
            loss, pred = model(**inputs)

            # update parameters
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % gradient_accumulation_steps == 0 or (
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                len(train_dataloader) <= gradient_accumulation_steps
                and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                model.zero_grad()
                losses.append(tr_loss)
                tr_loss = 0.0

            # Validation
            if (step + 1) == len(train_dataloader):
                metricsFullBert = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
                for name, metric in metricsFullBert.items():
                    metrics_hist_fullBert[name].append(metric)

            # tensor output
            if (step + 1) % eval_steps == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), np.mean(losses[-10:]), elapsed))

            step += 1

#evaluation metrics
metricsFullBert = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
for name, metric in metricsFullBert.items():
  metrics_hist_fullBert[name].append(metric)

#test metrics
metrics_hist_FullBert_test = defaultdict(list)
metricsFullBert_test = evaluate(model, test_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics_test.items():
  metrics_hist_FullBert_test[name].append(metric)

#save model
out_dir = "/content/drive/My Drive/Deep Learning Project"
sd = model.state_dict()
torch.save(sd, out_dir + "/model_final_FullBert.pth")


  0%|          | 0/1198 [00:00<?, ?it/s]


OutOfMemoryError: ignored