In [1]:
!pip install datasets transformers > /dev/null

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [4]:
from torch.utils.data import Dataset, DataLoader
from nltk import word_tokenize
from collections import Counter, defaultdict

class SentDataset(Dataset):
    def __init__(self, fname, word2ix=None, df=None):
        if df is None:
            self.sents = pd.read_csv(fname)
        else:
            self.sents = df.dropna(subset=['sentence'])
        # reset index necessary for getting context sentences when using unlabeled data
        #self.sents = self.sents.dropna(subset=['sentence'])
        if word2ix is not None:
            self.word2ix = word2ix
            self.ix2word = {ix:word for word,ix in self.word2ix.items()}
        else:
            self.word2ix = {'<PAD>': 0}
            self.ix2word = {}
            self.min_cnt = 0

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        row = self.sents.iloc[idx]
        if row.sentence.startswith('[') and row.sentence.endswith(']'):
            try:
                sentence = eval(row.sentence)
            except:
                sentence = word_tokenize(row.sentence)
        else:
            assert isinstance(row.sentence, str)
            sentence = word_tokenize(row.sentence)
        sent = [x.lower() for x in sentence]
        if 'labels' in row:
          labels = eval(row.labels)
        else:
          labels = []
        doc_id = row.doc_id
        sent_ix = None
        if 'sent_index' in row:
            sent_ix = row.sent_index
        elif 'sent_ix' in row:
            sent_ix = row.sent_ix
        return sent, labels, doc_id, sent_ix

In [5]:
# load dataset
import ast
def str_to_list(s):
    # Use the ast library to safely evaluate the string as a list
    return ast.literal_eval(s)

sents = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/sentence_level.csv")

# Apply the function to the column to convert the strings to lists
#sents['sentence'] = sents['sentence'].apply(str_to_list)

# Apply the function to the column to convert the strings to lists
#sents['labels'] = sents['labels'].apply(str_to_list)
train_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/train_ids.csv", header=None)
test_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/test_ids.csv", header=None)
eval_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/val_ids.csv", header=None)

In [6]:
train_dataset.columns = ['doc_id']
train_dataset = train_dataset.merge(sents, on = "doc_id")
test_dataset.columns = ['doc_id']
test_dataset = test_dataset.merge(sents, on = "doc_id")
eval_dataset.columns = ['doc_id']
eval_dataset = eval_dataset.merge(sents, on = "doc_id")

In [7]:
train_dataset = SentDataset('contet',df =train_dataset)
eval_dataset = SentDataset('content',df=eval_dataset)
test_dataset = SentDataset('content',df=test_dataset)

In [8]:
tokenizer1 =  AutoTokenizer.from_pretrained('Sedigh/RoBERTa-large-PM-M3-Voc', use_auth_token = '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')

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/912k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/472k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/17.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

# Datasets & Dataloaders


In [9]:
batch_size = 64

In [10]:
def collator(batch, tokenizer, eval=False, doc_position=False):
    # standard bert collator that applies the right label for the training task
    sents = []
    labels = []
    doc_poses = []
    doc_ids = []
    sent_ixs = []
    for sent, label, doc_id, sent_ix in batch:
        sents.append(' '.join(sent))
        label = multilabel_labeler(label)
        labels.append(label)
        doc_ids.append(doc_id)
        sent_ixs.append(sent_ix)
    tokd = tokenizer(sents, padding=True, max_length = 512, truncation=True)
    input_ids, attention_mask = tokd['input_ids'], tokd['attention_mask']
    toks = torch.LongTensor(input_ids)
    mask = torch.LongTensor(attention_mask)
    labels = torch.Tensor(labels)
    # for whatever reason, pytorch's cross entropy requires long labels but BCE doesn't
    if not eval:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}
    else:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'sentences': sents, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}

data_collator = lambda x: collator(x, tokenizer = tokenizer1)
eval_collate_fn = lambda x: collator(x, tokenizer =tokenizer1, eval=True)

In [11]:
LABEL_TYPES = ['I-Imaging-related followup',
 'I-Appointment-related followup',
 'I-Medication-related followups',
 'I-Procedure-related followup',
 'I-Lab-related followup',
 'I-Case-specific instructions for patient',
 'I-Other helpful contextual information',
 ]

label2abbrev = {'I-Imaging-related followup': 'Imaging',
        'I-Appointment-related followup': 'Appointment',
        'I-Medication-related followups': 'Medication',
        'I-Procedure-related followup': 'Procedure',
        'I-Lab-related followup': 'Lab',
        'I-Case-specific instructions for patient': 'Patient instructions',
        'I-Other helpful contextual information': 'Other',
 }

def multilabel_labeler(label):
    # convert list of label names to a multi-hot array
    if isinstance(label, str):
        label = eval(label)
    label_ixs = [LABEL_TYPES.index(l) for l in label]
    label = np.zeros(len(LABEL_TYPES))
    label[label_ixs] = 1
    return label

In [12]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=data_collator
)

validation_dataloader = DataLoader(
  eval_dataset,
  shuffle=False,
  batch_size=1,
  collate_fn=eval_collate_fn
)

test_dataloader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, collate_fn=eval_collate_fn
    )

# Finetune BERT Model

In [13]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, RobertaPreTrainedModel

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, KLDivLoss
import torch.nn.functional as F

class RoBERTaForSequenceMultilabelClassification(AutoModelForSequenceClassification):
    """
        simple mod of MiniBERT to accept multilabel or binary task
        there may or may not be a class in huggingface that does this but this is here for historical reasons
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.RoBERTa = AutoModelForSequenceClassification.from_pretrained('Sedigh/RoBERTa-large-PM-M3-Voc', use_auth_token = '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')
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def add_doc_position_feature(self, config):
        self.classifier = nn.Linear(config.hidden_size+1, config.num_labels)

    def set_task(self, task):
        self.task = task

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        doc_positions=None,
    ):
        outputs = self.RoBERTa(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if doc_positions is not None and len(doc_positions) > 0:
            pooled_output = torch.cat((pooled_output, doc_positions.unsqueeze(1)), dim=1)

        logits = self.classifier(pooled_output)
        #print(logits)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:

            loss = F.binary_cross_entropy_with_logits(outputs[0], labels)

            #print(loss)
            outputs = (loss,) + outputs
        
        #print(outputs)

        return outputs




In [14]:
label_set = LABEL_TYPES
num_labels = len(label_set)

label2id = {label:ix for ix,label in enumerate(label_set)}
id2label = {ix:label for label,ix in label2id.items()}
config = AutoConfig.from_pretrained(
        'Sedigh/RoBERTa-large-PM-M3-Voc',
        num_labels=num_labels,
        finetuning_task="text_classification",
        label2id=label2id,
        id2label=id2label,
        problem_type="multi_label_classification",
        use_auth_token = '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'
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/816 [00:00<?, ?B/s]

In [15]:
import random
model = RoBERTaForSequenceMultilabelClassification.from_pretrained('Sedigh/RoBERTa-large-PM-M3-Voc', use_auth_token = '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', ignore_mismatched_sizes=True, config=config)

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set the seed value all over the place to make this reproducible.
seed_val = 11

random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

Downloading pytorch_model.bin:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at Sedigh/RoBERTa-large-PM-M3-Voc and are newly initialized because the shapes did not match:
- classifier.out_proj.weight: found shape torch.Size([2, 1024]) in the checkpoint and torch.Size([7, 1024]) in the model instantiated
- classifier.out_proj.bias: found shape torch.Size([2]) in the checkpoint and torch.Size([7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


freeze weights if needed

In [16]:
for name, params in model.named_parameters():
  if name.startswith('roberta.embeddings'):
    params.requires_grad = False
  if name.startswith('roberta.encoder'):
    params.requires_grad = False
  # if name.startswith('bert.encoder.layer.13'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.14'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.15'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.16'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.17'):
  #   params.requires_grad = True
  if name.startswith('roberta.encoder.layer.18'):
    params.requires_grad = True
  if name.startswith('roberta.encoder.layer.19'):
    params.requires_grad = True
  if name.startswith('roberta.encoder.layer.20'):
    params.requires_grad = True
  if name.startswith('robert.encoder.layer.21'):
    params.requires_grad = True
  if name.startswith('roberta.encoder.layer.22'):
    params.requires_grad = True
  if name.startswith('roberta.encoder.layer.23'):
    params.requires_grad = True
  if name.startswith('roberta.encoder.ln'):
    params.requires_grad = True

In [17]:
for name, params in model.named_parameters():
  print(name)
  print(params.shape)

roberta.embeddings.word_embeddings.weight
torch.Size([50008, 1024])
roberta.embeddings.position_embeddings.weight
torch.Size([514, 1024])
roberta.embeddings.token_type_embeddings.weight
torch.Size([1, 1024])
roberta.embeddings.LayerNorm.weight
torch.Size([1024])
roberta.embeddings.LayerNorm.bias
torch.Size([1024])
roberta.encoder.layer.0.attention.self.query.weight
torch.Size([1024, 1024])
roberta.encoder.layer.0.attention.self.query.bias
torch.Size([1024])
roberta.encoder.layer.0.attention.self.key.weight
torch.Size([1024, 1024])
roberta.encoder.layer.0.attention.self.key.bias
torch.Size([1024])
roberta.encoder.layer.0.attention.self.value.weight
torch.Size([1024, 1024])
roberta.encoder.layer.0.attention.self.value.bias
torch.Size([1024])
roberta.encoder.layer.0.attention.output.dense.weight
torch.Size([1024, 1024])
roberta.encoder.layer.0.attention.output.dense.bias
torch.Size([1024])
roberta.encoder.layer.0.attention.output.LayerNorm.weight
torch.Size([1024])
roberta.encoder.layer.0

In [18]:
import time

timestamp = time.strftime('%b_%d_%H:%M:%S', time.localtime())

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)



# Compute Predictions

In [19]:
def gather_predictions(model, loader, device, doc_position=False, save_preds=False, max_preds=1e9):
    # run through data loader and get model predictions
    with torch.no_grad():
        model.eval()
        yhat_raw = np.zeros((len(loader), 7))
        yhat = np.zeros((len(loader), 7))
        y = np.zeros((len(loader), 7))
        sentences = []
        doc_ids = []
        sent_ixs = []
        for ix, x in enumerate(loader):
            if ix >= max_preds:
                break
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            sentences.extend(x['sentences'])
            doc_ids.extend(x['doc_ids'])
            sent_ixs.extend(x['sent_ixs'])
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            #if 'token_type_ids' in x:
            #  inputs['token_type_ids'] = x['token_type_ids']
            outputs = model(**inputs)
            #print(outputs)
            loss = outputs['loss']
            pred = outputs['logits']
            pred = torch.sigmoid(pred)
            pred = pred.cpu().numpy()
            yhat_raw[ix] = pred
            yhat[ix] = np.round(pred)
            y[ix] = x['labels'].cpu().numpy()[0]
    return yhat_raw, yhat, y, sentences, doc_ids, sent_ixs

# Evaluation Metrics

In [20]:
import datetime
import time

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [21]:
from collections import defaultdict
import csv
import json
import numpy as np
import os
import sys

from sklearn.metrics import roc_curve, auc, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm


def auc_metrics(yhat_raw, y, ymic):
    if yhat_raw.shape[0] <= 1:
        return
    fpr = {}
    tpr = {}
    roc_auc = {}
    #get AUC for each label individually
    relevant_labels = []
    auc_labels = {}
    for i in range(y.shape[1]):
        #only if there are true positives for this label
        if y[:,i].sum() > 0:
            fpr[i], tpr[i], _ = roc_curve(y[:,i], yhat_raw[:,i])
            if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                auc_score = auc(fpr[i], tpr[i])
                if not np.isnan(auc_score): 
                    auc_labels["auc_%d" % i] = auc_score
                    relevant_labels.append(i)

    #macro-AUC: just average the auc scores
    aucs = []
    for i in relevant_labels:
        aucs.append(auc_labels['auc_%d' % i])
    roc_auc['auc_macro'] = np.mean(aucs)

    #micro-AUC: just look at each individual prediction
    yhatmic = yhat_raw.ravel()
    fpr["micro"], tpr["micro"], _ = roc_curve(ymic, yhatmic) 
    roc_auc["auc_micro"] = auc(fpr["micro"], tpr["micro"])

    return roc_auc

def union_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_or(yhat, y).sum(axis=axis).astype(float)

def intersect_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_and(yhat, y).sum(axis=axis).astype(float)


#########################################################################
#MACRO METRICS: calculate metric for each label and average across labels
#########################################################################

def macro_accuracy(yhat, y):
    num = intersect_size(yhat, y, 0) / (union_size(yhat, y, 0) + 1e-10)
    return np.mean(num)

def macro_precision(yhat, y):
    num = intersect_size(yhat, y, 0) / (yhat.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_recall(yhat, y):
    num = intersect_size(yhat, y, 0) / (y.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_f1(yhat, y):
    prec = macro_precision(yhat, y)
    rec = macro_recall(yhat, y)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1

##########################################################################
#MICRO METRICS: treat every prediction as an individual binary prediction
##########################################################################

def micro_accuracy(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / union_size(yhatmic, ymic, 0)

def micro_precision(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / yhatmic.sum(axis=0)

def micro_recall(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / ymic.sum(axis=0)

def micro_f1(yhatmic, ymic):
    prec = micro_precision(yhatmic, ymic)
    rec = micro_recall(yhatmic, ymic)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1


def all_macro(yhat, y):
    return macro_accuracy(yhat, y), macro_precision(yhat, y), macro_recall(yhat, y), macro_f1(yhat, y)

def all_micro(yhatmic, ymic):
    return micro_accuracy(yhatmic, ymic), micro_precision(yhatmic, ymic), micro_recall(yhatmic, ymic), micro_f1(yhatmic, ymic)


def all_metrics(yhat, y, yhat_raw=None, calc_auc=True, label_order=[]):
    """
        Inputs:
            yhat: binary predictions matrix 
            y: binary ground truth matrix
            yhat_raw: prediction scores matrix (floats)
        Outputs:
            dict holding relevant metrics
    """
    names = ["acc", "prec", "rec", "f1"]

    metrics = {}
    for ix,label in enumerate(label_order):
        metrics[f"{label2abbrev[label]}-f1"] = f1_score(y[:,ix], yhat[:,ix])

    #macro
    #print("GETTING ALL MACRO")
    macro = all_macro(yhat, y)

    #micro
    #print("GETTING ALL MICRO")
    ymic = y.ravel()
    yhatmic = yhat.ravel()
    micro = all_micro(yhatmic, ymic)

    metrics.update({names[i] + "_macro": macro[i] for i in range(len(macro))})
    metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

    #AUC
    #print("AUC")
    if yhat_raw is not None and calc_auc:
        roc_auc = auc_metrics(yhat_raw, y, ymic)
        metrics.update(roc_auc)

    return metrics

In [22]:
def evaluate(model, dv_loader, device, tokenizer, return_thresholds=False, doc_position=False):
    # apply model to validation set and compute metrics, including identifying best thresholds
    yhat_raw, yhat, y, sentences, doc_ids, sent_ixs = gather_predictions(model, dv_loader, device)


    # unbalanced metrics
    metrics = all_metrics(yhat, y, yhat_raw=yhat_raw, calc_auc=True, label_order=LABEL_TYPES)

    return metrics

In [23]:
# parameters

max_epochs=5
learning_rate=5e-6
max_steps=-1
gradient_accumulation_steps=4
eval_steps = 100


In [24]:
# Training loop
total_t0 = time.time()
tr_loss = 0.0
model.zero_grad()
model.train()
metrics_hist = defaultdict(list)
best_epoch = 0
best_step = 0
step = 0
losses = []
for epoch in range(max_epochs):
        step = 0
        t0 = time.time()
        for x in tqdm(train_dataloader):
            if max_steps > -1 and step > max_steps:
                break
            # transfer to gpu
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            #if 'token_type_ids' in x:
            #  inputs['token_type_ids'] = x['token_type_ids']
            outputs = model(**inputs)
            #print(outputs)
            loss = outputs['loss']
            pred = outputs['logits']

            # update parameters
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % gradient_accumulation_steps == 0 or (
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                len(train_dataloader) <= gradient_accumulation_steps
                and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                model.zero_grad()
                losses.append(tr_loss)
                tr_loss = 0.0


            # tensor output
            if (step + 1) % eval_steps == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), np.mean(losses[-10:]), elapsed))

            step += 1

#evaluation metrics
metrics = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics.items():
  metrics_hist[name].append(metric)

#test metrics
metrics_hist_test = defaultdict(list)
metrics_test = evaluate(model, test_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics_test.items():
  metrics_hist_test[name].append(metric)

#save model
out_dir = "/content/drive/My Drive/Deep Learning Project"
sd = model.state_dict()
torch.save(sd, out_dir + "/model_final_RoBERTa.pth")

  

  labels = torch.Tensor(labels)
  8%|▊         | 100/1198 [01:04<13:42,  1.34it/s]

  Batch    99  of  1,198. Loss: 0.09901929944753647.   Elapsed: 0:01:05.


 17%|█▋        | 200/1198 [01:58<09:06,  1.83it/s]

  Batch   199  of  1,198. Loss: 0.09432009076699614.   Elapsed: 0:01:58.


 25%|██▌       | 300/1198 [02:58<07:20,  2.04it/s]

  Batch   299  of  1,198. Loss: 0.08242242806591094.   Elapsed: 0:02:58.


 33%|███▎      | 401/1198 [03:51<05:30,  2.42it/s]

  Batch   399  of  1,198. Loss: 0.07877514567226171.   Elapsed: 0:03:51.


 42%|████▏     | 500/1198 [04:42<03:23,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.08629403784871101.   Elapsed: 0:04:42.


 50%|█████     | 600/1198 [05:40<05:47,  1.72it/s]

  Batch   599  of  1,198. Loss: 0.07078072384465486.   Elapsed: 0:05:41.


 58%|█████▊    | 700/1198 [06:39<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.057235073600895706.   Elapsed: 0:06:39.


 67%|██████▋   | 800/1198 [07:33<04:01,  1.65it/s]

  Batch   799  of  1,198. Loss: 0.06821290239458903.   Elapsed: 0:07:33.


 75%|███████▌  | 900/1198 [08:24<03:40,  1.35it/s]

  Batch   899  of  1,198. Loss: 0.054291850293520835.   Elapsed: 0:08:25.


 83%|████████▎ | 1000/1198 [09:12<01:55,  1.72it/s]

  Batch   999  of  1,198. Loss: 0.05659884014166892.   Elapsed: 0:09:13.


 92%|█████████▏| 1100/1198 [10:02<00:36,  2.66it/s]

  Batch 1,099  of  1,198. Loss: 0.054176168772391974.   Elapsed: 0:10:03.


100%|██████████| 1198/1198 [10:51<00:00,  1.84it/s]
  8%|▊         | 100/1198 [01:01<13:41,  1.34it/s]

  Batch    99  of  1,198. Loss: 0.05496748994337395.   Elapsed: 0:01:02.


 17%|█▋        | 200/1198 [01:55<09:06,  1.83it/s]

  Batch   199  of  1,198. Loss: 0.04022048485348932.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:55<07:20,  2.04it/s]

  Batch   299  of  1,198. Loss: 0.03685348975704983.   Elapsed: 0:02:55.


 33%|███▎      | 401/1198 [03:47<05:29,  2.42it/s]

  Batch   399  of  1,198. Loss: 0.03153073297580704.   Elapsed: 0:03:48.


 42%|████▏     | 500/1198 [04:39<03:24,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.041599382890854034.   Elapsed: 0:04:39.


 50%|█████     | 600/1198 [05:37<05:47,  1.72it/s]

  Batch   599  of  1,198. Loss: 0.03675141327548772.   Elapsed: 0:05:38.


 58%|█████▊    | 700/1198 [06:36<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.030255576816853137.   Elapsed: 0:06:36.


 67%|██████▋   | 800/1198 [07:30<04:01,  1.65it/s]

  Batch   799  of  1,198. Loss: 0.0366525661200285.   Elapsed: 0:07:30.


 75%|███████▌  | 900/1198 [08:21<03:40,  1.35it/s]

  Batch   899  of  1,198. Loss: 0.033057827455922964.   Elapsed: 0:08:22.


 83%|████████▎ | 1000/1198 [09:09<01:54,  1.72it/s]

  Batch   999  of  1,198. Loss: 0.03622100084321574.   Elapsed: 0:09:09.


 92%|█████████▏| 1100/1198 [09:59<00:36,  2.66it/s]

  Batch 1,099  of  1,198. Loss: 0.036062398680951444.   Elapsed: 0:09:59.


100%|██████████| 1198/1198 [10:48<00:00,  1.85it/s]
  8%|▊         | 100/1198 [01:01<13:42,  1.34it/s]

  Batch    99  of  1,198. Loss: 0.038731008529430254.   Elapsed: 0:01:02.


 17%|█▋        | 200/1198 [01:55<09:06,  1.83it/s]

  Batch   199  of  1,198. Loss: 0.032304461498279126.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:55<07:20,  2.04it/s]

  Batch   299  of  1,198. Loss: 0.03142259362211917.   Elapsed: 0:02:55.


 33%|███▎      | 401/1198 [03:48<05:30,  2.41it/s]

  Batch   399  of  1,198. Loss: 0.026612240230315366.   Elapsed: 0:03:48.


 42%|████▏     | 500/1198 [04:39<03:23,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.03230295210960321.   Elapsed: 0:04:39.


 50%|█████     | 600/1198 [05:37<05:48,  1.72it/s]

  Batch   599  of  1,198. Loss: 0.0301882273866795.   Elapsed: 0:05:38.


 58%|█████▊    | 700/1198 [06:36<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.0236859295720933.   Elapsed: 0:06:36.


 67%|██████▋   | 800/1198 [07:30<04:00,  1.65it/s]

  Batch   799  of  1,198. Loss: 0.029830284329364076.   Elapsed: 0:07:30.


 75%|███████▌  | 900/1198 [08:21<03:40,  1.35it/s]

  Batch   899  of  1,198. Loss: 0.026211817853618414.   Elapsed: 0:08:22.


 83%|████████▎ | 1000/1198 [09:09<01:55,  1.72it/s]

  Batch   999  of  1,198. Loss: 0.032009254273725675.   Elapsed: 0:09:10.


 92%|█████████▏| 1100/1198 [09:59<00:36,  2.66it/s]

  Batch 1,099  of  1,198. Loss: 0.034011500063934365.   Elapsed: 0:10:00.


100%|██████████| 1198/1198 [10:48<00:00,  1.85it/s]
  8%|▊         | 100/1198 [01:01<13:42,  1.34it/s]

  Batch    99  of  1,198. Loss: 0.036892242889734916.   Elapsed: 0:01:02.


 17%|█▋        | 201/1198 [01:55<07:22,  2.25it/s]

  Batch   199  of  1,198. Loss: 0.026841276927734727.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:55<07:21,  2.04it/s]

  Batch   299  of  1,198. Loss: 0.027082406289991923.   Elapsed: 0:02:55.


 33%|███▎      | 401/1198 [03:47<05:30,  2.41it/s]

  Batch   399  of  1,198. Loss: 0.022827820846578105.   Elapsed: 0:03:48.


 42%|████▏     | 500/1198 [04:39<03:24,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.03025229745253455.   Elapsed: 0:04:39.


 50%|█████     | 600/1198 [05:37<05:47,  1.72it/s]

  Batch   599  of  1,198. Loss: 0.028070085868239402.   Elapsed: 0:05:38.


 58%|█████▊    | 700/1198 [06:36<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.01965203319559805.   Elapsed: 0:06:36.


 67%|██████▋   | 800/1198 [07:30<04:00,  1.65it/s]

  Batch   799  of  1,198. Loss: 0.024227731704013423.   Elapsed: 0:07:30.


 75%|███████▌  | 900/1198 [08:21<03:40,  1.35it/s]

  Batch   899  of  1,198. Loss: 0.024309650843497364.   Elapsed: 0:08:22.


 83%|████████▎ | 1000/1198 [09:09<01:54,  1.72it/s]

  Batch   999  of  1,198. Loss: 0.028404901789326687.   Elapsed: 0:09:09.


 92%|█████████▏| 1100/1198 [09:59<00:36,  2.66it/s]

  Batch 1,099  of  1,198. Loss: 0.029783229934400877.   Elapsed: 0:09:59.


100%|██████████| 1198/1198 [10:48<00:00,  1.85it/s]
  8%|▊         | 100/1198 [01:01<13:41,  1.34it/s]

  Batch    99  of  1,198. Loss: 0.03301217330008512.   Elapsed: 0:01:02.


 17%|█▋        | 200/1198 [01:55<09:05,  1.83it/s]

  Batch   199  of  1,198. Loss: 0.02457050334342057.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:55<07:20,  2.04it/s]

  Batch   299  of  1,198. Loss: 0.025149927743768785.   Elapsed: 0:02:55.


 33%|███▎      | 401/1198 [03:47<05:29,  2.42it/s]

  Batch   399  of  1,198. Loss: 0.020662603473465425.   Elapsed: 0:03:48.


 42%|████▏     | 500/1198 [04:39<03:23,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.026560974700259976.   Elapsed: 0:04:39.


 50%|█████     | 600/1198 [05:37<05:47,  1.72it/s]

  Batch   599  of  1,198. Loss: 0.025543645808647854.   Elapsed: 0:05:38.


 58%|█████▊    | 700/1198 [06:36<03:28,  2.39it/s]

  Batch   699  of  1,198. Loss: 0.018758977900142783.   Elapsed: 0:06:36.


 67%|██████▋   | 800/1198 [07:30<04:01,  1.65it/s]

  Batch   799  of  1,198. Loss: 0.02169586934178369.   Elapsed: 0:07:30.


 75%|███████▌  | 900/1198 [08:21<03:41,  1.35it/s]

  Batch   899  of  1,198. Loss: 0.021901934748166242.   Elapsed: 0:08:22.


 83%|████████▎ | 1000/1198 [09:09<01:54,  1.72it/s]

  Batch   999  of  1,198. Loss: 0.026541735412320123.   Elapsed: 0:09:09.


 92%|█████████▏| 1100/1198 [09:59<00:36,  2.66it/s]

  Batch 1,099  of  1,198. Loss: 0.028140606447414028.   Elapsed: 0:10:00.


100%|██████████| 1198/1198 [10:48<00:00,  1.85it/s]


In [25]:
print(metrics)
print(metrics_test)

{'Imaging-f1': 0.6, 'Appointment-f1': 0.8419150858175248, 'Medication-f1': 0.6516464471403813, 'Procedure-f1': 0.525, 'Lab-f1': 0.5433526011560694, 'Patient instructions-f1': 0.8373831775700935, 'Other-f1': 0.051948051948051945, 'acc_macro': 0.4449604249090422, 'prec_macro': 0.714806719303696, 'rec_macro': 0.5261502009697084, 'f1_macro': 0.6061382033040715, 'acc_micro': 0.6358695652173914, 'prec_micro': 0.84346035015448, 'rec_micro': 0.7209507042253521, 'f1_micro': 0.7774086378737542, 'auc_macro': 0.9676228925295449, 'auc_micro': 0.9839726059162452}
{'Imaging-f1': 0.65625, 'Appointment-f1': 0.8427876823338737, 'Medication-f1': 0.6213151927437642, 'Procedure-f1': 0.6, 'Lab-f1': 0.6331658291457287, 'Patient instructions-f1': 0.7797397769516728, 'Other-f1': 0.25806451612903225, 'acc_macro': 0.4780386919740847, 'prec_macro': 0.8170285927056015, 'rec_macro': 0.5428274634668244, 'f1_macro': 0.6522830950307811, 'acc_micro': 0.6114698063045955, 'prec_micro': 0.8438155136268344, 'rec_micro': 0.