In [1]:
!pip install datasets transformers > /dev/null

In [2]:
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [4]:
from torch.utils.data import Dataset, DataLoader
from nltk import word_tokenize
from collections import Counter, defaultdict

class SentDataset(Dataset):
    def __init__(self, fname, word2ix=None, df=None):
        if df is None:
            self.sents = pd.read_csv(fname)
        else:
            self.sents = df.dropna(subset=['sentence'])
        # reset index necessary for getting context sentences when using unlabeled data
        #self.sents = self.sents.dropna(subset=['sentence'])
        if word2ix is not None:
            self.word2ix = word2ix
            self.ix2word = {ix:word for word,ix in self.word2ix.items()}
        else:
            self.word2ix = {'<PAD>': 0}
            self.ix2word = {}
            self.min_cnt = 0

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        row = self.sents.iloc[idx]
        if row.sentence.startswith('[') and row.sentence.endswith(']'):
            try:
                sentence = eval(row.sentence)
            except:
                sentence = word_tokenize(row.sentence)
        else:
            assert isinstance(row.sentence, str)
            sentence = word_tokenize(row.sentence)
        sent = [x.lower() for x in sentence]
        if 'labels' in row:
          labels = eval(row.labels)
        else:
          labels = []
        doc_id = row.doc_id
        sent_ix = None
        if 'sent_index' in row:
            sent_ix = row.sent_index
        elif 'sent_ix' in row:
            sent_ix = row.sent_ix
        return sent, labels, doc_id, sent_ix

In [5]:
# load dataset
import ast
def str_to_list(s):
    # Use the ast library to safely evaluate the string as a list
    return ast.literal_eval(s)

sents = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/sentence_level.csv")

# Apply the function to the column to convert the strings to lists
#sents['sentence'] = sents['sentence'].apply(str_to_list)

# Apply the function to the column to convert the strings to lists
#sents['labels'] = sents['labels'].apply(str_to_list)
train_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/train_ids.csv", header=None)
test_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/test_ids.csv", header=None)
eval_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/val_ids.csv", header=None)

In [6]:
train_dataset.columns = ['doc_id']
train_dataset = train_dataset.merge(sents, on = "doc_id")
test_dataset.columns = ['doc_id']
test_dataset = test_dataset.merge(sents, on = "doc_id")
eval_dataset.columns = ['doc_id']
eval_dataset = eval_dataset.merge(sents, on = "doc_id")

In [7]:
train_dataset = SentDataset('contet',df =train_dataset)
eval_dataset = SentDataset('content',df=eval_dataset)
test_dataset = SentDataset('content',df=test_dataset)

In [8]:
tokenizer1 = BertTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

# Datasets & Dataloaders


In [9]:
batch_size = 64

In [10]:
def collator(batch, tokenizer, eval=False, doc_position=False):
    # standard bert collator that applies the right label for the training task
    sents = []
    labels = []
    doc_poses = []
    doc_ids = []
    sent_ixs = []
    for sent, label, doc_id, sent_ix in batch:
        sents.append(' '.join(sent))
        label = multilabel_labeler(label)
        labels.append(label)
        doc_ids.append(doc_id)
        sent_ixs.append(sent_ix)
    tokd = tokenizer(sents, padding=True, max_length = 512, truncation=True)
    input_ids, token_type_ids, attention_mask = tokd['input_ids'], tokd['token_type_ids'], tokd['attention_mask']
    toks = torch.LongTensor(input_ids)
    mask = torch.LongTensor(attention_mask)
    labels = torch.Tensor(labels)
    # for whatever reason, pytorch's cross entropy requires long labels but BCE doesn't
    if not eval:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}
    else:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'sentences': sents, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}

data_collator = lambda x: collator(x, tokenizer = tokenizer1)
eval_collate_fn = lambda x: collator(x, tokenizer =tokenizer1, eval=True)

In [11]:
# import torch
# import tqdm

# class BertDataset(torch.utils.data.Dataset):
#   def __init__(self, txt_list, tokenizer, max_length=512):
#     self.tokenizer = tokenizer
#     self.sents = []
#     self.labels = []
#     self.doc_poses = []
#     self.doc_ids = []
#     self.sent_ixs = []
#     self.input_ids = []
#     self.token_type_ids = []
#     self.attention_maks = []


#     for doc_id, sent_ix, sent, label in tqdm.tqdm(txt_list.values, desc="Tokenizing data"):
#         #sents.append(' '.join(sent))
#         tokd = tokenizer(sent, padding=True, max_length = 512)
#         self.input_ids.append(torch.Tensor(tokd['input_ids']))
#         self.token_type_ids.append(tokd['token_type_ids'])
#         self.attention_maks.append(torch.Tensor(tokd['attention_mask']))
#         label = multilabel_labeler(label)
#         self.labels.append(label)
#         self.doc_ids.append(doc_id)
#         self.sent_ixs.append(sent_ix)
#     self.labels = torch.Tensor(self.labels)
    
#   def __len__(self):
#     return len(self.sent_ixs)

#   def __getitem__(self, idx):
#     return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_maks[idx], 'labels': self.labels[idx], 'doc_ids': self.doc_ids[idx], 'sent_ixs': self.sent_ixs[idx]} 

In [12]:
LABEL_TYPES = ['I-Imaging-related followup',
 'I-Appointment-related followup',
 'I-Medication-related followups',
 'I-Procedure-related followup',
 'I-Lab-related followup',
 'I-Case-specific instructions for patient',
 'I-Other helpful contextual information',
 ]

label2abbrev = {'I-Imaging-related followup': 'Imaging',
        'I-Appointment-related followup': 'Appointment',
        'I-Medication-related followups': 'Medication',
        'I-Procedure-related followup': 'Procedure',
        'I-Lab-related followup': 'Lab',
        'I-Case-specific instructions for patient': 'Patient instructions',
        'I-Other helpful contextual information': 'Other',
 }

def multilabel_labeler(label):
    # convert list of label names to a multi-hot array
    if isinstance(label, str):
        label = eval(label)
    label_ixs = [LABEL_TYPES.index(l) for l in label]
    label = np.zeros(len(LABEL_TYPES))
    label[label_ixs] = 1
    return label

In [13]:

# train_dataset = BertDataset(train_dataset, tokenizer, max_length=512, collate_fn=data_collator)
# eval_dataset = BertDataset(eval_dataset, tokenizer, max_length=512, collate_fn=eval_collate_fn)

# print()
# print('{:>5,} training samples'.format(len(train_dataset)))
# print('{:>5,} validation samples'.format(len(eval_dataset)))

In [14]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=data_collator
)

validation_dataloader = DataLoader(
  eval_dataset,
  shuffle=False,
  batch_size=1,
  collate_fn=eval_collate_fn
)

test_dataloader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, collate_fn=eval_collate_fn
    )

# Finetune BERT Model

In [15]:
from transformers import BertPreTrainedModel, BertModel, BertForMaskedLM
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, KLDivLoss
import torch.nn.functional as F

class BioClinicalBertForSequenceMultilabelClassification(BertPreTrainedModel):
    """
        simple mod of MiniBERT to accept multilabel or binary task
        there may or may not be a class in huggingface that does this but this is here for historical reasons
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.biobert = BertModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def add_doc_position_feature(self, config):
        self.classifier = nn.Linear(config.hidden_size+1, config.num_labels)

    def set_task(self, task):
        self.task = task

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        doc_positions=None,
    ):
        outputs = self.biobert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if doc_positions is not None and len(doc_positions) > 0:
            pooled_output = torch.cat((pooled_output, doc_positions.unsqueeze(1)), dim=1)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(outputs[0], labels)
            outputs = (loss,) + outputs

        return outputs




In [16]:
label_set = LABEL_TYPES
num_labels = len(label_set)

label2id = {label:ix for ix,label in enumerate(label_set)}
id2label = {ix:label for label,ix in label2id.items()}
config = BertConfig.from_pretrained(
        'emilyalsentzer/Bio_ClinicalBERT',
        num_labels=num_labels,
        finetuning_task="text_classification",
        label2id=label2id,
        id2label=id2label,
    )

In [17]:
import random
model = BioClinicalBertForSequenceMultilabelClassification.from_pretrained('emilyalsentzer/Bio_ClinicalBERT', config=config)

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BioClinicalBertForSequenceMultilabelC

freeze weights if needed

In [18]:
for name, params in model.named_parameters():
  if name.startswith('biobert.embeddings.'):
    params.requires_grad = False
  if name.startswith('biobert.encoder.layer.0'):
    params.requires_grad = False
  if name.startswith('biobert.encoder.layer.1'):
    params.requires_grad = False
  if name.startswith('biobert.encoder.layer.2'):
    params.requires_grad = False
  if name.startswith('biobert.encoder.layer.3'):
    params.requires_grad = False
  if name.startswith('biobert.encoder.layer.4'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.5'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.6'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.7'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.8'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.9'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.10'):
    params.requires_grad = True
  if name.startswith('biobert.encoder.layer.11'):
    params.requires_grad = True

In [19]:
for name, params in model.named_parameters():
  print(name)
  print(params.requires_grad)

biobert.embeddings.word_embeddings.weight
False
biobert.embeddings.position_embeddings.weight
False
biobert.embeddings.token_type_embeddings.weight
False
biobert.embeddings.LayerNorm.weight
False
biobert.embeddings.LayerNorm.bias
False
biobert.encoder.layer.0.attention.self.query.weight
False
biobert.encoder.layer.0.attention.self.query.bias
False
biobert.encoder.layer.0.attention.self.key.weight
False
biobert.encoder.layer.0.attention.self.key.bias
False
biobert.encoder.layer.0.attention.self.value.weight
False
biobert.encoder.layer.0.attention.self.value.bias
False
biobert.encoder.layer.0.attention.output.dense.weight
False
biobert.encoder.layer.0.attention.output.dense.bias
False
biobert.encoder.layer.0.attention.output.LayerNorm.weight
False
biobert.encoder.layer.0.attention.output.LayerNorm.bias
False
biobert.encoder.layer.0.intermediate.dense.weight
False
biobert.encoder.layer.0.intermediate.dense.bias
False
biobert.encoder.layer.0.output.dense.weight
False
biobert.encoder.layer.

In [20]:
import time

timestamp = time.strftime('%b_%d_%H:%M:%S', time.localtime())

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)



# Compute Predictions

In [21]:
def gather_predictions(model, loader, device, doc_position=False, save_preds=False, max_preds=1e9):
    # run through data loader and get model predictions
    with torch.no_grad():
        model.eval()
        yhat_raw = np.zeros((len(loader), 7))
        yhat = np.zeros((len(loader), 7))
        y = np.zeros((len(loader), 7))
        sentences = []
        doc_ids = []
        sent_ixs = []
        for ix, x in enumerate(loader):
            if ix >= max_preds:
                break
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            sentences.extend(x['sentences'])
            doc_ids.extend(x['doc_ids'])
            sent_ixs.extend(x['sent_ixs'])
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            loss, pred = model(**inputs)
            pred = torch.sigmoid(pred)
            pred = pred.cpu().numpy()
            yhat_raw[ix] = pred
            yhat[ix] = np.round(pred)
            y[ix] = x['labels'].cpu().numpy()[0]
    return yhat_raw, yhat, y, sentences, doc_ids, sent_ixs

# Evaluation Metrics

In [22]:
import datetime
import time

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [23]:
from collections import defaultdict
import csv
import json
import numpy as np
import os
import sys

from sklearn.metrics import roc_curve, auc, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm


def auc_metrics(yhat_raw, y, ymic):
    if yhat_raw.shape[0] <= 1:
        return
    fpr = {}
    tpr = {}
    roc_auc = {}
    #get AUC for each label individually
    relevant_labels = []
    auc_labels = {}
    for i in range(y.shape[1]):
        #only if there are true positives for this label
        if y[:,i].sum() > 0:
            fpr[i], tpr[i], _ = roc_curve(y[:,i], yhat_raw[:,i])
            if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                auc_score = auc(fpr[i], tpr[i])
                if not np.isnan(auc_score): 
                    auc_labels["auc_%d" % i] = auc_score
                    relevant_labels.append(i)

    #macro-AUC: just average the auc scores
    aucs = []
    for i in relevant_labels:
        aucs.append(auc_labels['auc_%d' % i])
    roc_auc['auc_macro'] = np.mean(aucs)

    #micro-AUC: just look at each individual prediction
    yhatmic = yhat_raw.ravel()
    fpr["micro"], tpr["micro"], _ = roc_curve(ymic, yhatmic) 
    roc_auc["auc_micro"] = auc(fpr["micro"], tpr["micro"])

    return roc_auc

def union_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_or(yhat, y).sum(axis=axis).astype(float)

def intersect_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_and(yhat, y).sum(axis=axis).astype(float)


#########################################################################
#MACRO METRICS: calculate metric for each label and average across labels
#########################################################################

def macro_accuracy(yhat, y):
    num = intersect_size(yhat, y, 0) / (union_size(yhat, y, 0) + 1e-10)
    return np.mean(num)

def macro_precision(yhat, y):
    num = intersect_size(yhat, y, 0) / (yhat.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_recall(yhat, y):
    num = intersect_size(yhat, y, 0) / (y.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_f1(yhat, y):
    prec = macro_precision(yhat, y)
    rec = macro_recall(yhat, y)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1

##########################################################################
#MICRO METRICS: treat every prediction as an individual binary prediction
##########################################################################

def micro_accuracy(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / union_size(yhatmic, ymic, 0)

def micro_precision(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / yhatmic.sum(axis=0)

def micro_recall(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / ymic.sum(axis=0)

def micro_f1(yhatmic, ymic):
    prec = micro_precision(yhatmic, ymic)
    rec = micro_recall(yhatmic, ymic)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1


def all_macro(yhat, y):
    return macro_accuracy(yhat, y), macro_precision(yhat, y), macro_recall(yhat, y), macro_f1(yhat, y)

def all_micro(yhatmic, ymic):
    return micro_accuracy(yhatmic, ymic), micro_precision(yhatmic, ymic), micro_recall(yhatmic, ymic), micro_f1(yhatmic, ymic)


def all_metrics(yhat, y, yhat_raw=None, calc_auc=True, label_order=[]):
    """
        Inputs:
            yhat: binary predictions matrix 
            y: binary ground truth matrix
            yhat_raw: prediction scores matrix (floats)
        Outputs:
            dict holding relevant metrics
    """
    names = ["acc", "prec", "rec", "f1"]

    metrics = {}
    for ix,label in enumerate(label_order):
        metrics[f"{label2abbrev[label]}-f1"] = f1_score(y[:,ix], yhat[:,ix])

    #macro
    #print("GETTING ALL MACRO")
    macro = all_macro(yhat, y)

    #micro
    #print("GETTING ALL MICRO")
    ymic = y.ravel()
    yhatmic = yhat.ravel()
    micro = all_micro(yhatmic, ymic)

    metrics.update({names[i] + "_macro": macro[i] for i in range(len(macro))})
    metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

    #AUC
    #print("AUC")
    if yhat_raw is not None and calc_auc:
        roc_auc = auc_metrics(yhat_raw, y, ymic)
        metrics.update(roc_auc)

    return metrics

In [24]:
def evaluate(model, dv_loader, device, tokenizer, return_thresholds=False, doc_position=False):
    # apply model to validation set and compute metrics, including identifying best thresholds
    yhat_raw, yhat, y, sentences, doc_ids, sent_ixs = gather_predictions(model, dv_loader, device)


    # unbalanced metrics
    metrics = all_metrics(yhat, y, yhat_raw=yhat_raw, calc_auc=True, label_order=LABEL_TYPES)

    return metrics

In [25]:
# parameters

max_epochs=5
learning_rate=5e-5
max_steps=-1
gradient_accumulation_steps=4
eval_steps = 100


In [26]:
# Training loop
total_t0 = time.time()
tr_loss = 0.0
model.zero_grad()
model.train()
metrics_hist = defaultdict(list)
best_epoch = 0
best_step = 0
step = 0
losses = []
for epoch in range(max_epochs):
        step = 0
        t0 = time.time()
        for x in tqdm(train_dataloader):
            if max_steps > -1 and step > max_steps:
                break
            # transfer to gpu
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            #outputs = model(**inputs)
            loss, pred = model(**inputs)

            # update parameters
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % gradient_accumulation_steps == 0 or (
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                len(train_dataloader) <= gradient_accumulation_steps
                and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                model.zero_grad()
                losses.append(tr_loss)
                tr_loss = 0.0


            # tensor output
            if (step + 1) % eval_steps == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), np.mean(losses[-10:]), elapsed))

            step += 1

#evaluation metrics
metrics = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics.items():
  metrics_hist[name].append(metric)

#test metrics
metrics_hist_test = defaultdict(list)
metrics_test = evaluate(model, test_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics_test.items():
  metrics_hist_test[name].append(metric)

#save model
out_dir = "/content/drive/My Drive/Deep Learning Project"
sd = model.state_dict()
torch.save(sd, out_dir + "/model_final_BioClincial.pth")

  

  labels = torch.Tensor(labels)
  8%|▊         | 100/1198 [00:38<07:24,  2.47it/s]

  Batch    99  of  1,198. Loss: 0.18798042386770247.   Elapsed: 0:00:38.


 17%|█▋        | 201/1198 [01:09<04:30,  3.68it/s]

  Batch   199  of  1,198. Loss: 0.1011527406051755.   Elapsed: 0:01:10.


 25%|██▌       | 301/1198 [01:44<03:51,  3.88it/s]

  Batch   299  of  1,198. Loss: 0.07052444708533585.   Elapsed: 0:01:45.


 33%|███▎      | 400/1198 [02:15<04:18,  3.08it/s]

  Batch   399  of  1,198. Loss: 0.052933424315415326.   Elapsed: 0:02:16.


 42%|████▏     | 500/1198 [02:46<02:11,  5.32it/s]

  Batch   499  of  1,198. Loss: 0.06251525122206658.   Elapsed: 0:02:46.


 50%|█████     | 600/1198 [03:20<03:18,  3.02it/s]

  Batch   599  of  1,198. Loss: 0.048913332051597536.   Elapsed: 0:03:21.


 59%|█████▊    | 701/1198 [03:55<01:49,  4.54it/s]

  Batch   699  of  1,198. Loss: 0.043844261299818756.   Elapsed: 0:03:55.


 67%|██████▋   | 801/1198 [04:27<01:55,  3.45it/s]

  Batch   799  of  1,198. Loss: 0.04724556610453874.   Elapsed: 0:04:27.


 75%|███████▌  | 901/1198 [04:58<01:48,  2.73it/s]

  Batch   899  of  1,198. Loss: 0.03728429033653811.   Elapsed: 0:04:58.


 84%|████████▎ | 1001/1198 [05:26<00:53,  3.66it/s]

  Batch   999  of  1,198. Loss: 0.041477080923505126.   Elapsed: 0:05:27.


 92%|█████████▏| 1100/1198 [05:55<00:22,  4.34it/s]

  Batch 1,099  of  1,198. Loss: 0.040056271752109754.   Elapsed: 0:05:56.


100%|██████████| 1198/1198 [06:25<00:00,  3.11it/s]
  8%|▊         | 100/1198 [00:35<07:24,  2.47it/s]

  Batch    99  of  1,198. Loss: 0.04472177331917919.   Elapsed: 0:00:36.


 17%|█▋        | 201/1198 [01:06<04:29,  3.69it/s]

  Batch   199  of  1,198. Loss: 0.0328934132528957.   Elapsed: 0:01:07.


 25%|██▌       | 301/1198 [01:41<03:50,  3.89it/s]

  Batch   299  of  1,198. Loss: 0.029988906043581665.   Elapsed: 0:01:41.


 33%|███▎      | 400/1198 [02:12<04:18,  3.09it/s]

  Batch   399  of  1,198. Loss: 0.02776574089657515.   Elapsed: 0:02:13.


 42%|████▏     | 500/1198 [02:43<02:11,  5.31it/s]

  Batch   499  of  1,198. Loss: 0.03386275724624284.   Elapsed: 0:02:43.


 50%|█████     | 600/1198 [03:17<03:18,  3.02it/s]

  Batch   599  of  1,198. Loss: 0.03356903553358279.   Elapsed: 0:03:18.


 59%|█████▊    | 701/1198 [03:52<01:49,  4.53it/s]

  Batch   699  of  1,198. Loss: 0.022371929208748042.   Elapsed: 0:03:52.


 67%|██████▋   | 801/1198 [04:24<01:54,  3.46it/s]

  Batch   799  of  1,198. Loss: 0.02963237201620359.   Elapsed: 0:04:24.


 75%|███████▌  | 901/1198 [04:55<01:48,  2.73it/s]

  Batch   899  of  1,198. Loss: 0.027694344409974293.   Elapsed: 0:04:55.


 84%|████████▎ | 1001/1198 [05:23<00:53,  3.66it/s]

  Batch   999  of  1,198. Loss: 0.028871821489883587.   Elapsed: 0:05:23.


 92%|█████████▏| 1100/1198 [05:52<00:22,  4.36it/s]

  Batch 1,099  of  1,198. Loss: 0.03018966636736877.   Elapsed: 0:05:53.


100%|██████████| 1198/1198 [06:22<00:00,  3.13it/s]
  8%|▊         | 100/1198 [00:35<07:25,  2.47it/s]

  Batch    99  of  1,198. Loss: 0.034621725368197076.   Elapsed: 0:00:36.


 17%|█▋        | 201/1198 [01:06<04:29,  3.70it/s]

  Batch   199  of  1,198. Loss: 0.024358938133809717.   Elapsed: 0:01:07.


 25%|██▌       | 301/1198 [01:41<03:51,  3.87it/s]

  Batch   299  of  1,198. Loss: 0.023234902147669344.   Elapsed: 0:01:41.


 33%|███▎      | 400/1198 [02:12<04:18,  3.09it/s]

  Batch   399  of  1,198. Loss: 0.022095591493416576.   Elapsed: 0:02:13.


 42%|████▏     | 500/1198 [02:43<02:11,  5.31it/s]

  Batch   499  of  1,198. Loss: 0.026451301234192214.   Elapsed: 0:02:43.


 50%|█████     | 600/1198 [03:17<03:18,  3.01it/s]

  Batch   599  of  1,198. Loss: 0.02563029950542841.   Elapsed: 0:03:18.


 59%|█████▊    | 701/1198 [03:51<01:49,  4.55it/s]

  Batch   699  of  1,198. Loss: 0.017591299075866117.   Elapsed: 0:03:52.


 67%|██████▋   | 801/1198 [04:24<01:55,  3.44it/s]

  Batch   799  of  1,198. Loss: 0.021640905033564195.   Elapsed: 0:04:24.


 75%|███████▌  | 901/1198 [04:54<01:48,  2.73it/s]

  Batch   899  of  1,198. Loss: 0.02182697125826962.   Elapsed: 0:04:55.


 84%|████████▎ | 1001/1198 [05:23<00:53,  3.66it/s]

  Batch   999  of  1,198. Loss: 0.024573367653647437.   Elapsed: 0:05:23.


 92%|█████████▏| 1100/1198 [05:52<00:22,  4.34it/s]

  Batch 1,099  of  1,198. Loss: 0.025986023267614657.   Elapsed: 0:05:53.


100%|██████████| 1198/1198 [06:22<00:00,  3.13it/s]
  8%|▊         | 100/1198 [00:35<07:24,  2.47it/s]

  Batch    99  of  1,198. Loss: 0.027512274641776458.   Elapsed: 0:00:36.


 17%|█▋        | 201/1198 [01:06<04:29,  3.70it/s]

  Batch   199  of  1,198. Loss: 0.01927780747937504.   Elapsed: 0:01:07.


 25%|██▌       | 301/1198 [01:41<03:50,  3.89it/s]

  Batch   299  of  1,198. Loss: 0.019420328352134676.   Elapsed: 0:01:42.


 33%|███▎      | 400/1198 [02:12<04:19,  3.08it/s]

  Batch   399  of  1,198. Loss: 0.01849491619359469.   Elapsed: 0:02:13.


 42%|████▏     | 500/1198 [02:43<02:11,  5.30it/s]

  Batch   499  of  1,198. Loss: 0.022915302333422004.   Elapsed: 0:02:43.


 50%|█████     | 600/1198 [03:17<03:18,  3.01it/s]

  Batch   599  of  1,198. Loss: 0.02225398366281297.   Elapsed: 0:03:18.


 59%|█████▊    | 701/1198 [03:52<01:49,  4.54it/s]

  Batch   699  of  1,198. Loss: 0.014405238031758926.   Elapsed: 0:03:52.


 67%|██████▋   | 801/1198 [04:24<01:55,  3.45it/s]

  Batch   799  of  1,198. Loss: 0.01840077169035794.   Elapsed: 0:04:24.


 75%|███████▌  | 901/1198 [04:55<01:48,  2.73it/s]

  Batch   899  of  1,198. Loss: 0.01742238488077419.   Elapsed: 0:04:55.


 84%|████████▎ | 1001/1198 [05:23<00:53,  3.65it/s]

  Batch   999  of  1,198. Loss: 0.020125441362324636.   Elapsed: 0:05:24.


 92%|█████████▏| 1100/1198 [05:52<00:22,  4.35it/s]

  Batch 1,099  of  1,198. Loss: 0.022010187638807112.   Elapsed: 0:05:53.


100%|██████████| 1198/1198 [06:22<00:00,  3.13it/s]
  8%|▊         | 100/1198 [00:35<07:25,  2.46it/s]

  Batch    99  of  1,198. Loss: 0.022592314184294082.   Elapsed: 0:00:36.


 17%|█▋        | 201/1198 [01:06<04:30,  3.69it/s]

  Batch   199  of  1,198. Loss: 0.01653091477637645.   Elapsed: 0:01:07.


 25%|██▌       | 301/1198 [01:41<03:51,  3.88it/s]

  Batch   299  of  1,198. Loss: 0.016439586268097627.   Elapsed: 0:01:41.


 33%|███▎      | 400/1198 [02:12<04:20,  3.07it/s]

  Batch   399  of  1,198. Loss: 0.01569100109336432.   Elapsed: 0:02:13.


 42%|████▏     | 500/1198 [02:43<02:11,  5.31it/s]

  Batch   499  of  1,198. Loss: 0.019535391866520514.   Elapsed: 0:02:43.


 50%|█████     | 600/1198 [03:17<03:18,  3.01it/s]

  Batch   599  of  1,198. Loss: 0.019005494186421855.   Elapsed: 0:03:18.


 59%|█████▊    | 701/1198 [03:51<01:49,  4.55it/s]

  Batch   699  of  1,198. Loss: 0.013092208132729866.   Elapsed: 0:03:52.


 67%|██████▋   | 801/1198 [04:24<01:55,  3.44it/s]

  Batch   799  of  1,198. Loss: 0.016899962327443063.   Elapsed: 0:04:24.


 75%|███████▌  | 901/1198 [04:54<01:49,  2.72it/s]

  Batch   899  of  1,198. Loss: 0.014379937735793647.   Elapsed: 0:04:55.


 84%|████████▎ | 1001/1198 [05:23<00:54,  3.65it/s]

  Batch   999  of  1,198. Loss: 0.016178866534028204.   Elapsed: 0:05:23.


 92%|█████████▏| 1100/1198 [05:52<00:22,  4.34it/s]

  Batch 1,099  of  1,198. Loss: 0.02017108709696913.   Elapsed: 0:05:53.


100%|██████████| 1198/1198 [06:22<00:00,  3.13it/s]


In [27]:
print(metrics)
print(metrics_test)

{'Imaging-f1': 0.5479452054794521, 'Appointment-f1': 0.8454106280193237, 'Medication-f1': 0.6835820895522388, 'Procedure-f1': 0.5486725663716815, 'Lab-f1': 0.5347593582887701, 'Patient instructions-f1': 0.8407758231844836, 'Other-f1': 0.2, 'acc_macro': 0.45832366632079075, 'prec_macro': 0.5953648425113938, 'rec_macro': 0.6298090522890244, 'f1_macro': 0.6121027697694905, 'acc_micro': 0.6372653205809422, 'prec_micro': 0.765531914893617, 'rec_micro': 0.7918133802816901, 'f1_micro': 0.778450887061878, 'auc_macro': 0.9640015312768829, 'auc_micro': 0.9810778491530974}
{'Imaging-f1': 0.6176470588235294, 'Appointment-f1': 0.8422597212032282, 'Medication-f1': 0.6613226452905812, 'Procedure-f1': 0.5238095238095237, 'Lab-f1': 0.672566371681416, 'Patient instructions-f1': 0.7829279486002754, 'Other-f1': 0.16513761467889906, 'acc_macro': 0.46615958800863017, 'prec_macro': 0.6453442423450563, 'rec_macro': 0.5979576220954591, 'f1_macro': 0.6207478965847043, 'acc_micro': 0.6125356125356125, 'prec_micr