In [1]:
!pip install datasets transformers > /dev/null

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

Mounted at /content/drive/


In [4]:
from torch.utils.data import Dataset, DataLoader
from nltk import word_tokenize
from collections import Counter, defaultdict

class SentDataset(Dataset):
    def __init__(self, fname, word2ix=None, df=None):
        if df is None:
            self.sents = pd.read_csv(fname)
        else:
            self.sents = df.dropna(subset=['sentence'])
        # reset index necessary for getting context sentences when using unlabeled data
        #self.sents = self.sents.dropna(subset=['sentence'])
        if word2ix is not None:
            self.word2ix = word2ix
            self.ix2word = {ix:word for word,ix in self.word2ix.items()}
        else:
            self.word2ix = {'<PAD>': 0}
            self.ix2word = {}
            self.min_cnt = 0

    def __len__(self):
        return len(self.sents)

    def __getitem__(self, idx):
        row = self.sents.iloc[idx]
        if row.sentence.startswith('[') and row.sentence.endswith(']'):
            try:
                sentence = eval(row.sentence)
            except:
                sentence = word_tokenize(row.sentence)
        else:
            assert isinstance(row.sentence, str)
            sentence = word_tokenize(row.sentence)
        sent = [x.lower() for x in sentence]
        if 'labels' in row:
          labels = eval(row.labels)
        else:
          labels = []
        doc_id = row.doc_id
        sent_ix = None
        if 'sent_index' in row:
            sent_ix = row.sent_index
        elif 'sent_ix' in row:
            sent_ix = row.sent_ix
        return sent, labels, doc_id, sent_ix

In [5]:
# load dataset
import ast
def str_to_list(s):
    # Use the ast library to safely evaluate the string as a list
    return ast.literal_eval(s)

sents = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/sentence_level.csv")

# Apply the function to the column to convert the strings to lists
#sents['sentence'] = sents['sentence'].apply(str_to_list)

# Apply the function to the column to convert the strings to lists
#sents['labels'] = sents['labels'].apply(str_to_list)
train_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/train_ids.csv", header=None)
test_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/test_ids.csv", header=None)
eval_dataset = pd.read_csv("/content/drive/My Drive/Deep Learning Project/Dataset/clip-a-dataset-for-extracting-action-items-for-physicians-from-hospital-discharge-notes-1.0.0/val_ids.csv", header=None)

In [6]:
train_dataset.columns = ['doc_id']
train_dataset = train_dataset.merge(sents, on = "doc_id")
test_dataset.columns = ['doc_id']
test_dataset = test_dataset.merge(sents, on = "doc_id")
eval_dataset.columns = ['doc_id']
eval_dataset = eval_dataset.merge(sents, on = "doc_id")

In [7]:
train_dataset = SentDataset('contet',df =train_dataset)
eval_dataset = SentDataset('content',df=eval_dataset)
test_dataset = SentDataset('content',df=test_dataset)

In [8]:
tokenizer1 =  AutoTokenizer.from_pretrained('AshtonIsNotHere/GatorTron-OG')

Downloading (…)okenizer_config.json:   0%|          | 0.00/418 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/379k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

# Datasets & Dataloaders


In [9]:
batch_size = 64

In [10]:
def collator(batch, tokenizer, eval=False, doc_position=False):
    # standard bert collator that applies the right label for the training task
    sents = []
    labels = []
    doc_poses = []
    doc_ids = []
    sent_ixs = []
    for sent, label, doc_id, sent_ix in batch:
        sents.append(' '.join(sent))
        label = multilabel_labeler(label)
        labels.append(label)
        doc_ids.append(doc_id)
        sent_ixs.append(sent_ix)
    tokd = tokenizer(sents, padding=True, max_length = 512, truncation=True)
    input_ids, token_type_ids, attention_mask = tokd['input_ids'], tokd['token_type_ids'], tokd['attention_mask']
    toks = torch.LongTensor(input_ids)
    mask = torch.LongTensor(attention_mask)
    labels = torch.Tensor(labels)
    # for whatever reason, pytorch's cross entropy requires long labels but BCE doesn't
    if not eval:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}
    else:
        return {'input_ids': toks, 'attention_mask': mask, 'labels': labels, 'sentences': sents, 'doc_ids': doc_ids, 'sent_ixs': sent_ixs}

data_collator = lambda x: collator(x, tokenizer = tokenizer1)
eval_collate_fn = lambda x: collator(x, tokenizer =tokenizer1, eval=True)

In [11]:
# import torch
# import tqdm

# class BertDataset(torch.utils.data.Dataset):
#   def __init__(self, txt_list, tokenizer, max_length=512):
#     self.tokenizer = tokenizer
#     self.sents = []
#     self.labels = []
#     self.doc_poses = []
#     self.doc_ids = []
#     self.sent_ixs = []
#     self.input_ids = []
#     self.token_type_ids = []
#     self.attention_maks = []


#     for doc_id, sent_ix, sent, label in tqdm.tqdm(txt_list.values, desc="Tokenizing data"):
#         #sents.append(' '.join(sent))
#         tokd = tokenizer(sent, padding=True, max_length = 512)
#         self.input_ids.append(torch.Tensor(tokd['input_ids']))
#         self.token_type_ids.append(tokd['token_type_ids'])
#         self.attention_maks.append(torch.Tensor(tokd['attention_mask']))
#         label = multilabel_labeler(label)
#         self.labels.append(label)
#         self.doc_ids.append(doc_id)
#         self.sent_ixs.append(sent_ix)
#     self.labels = torch.Tensor(self.labels)
    
#   def __len__(self):
#     return len(self.sent_ixs)

#   def __getitem__(self, idx):
#     return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_maks[idx], 'labels': self.labels[idx], 'doc_ids': self.doc_ids[idx], 'sent_ixs': self.sent_ixs[idx]} 

In [12]:
LABEL_TYPES = ['I-Imaging-related followup',
 'I-Appointment-related followup',
 'I-Medication-related followups',
 'I-Procedure-related followup',
 'I-Lab-related followup',
 'I-Case-specific instructions for patient',
 'I-Other helpful contextual information',
 ]

label2abbrev = {'I-Imaging-related followup': 'Imaging',
        'I-Appointment-related followup': 'Appointment',
        'I-Medication-related followups': 'Medication',
        'I-Procedure-related followup': 'Procedure',
        'I-Lab-related followup': 'Lab',
        'I-Case-specific instructions for patient': 'Patient instructions',
        'I-Other helpful contextual information': 'Other',
 }

def multilabel_labeler(label):
    # convert list of label names to a multi-hot array
    if isinstance(label, str):
        label = eval(label)
    label_ixs = [LABEL_TYPES.index(l) for l in label]
    label = np.zeros(len(LABEL_TYPES))
    label[label_ixs] = 1
    return label

In [13]:

# train_dataset = BertDataset(train_dataset, tokenizer, max_length=512, collate_fn=data_collator)
# eval_dataset = BertDataset(eval_dataset, tokenizer, max_length=512, collate_fn=eval_collate_fn)

# print()
# print('{:>5,} training samples'.format(len(train_dataset)))
# print('{:>5,} validation samples'.format(len(eval_dataset)))

In [14]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(
  train_dataset,
  batch_size=batch_size,
  collate_fn=data_collator
)

validation_dataloader = DataLoader(
  eval_dataset,
  shuffle=False,
  batch_size=1,
  collate_fn=eval_collate_fn
)

test_dataloader = DataLoader(
    test_dataset, batch_size=1, shuffle=False, collate_fn=eval_collate_fn
    )

# Finetune BERT Model

In [15]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, KLDivLoss
import torch.nn.functional as F

class GatorTronForSequenceMultilabelClassification(AutoModelForSequenceClassification):
    """
        simple mod of MiniBERT to accept multilabel or binary task
        there may or may not be a class in huggingface that does this but this is here for historical reasons
    """
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.gatorTron = AutoModelForSequenceClassification.from_pretrained('AshtonIsNotHere/GatorTron-OG')
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

    def add_doc_position_feature(self, config):
        self.classifier = nn.Linear(config.hidden_size+1, config.num_labels)

    def set_task(self, task):
        self.task = task

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        doc_positions=None,
    ):
        outputs = self.gatorTron(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if doc_positions is not None and len(doc_positions) > 0:
            pooled_output = torch.cat((pooled_output, doc_positions.unsqueeze(1)), dim=1)

        logits = self.classifier(pooled_output)
        #print(logits)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(outputs[0], labels)
            #print(loss)
            outputs = (loss,) + outputs
        
        print(outputs)

        return outputs




In [16]:
label_set = LABEL_TYPES
num_labels = len(label_set)

label2id = {label:ix for ix,label in enumerate(label_set)}
id2label = {ix:label for label,ix in label2id.items()}
config = AutoConfig.from_pretrained(
        'AshtonIsNotHere/GatorTron-OG',
        num_labels=num_labels,
        finetuning_task="text_classification",
        label2id=label2id,
        id2label=id2label,
    )

Downloading (…)lve/main/config.json:   0%|          | 0.00/675 [00:00<?, ?B/s]

In [17]:
import random
model = GatorTronForSequenceMultilabelClassification.from_pretrained('AshtonIsNotHere/GatorTron-OG', config=config)

# Tell pytorch to run this model on the GPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

Downloading pytorch_model.bin:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of MegatronBertForSequenceClassification were not initialized from the model checkpoint at AshtonIsNotHere/GatorTron-OG and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


freeze weights if needed

In [18]:
for name, params in model.named_parameters():
  if name.startswith('bert.embeddings'):
    params.requires_grad = False
  if name.startswith('bert.encoder'):
    params.requires_grad = False
  # if name.startswith('bert.encoder.layer.13'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.14'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.15'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.16'):
  #   params.requires_grad = True
  # if name.startswith('bert.encoder.layer.17'):
  #   params.requires_grad = True
  if name.startswith('bert.encoder.layer.18'):
    params.requires_grad = True
  if name.startswith('bert.encoder.layer.19'):
    params.requires_grad = True
  if name.startswith('bert.encoder.layer.20'):
    params.requires_grad = True
  if name.startswith('bert.encoder.layer.21'):
    params.requires_grad = True
  if name.startswith('bert.encoder.layer.22'):
    params.requires_grad = True
  if name.startswith('bert.encoder.layer.23'):
    params.requires_grad = True
  if name.startswith('bert.encoder.ln'):
    params.requires_grad = True

In [19]:
for name, params in model.named_parameters():
  print(name)
  print(params.requires_grad)

bert.embeddings.word_embeddings.weight
False
bert.embeddings.position_embeddings.weight
False
bert.embeddings.token_type_embeddings.weight
False
bert.encoder.layer.0.attention.ln.weight
False
bert.encoder.layer.0.attention.ln.bias
False
bert.encoder.layer.0.attention.self.query.weight
False
bert.encoder.layer.0.attention.self.query.bias
False
bert.encoder.layer.0.attention.self.key.weight
False
bert.encoder.layer.0.attention.self.key.bias
False
bert.encoder.layer.0.attention.self.value.weight
False
bert.encoder.layer.0.attention.self.value.bias
False
bert.encoder.layer.0.attention.output.dense.weight
False
bert.encoder.layer.0.attention.output.dense.bias
False
bert.encoder.layer.0.ln.weight
False
bert.encoder.layer.0.ln.bias
False
bert.encoder.layer.0.intermediate.dense.weight
False
bert.encoder.layer.0.intermediate.dense.bias
False
bert.encoder.layer.0.output.dense.weight
False
bert.encoder.layer.0.output.dense.bias
False
bert.encoder.layer.1.attention.ln.weight
False
bert.encoder.lay

In [20]:
import time

timestamp = time.strftime('%b_%d_%H:%M:%S', time.localtime())

# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)



# Compute Predictions

In [27]:
def gather_predictions(model, loader, device, doc_position=False, save_preds=False, max_preds=1e9):
    # run through data loader and get model predictions
    with torch.no_grad():
        model.eval()
        yhat_raw = np.zeros((len(loader), 7))
        yhat = np.zeros((len(loader), 7))
        y = np.zeros((len(loader), 7))
        sentences = []
        doc_ids = []
        sent_ixs = []
        for ix, x in enumerate(loader):
            if ix >= max_preds:
                break
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            sentences.extend(x['sentences'])
            doc_ids.extend(x['doc_ids'])
            sent_ixs.extend(x['sent_ixs'])
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            outputs = model(**inputs)
            loss = outputs['loss']
            pred = outputs['logits']
            pred = torch.sigmoid(pred)
            pred = pred.cpu().numpy()
            yhat_raw[ix] = pred
            yhat[ix] = np.round(pred)
            y[ix] = x['labels'].cpu().numpy()[0]
    return yhat_raw, yhat, y, sentences, doc_ids, sent_ixs

# Evaluation Metrics

In [22]:
import datetime
import time

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [23]:
from collections import defaultdict
import csv
import json
import numpy as np
import os
import sys

from sklearn.metrics import roc_curve, auc, precision_score, recall_score, f1_score, roc_auc_score, accuracy_score
from tqdm import tqdm


def auc_metrics(yhat_raw, y, ymic):
    if yhat_raw.shape[0] <= 1:
        return
    fpr = {}
    tpr = {}
    roc_auc = {}
    #get AUC for each label individually
    relevant_labels = []
    auc_labels = {}
    for i in range(y.shape[1]):
        #only if there are true positives for this label
        if y[:,i].sum() > 0:
            fpr[i], tpr[i], _ = roc_curve(y[:,i], yhat_raw[:,i])
            if len(fpr[i]) > 1 and len(tpr[i]) > 1:
                auc_score = auc(fpr[i], tpr[i])
                if not np.isnan(auc_score): 
                    auc_labels["auc_%d" % i] = auc_score
                    relevant_labels.append(i)

    #macro-AUC: just average the auc scores
    aucs = []
    for i in relevant_labels:
        aucs.append(auc_labels['auc_%d' % i])
    roc_auc['auc_macro'] = np.mean(aucs)

    #micro-AUC: just look at each individual prediction
    yhatmic = yhat_raw.ravel()
    fpr["micro"], tpr["micro"], _ = roc_curve(ymic, yhatmic) 
    roc_auc["auc_micro"] = auc(fpr["micro"], tpr["micro"])

    return roc_auc

def union_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_or(yhat, y).sum(axis=axis).astype(float)

def intersect_size(yhat, y, axis):
    #axis=0 for label-level union (macro). axis=1 for instance-level
    return np.logical_and(yhat, y).sum(axis=axis).astype(float)


#########################################################################
#MACRO METRICS: calculate metric for each label and average across labels
#########################################################################

def macro_accuracy(yhat, y):
    num = intersect_size(yhat, y, 0) / (union_size(yhat, y, 0) + 1e-10)
    return np.mean(num)

def macro_precision(yhat, y):
    num = intersect_size(yhat, y, 0) / (yhat.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_recall(yhat, y):
    num = intersect_size(yhat, y, 0) / (y.sum(axis=0) + 1e-10)
    return np.mean(num)

def macro_f1(yhat, y):
    prec = macro_precision(yhat, y)
    rec = macro_recall(yhat, y)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1

##########################################################################
#MICRO METRICS: treat every prediction as an individual binary prediction
##########################################################################

def micro_accuracy(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / union_size(yhatmic, ymic, 0)

def micro_precision(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / yhatmic.sum(axis=0)

def micro_recall(yhatmic, ymic):
    return intersect_size(yhatmic, ymic, 0) / ymic.sum(axis=0)

def micro_f1(yhatmic, ymic):
    prec = micro_precision(yhatmic, ymic)
    rec = micro_recall(yhatmic, ymic)
    if prec + rec == 0:
        f1 = 0.
    else:
        f1 = 2*(prec*rec)/(prec+rec)
    return f1


def all_macro(yhat, y):
    return macro_accuracy(yhat, y), macro_precision(yhat, y), macro_recall(yhat, y), macro_f1(yhat, y)

def all_micro(yhatmic, ymic):
    return micro_accuracy(yhatmic, ymic), micro_precision(yhatmic, ymic), micro_recall(yhatmic, ymic), micro_f1(yhatmic, ymic)


def all_metrics(yhat, y, yhat_raw=None, calc_auc=True, label_order=[]):
    """
        Inputs:
            yhat: binary predictions matrix 
            y: binary ground truth matrix
            yhat_raw: prediction scores matrix (floats)
        Outputs:
            dict holding relevant metrics
    """
    names = ["acc", "prec", "rec", "f1"]

    metrics = {}
    for ix,label in enumerate(label_order):
        metrics[f"{label2abbrev[label]}-f1"] = f1_score(y[:,ix], yhat[:,ix])

    #macro
    #print("GETTING ALL MACRO")
    macro = all_macro(yhat, y)

    #micro
    #print("GETTING ALL MICRO")
    ymic = y.ravel()
    yhatmic = yhat.ravel()
    micro = all_micro(yhatmic, ymic)

    metrics.update({names[i] + "_macro": macro[i] for i in range(len(macro))})
    metrics.update({names[i] + "_micro": micro[i] for i in range(len(micro))})

    #AUC
    #print("AUC")
    if yhat_raw is not None and calc_auc:
        roc_auc = auc_metrics(yhat_raw, y, ymic)
        metrics.update(roc_auc)

    return metrics

In [24]:
def evaluate(model, dv_loader, device, tokenizer, return_thresholds=False, doc_position=False):
    # apply model to validation set and compute metrics, including identifying best thresholds
    yhat_raw, yhat, y, sentences, doc_ids, sent_ixs = gather_predictions(model, dv_loader, device)


    # unbalanced metrics
    metrics = all_metrics(yhat, y, yhat_raw=yhat_raw, calc_auc=True, label_order=LABEL_TYPES)

    return metrics

In [25]:
# parameters

max_epochs=5
learning_rate=5e-5
max_steps=-1
gradient_accumulation_steps=4
eval_steps = 100


In [26]:
# Training loop
total_t0 = time.time()
tr_loss = 0.0
model.zero_grad()
model.train()
metrics_hist = defaultdict(list)
best_epoch = 0
best_step = 0
step = 0
losses = []
for epoch in range(max_epochs):
        step = 0
        t0 = time.time()
        for x in tqdm(train_dataloader):
            if max_steps > -1 and step > max_steps:
                break
            # transfer to gpu
            for k, v in x.items():
                if isinstance(v, torch.Tensor):
                    x[k] = v.to(device)
            inputs = {'input_ids': x['input_ids'], 'attention_mask': x['attention_mask'], 'labels': x['labels']}
            if 'token_type_ids' in x:
              inputs['token_type_ids'] = x['token_type_ids']
            
            outputs = model(**inputs)
            loss = outputs['loss']
            pred = outputs['logits']

            # update parameters
            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps
            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % gradient_accumulation_steps == 0 or (
                # last step in epoch but step is always smaller than gradient_accumulation_steps
                len(train_dataloader) <= gradient_accumulation_steps
                and (step + 1) == len(train_dataloader)
            ):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                model.zero_grad()
                losses.append(tr_loss)
                tr_loss = 0.0


            # tensor output
            if (step + 1) % eval_steps == 0:
              elapsed = format_time(time.time() - t0)
              print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), np.mean(losses[-10:]), elapsed))

            step += 1

#evaluation metrics
metrics = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics.items():
  metrics_hist[name].append(metric)

#test metrics
metrics_hist_test = defaultdict(list)
metrics_test = evaluate(model, test_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics_test.items():
  metrics_hist_test[name].append(metric)

#save model
out_dir = "/content/drive/My Drive/Deep Learning Project"
sd = model.state_dict()
torch.save(sd, out_dir + "/model_final_GatorTron.pth")

  

  labels = torch.Tensor(labels)
  8%|▊         | 100/1198 [01:04<13:58,  1.31it/s]

  Batch    99  of  1,198. Loss: 0.21516234949231147.   Elapsed: 0:01:05.


 17%|█▋        | 201/1198 [01:58<07:16,  2.28it/s]

  Batch   199  of  1,198. Loss: 0.09060258795507252.   Elapsed: 0:01:58.


 25%|██▌       | 300/1198 [02:58<07:25,  2.01it/s]

  Batch   299  of  1,198. Loss: 0.05997060730587691.   Elapsed: 0:02:58.


 33%|███▎      | 401/1198 [03:50<05:27,  2.43it/s]

  Batch   399  of  1,198. Loss: 0.04441649419022724.   Elapsed: 0:03:50.


 42%|████▏     | 500/1198 [04:41<03:24,  3.41it/s]

  Batch   499  of  1,198. Loss: 0.04455782260047272.   Elapsed: 0:04:41.


 50%|█████     | 600/1198 [05:38<05:43,  1.74it/s]

  Batch   599  of  1,198. Loss: 0.03930751067819074.   Elapsed: 0:05:39.


 58%|█████▊    | 700/1198 [06:36<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.029923554800916463.   Elapsed: 0:06:37.


 67%|██████▋   | 800/1198 [07:30<03:57,  1.67it/s]

  Batch   799  of  1,198. Loss: 0.03520517628639937.   Elapsed: 0:07:30.


 75%|███████▌  | 900/1198 [08:21<03:37,  1.37it/s]

  Batch   899  of  1,198. Loss: 0.030607649561716244.   Elapsed: 0:08:21.


 83%|████████▎ | 1000/1198 [09:08<01:52,  1.77it/s]

  Batch   999  of  1,198. Loss: 0.0355515445000492.   Elapsed: 0:09:08.


 92%|█████████▏| 1100/1198 [09:57<00:37,  2.65it/s]

  Batch 1,099  of  1,198. Loss: 0.0335644704697188.   Elapsed: 0:09:58.


100%|██████████| 1198/1198 [10:46<00:00,  1.85it/s]
  8%|▊         | 100/1198 [01:01<13:57,  1.31it/s]

  Batch    99  of  1,198. Loss: 0.03827093391446397.   Elapsed: 0:01:02.


 17%|█▋        | 201/1198 [01:55<07:15,  2.29it/s]

  Batch   199  of  1,198. Loss: 0.028936773631721736.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:54<07:25,  2.02it/s]

  Batch   299  of  1,198. Loss: 0.026755639151087963.   Elapsed: 0:02:54.


 33%|███▎      | 401/1198 [03:46<05:28,  2.43it/s]

  Batch   399  of  1,198. Loss: 0.022948775021359325.   Elapsed: 0:03:46.


 42%|████▏     | 500/1198 [04:37<03:24,  3.41it/s]

  Batch   499  of  1,198. Loss: 0.028313687400077468.   Elapsed: 0:04:38.


 50%|█████     | 600/1198 [05:35<05:43,  1.74it/s]

  Batch   599  of  1,198. Loss: 0.027636762789916246.   Elapsed: 0:05:35.


 58%|█████▊    | 700/1198 [06:33<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.02008620561682619.   Elapsed: 0:06:33.


 67%|██████▋   | 800/1198 [07:26<03:57,  1.68it/s]

  Batch   799  of  1,198. Loss: 0.025511273767915555.   Elapsed: 0:07:27.


 75%|███████▌  | 900/1198 [08:17<03:37,  1.37it/s]

  Batch   899  of  1,198. Loss: 0.022106173672364095.   Elapsed: 0:08:17.


 83%|████████▎ | 1000/1198 [09:04<01:52,  1.77it/s]

  Batch   999  of  1,198. Loss: 0.02598128249810543.   Elapsed: 0:09:05.


 92%|█████████▏| 1100/1198 [09:54<00:37,  2.64it/s]

  Batch 1,099  of  1,198. Loss: 0.02670395020104479.   Elapsed: 0:09:54.


100%|██████████| 1198/1198 [10:43<00:00,  1.86it/s]
  8%|▊         | 100/1198 [01:01<13:56,  1.31it/s]

  Batch    99  of  1,198. Loss: 0.03089250106277177.   Elapsed: 0:01:02.


 17%|█▋        | 201/1198 [01:55<07:15,  2.29it/s]

  Batch   199  of  1,198. Loss: 0.022768133229692466.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:54<07:25,  2.02it/s]

  Batch   299  of  1,198. Loss: 0.022022197546903044.   Elapsed: 0:02:54.


 33%|███▎      | 401/1198 [03:46<05:28,  2.43it/s]

  Batch   399  of  1,198. Loss: 0.018891806169995105.   Elapsed: 0:03:46.


 42%|████▏     | 500/1198 [04:37<03:24,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.023527589600416832.   Elapsed: 0:04:38.


 50%|█████     | 600/1198 [05:35<05:43,  1.74it/s]

  Batch   599  of  1,198. Loss: 0.02218239050998818.   Elapsed: 0:05:35.


 58%|█████▊    | 700/1198 [06:33<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.016056691134872382.   Elapsed: 0:06:33.


 67%|██████▋   | 800/1198 [07:26<03:57,  1.67it/s]

  Batch   799  of  1,198. Loss: 0.02068286201392766.   Elapsed: 0:07:27.


 75%|███████▌  | 900/1198 [08:17<03:37,  1.37it/s]

  Batch   899  of  1,198. Loss: 0.019080042027053424.   Elapsed: 0:08:17.


 83%|████████▎ | 1000/1198 [09:04<01:52,  1.77it/s]

  Batch   999  of  1,198. Loss: 0.021430471933854278.   Elapsed: 0:09:05.


 92%|█████████▏| 1100/1198 [09:54<00:37,  2.64it/s]

  Batch 1,099  of  1,198. Loss: 0.023568763499497436.   Elapsed: 0:09:54.


100%|██████████| 1198/1198 [10:43<00:00,  1.86it/s]
  8%|▊         | 100/1198 [01:01<13:57,  1.31it/s]

  Batch    99  of  1,198. Loss: 0.025728059202083386.   Elapsed: 0:01:02.


 17%|█▋        | 201/1198 [01:55<07:15,  2.29it/s]

  Batch   199  of  1,198. Loss: 0.020470541635586415.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:54<07:26,  2.01it/s]

  Batch   299  of  1,198. Loss: 0.017555924426415005.   Elapsed: 0:02:54.


 33%|███▎      | 401/1198 [03:46<05:28,  2.43it/s]

  Batch   399  of  1,198. Loss: 0.01680049791757483.   Elapsed: 0:03:46.


 42%|████▏     | 500/1198 [04:37<03:24,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.020002602305612526.   Elapsed: 0:04:38.


 50%|█████     | 600/1198 [05:35<05:43,  1.74it/s]

  Batch   599  of  1,198. Loss: 0.019463209678360727.   Elapsed: 0:05:35.


 58%|█████▊    | 700/1198 [06:33<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.014322936310782098.   Elapsed: 0:06:33.


 67%|██████▋   | 800/1198 [07:26<03:57,  1.68it/s]

  Batch   799  of  1,198. Loss: 0.01826004815811757.   Elapsed: 0:07:27.


 75%|███████▌  | 900/1198 [08:17<03:37,  1.37it/s]

  Batch   899  of  1,198. Loss: 0.014993847302685026.   Elapsed: 0:08:17.


 83%|████████▎ | 1000/1198 [09:04<01:51,  1.77it/s]

  Batch   999  of  1,198. Loss: 0.018734494037926198.   Elapsed: 0:09:05.


 92%|█████████▏| 1100/1198 [09:54<00:37,  2.64it/s]

  Batch 1,099  of  1,198. Loss: 0.02095568108052248.   Elapsed: 0:09:54.


100%|██████████| 1198/1198 [10:43<00:00,  1.86it/s]
  8%|▊         | 100/1198 [01:01<13:57,  1.31it/s]

  Batch    99  of  1,198. Loss: 0.022711331449681894.   Elapsed: 0:01:02.


 17%|█▋        | 201/1198 [01:55<07:15,  2.29it/s]

  Batch   199  of  1,198. Loss: 0.016319572924112437.   Elapsed: 0:01:55.


 25%|██▌       | 300/1198 [02:54<07:26,  2.01it/s]

  Batch   299  of  1,198. Loss: 0.014710085769183933.   Elapsed: 0:02:54.


 33%|███▎      | 401/1198 [03:46<05:29,  2.42it/s]

  Batch   399  of  1,198. Loss: 0.014051107550767484.   Elapsed: 0:03:46.


 42%|████▏     | 500/1198 [04:37<03:24,  3.42it/s]

  Batch   499  of  1,198. Loss: 0.017383412989147473.   Elapsed: 0:04:38.


 50%|█████     | 600/1198 [05:35<05:43,  1.74it/s]

  Batch   599  of  1,198. Loss: 0.016766855340392794.   Elapsed: 0:05:35.


 58%|█████▊    | 700/1198 [06:33<03:27,  2.40it/s]

  Batch   699  of  1,198. Loss: 0.012025385278684552.   Elapsed: 0:06:33.


 67%|██████▋   | 800/1198 [07:26<03:57,  1.68it/s]

  Batch   799  of  1,198. Loss: 0.014889482178114121.   Elapsed: 0:07:27.


 75%|███████▌  | 900/1198 [08:17<03:37,  1.37it/s]

  Batch   899  of  1,198. Loss: 0.012261831294017611.   Elapsed: 0:08:18.


 83%|████████▎ | 1000/1198 [09:04<01:52,  1.77it/s]

  Batch   999  of  1,198. Loss: 0.015262183621962322.   Elapsed: 0:09:05.


 92%|█████████▏| 1100/1198 [09:54<00:37,  2.64it/s]

  Batch 1,099  of  1,198. Loss: 0.018224094752804378.   Elapsed: 0:09:54.


100%|██████████| 1198/1198 [10:43<00:00,  1.86it/s]


TypeError: ignored

In [28]:
#evaluation metrics
metrics = evaluate(model, validation_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics.items():
  metrics_hist[name].append(metric)

#test metrics
metrics_hist_test = defaultdict(list)
metrics_test = evaluate(model, test_dataloader, device, tokenizer=tokenizer1)
for name, metric in metrics_test.items():
  metrics_hist_test[name].append(metric)

#save model
out_dir = "/content/drive/My Drive/Deep Learning Project"
sd = model.state_dict()
torch.save(sd, out_dir + "/model_final_GatorTron.pth")

In [29]:
print(metrics)
print(metrics_test)

{'Imaging-f1': 0.6181818181818182, 'Appointment-f1': 0.8465266558966076, 'Medication-f1': 0.6959247648902821, 'Procedure-f1': 0.7179487179487181, 'Lab-f1': 0.637837837837838, 'Patient instructions-f1': 0.842249657064472, 'Other-f1': 0.24242424242424243, 'acc_macro': 0.5155127112888321, 'prec_macro': 0.7400319789561964, 'rec_macro': 0.6130479466846751, 'f1_macro': 0.6705813552960881, 'acc_micro': 0.6611049314052652, 'prec_micro': 0.8075181159420289, 'rec_micro': 0.7847711267605634, 'f1_micro': 0.7959821428571427, 'auc_macro': 0.9714287964205433, 'auc_micro': 0.9869902150215957}
{'Imaging-f1': 0.6999999999999998, 'Appointment-f1': 0.8676470588235294, 'Medication-f1': 0.6720977596741343, 'Procedure-f1': 0.6428571428571428, 'Lab-f1': 0.7111111111111111, 'Patient instructions-f1': 0.7734447539461466, 'Other-f1': 0.2260869565217391, 'acc_macro': 0.5134674992307436, 'prec_macro': 0.7601856257221022, 'rec_macro': 0.59417856464969, 'f1_macro': 0.6670081905146812, 'acc_micro': 0.6275082086829624