In [1]:
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
import pprint
import pickle
import matplotlib.pyplot as plt
import numpy as np
import torch
import pandas as pd
import itertools
import os
import datetime
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
from tqdm.notebook import trange, tqdm
from torch.utils.data import DataLoader
from torch.optim import Adam
from sklearn.metrics import f1_score
from functools import partial

In [2]:
data_dir = Path('/gpfs/data/geraslab/ekr6072/projects/study_indication/data')
data_path = data_dir / 'dataset.pkl'

In [3]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

In [4]:
with open(data_path, 'rb') as f:
  dataset = pickle.load(f)

In [5]:
def clean_dataset(dataset):
  output = {}
  for name, subset in dataset.items():
    clean_subset = []
    for data in subset:
      label = data['label']
      if label not in ['unknown']:
        clean_subset.append(data)
    output[name] = clean_subset
  return output

In [6]:
dataset = clean_dataset(dataset)

In [7]:
category2id = {
  '(high-risk) screening': 0,
  'extent of disease / pre-operative planning': 1,
  'additional workup': 2,
  '6-month follow-up / surveillance': 3,
  'exclude': 4,
  'unknown': 5,
}

In [8]:
train_texts = [data['text']['longText'] for data in dataset['train']]
train_labels = [category2id[data['label']] for data in dataset['train']]

In [9]:
val_texts = [data['text']['longText'] for data in dataset['val']]
val_labels = [category2id[data['label']] for data in dataset['val']]

In [10]:
train_encodings = tokenizer(train_texts, truncation=False, padding=False)
val_encodings = tokenizer(val_texts, truncation=False, padding=False)

In [11]:
train_encodings.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [12]:
def split_pad(data, padding_idx, max_length=512):
  num_tokens = len(data)
  data = torch.tensor(data)
  num_no_pad = (num_tokens // max_length)
  pad_token_count = num_tokens - num_no_pad * max_length
  num_sections = num_no_pad if pad_token_count == 0 else num_no_pad + 1
  output = torch.zeros((num_sections, max_length), dtype=int) + padding_idx
  for i in range(num_sections):
    if i < num_sections - 1:
      output[i, :] = data[i*max_length:(i+1)*max_length]
    else:
      final_section = data[i*max_length:]
      output[i, :len(final_section)] = final_section
  return output

In [13]:
def sliding_window_transform(encoding, dataset, padding_idx, max_length=512, return_metadata=False):  
  outputs = []
  if return_metadata: metadata = []
  for i, data in enumerate(encoding):
    output = split_pad(data, padding_idx, max_length=max_length)
    if return_metadata: metadata.extend([dataset[i]['id'] for _ in output])
    outputs.append(output)
  if return_metadata:
    metadata = pd.DataFrame({"id": metadata}).groupby("id")
    return torch.cat(outputs), [data.index.values for _, data in metadata]
  return torch.cat(outputs)

In [14]:
def sliding_window_transform(encoding, dataset, padding_idx, max_length=512, return_metadata=False):  
  outputs = []
  if return_metadata: metadata = []
  for i, data in enumerate(encoding):
    output = split_pad(data, padding_idx, max_length=max_length)
    if return_metadata: metadata.extend([dataset[i]['id'] for _ in output])
    outputs.append(output)
  if return_metadata:
    metadata = pd.DataFrame({"id": metadata})
    return torch.cat(outputs), metadata
  return torch.cat(outputs)

In [15]:
def transform_encodings(encodings, dataset, padding_idx, max_length=512):  
  input_ids, metadata = sliding_window_transform(encodings['input_ids'], dataset, padding_idx, max_length=max_length, return_metadata=True)
  token_type_ids = sliding_window_transform(encodings['token_type_ids'], dataset, 0, max_length=max_length)
  attention_mask = sliding_window_transform(encodings['attention_mask'], dataset, 0, max_length=max_length)
  return {
    "input_ids": input_ids,
    "token_type_ids": token_type_ids,
    "attention_mask": attention_mask,
    "metadata": metadata
  }

In [16]:
import torch

class IndicationDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, metadata, labels):
        self.encodings = encodings
        self.metadata = metadata
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        label = self.labels[idx]
        id = self.metadata[idx]['id']    
        item['id'] = id
        item['label'] = label
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = IndicationDataset(train_encodings, dataset['train'], train_labels)
val_dataset = IndicationDataset(val_encodings, dataset['val'], val_labels)

In [17]:
def sliding_window_transform(encoding, dataset, padding_idx, max_length=512, return_metadata=False):  
  outputs = []
  if return_metadata: metadata = []
  for i, data in enumerate(encoding):
    output = split_pad(data, padding_idx, max_length=max_length)
    if return_metadata: metadata.extend([dataset[i]['id'] for _ in output])
    outputs.append(output)
  if return_metadata:
    metadata = pd.DataFrame({"id": metadata})
    return torch.cat(outputs), metadata
  return torch.cat(outputs)

In [18]:
def collate_fn(batch, padding_idx=0, max_length=512):
  outputs = defaultdict(list)
  metadata = []
  for sample in batch:
    input_ids = split_pad(sample['input_ids'], 0, max_length=512)
    token_type_ids = split_pad(sample['token_type_ids'], 0, max_length=512)
    attention_mask = split_pad(sample['attention_mask'], 0, max_length=512)
    repeat = input_ids.shape[0]
    metadata.extend([sample['id'] for _ in range(repeat)])
    label = torch.tensor(sample['label']).repeat(repeat)

    outputs['input_ids'].append(input_ids)
    outputs['token_type_ids'].append(token_type_ids)
    outputs['attention_mask'].append(attention_mask)
    outputs['labels'].append(label)
  
  outputs = {key: torch.cat(val) for key, val in outputs.items()}
  outputs['metadata'] = pd.DataFrame({"id": metadata})
  return outputs

In [19]:
train_dataloader = DataLoader(train_dataset, 4, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, 4, shuffle=False, collate_fn=collate_fn)

In [20]:
from torch import nn

class ClinicalBERT(nn.Module):
  def __init__(self, num_classes, reduction='mean'):
      super(ClinicalBERT, self).__init__()
      self.bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
      self.linear = nn.Linear(768, num_classes)
      self.loss_func = nn.CrossEntropyLoss()
      self.reduction = reduction
  
  def forward(self, **kwargs):
    x = self.bert(input_ids=kwargs['input_ids'], attention_mask=kwargs['attention_mask'])
    logits = self.linear(x['pooler_output'])
    reduced_logits = []
    labels = []
    for _, meta in kwargs['metadata'].groupby('id'):
      indices = meta.index.values
      if self.reduction == 'mean':
        reduced_logits.append(logits[indices].mean(axis=0))
      elif self.reduction == 'max':
        reduced_logits.append(logits[indices].max(axis=0).values)
      else:
        ValueError(f'invalid reduction value {self.reduction} entered')
      label = kwargs['labels'][indices][0]
      labels.append(label)
    labels = torch.stack(labels)
    reduced_logits = torch.stack(reduced_logits)
    loss = self.loss_func(reduced_logits, labels)
    return {
      "loss": loss,
      "logits": reduced_logits,
      "labels": labels
    }

In [25]:
def epoch_iter(num_epochs, dataloader):
    steps_per_epoch = len(dataloader)
    for epoch in range(num_epochs):
      for step in range(steps_per_epoch):
        yield epoch

In [26]:
@torch.no_grad()
def eval_loop(model, dataloader, device):
    """Run validation phase."""
    model.eval()

    # Keeping track of metrics
    total_loss = 0.0
    total_correct = 0.0
    total_count = 0
    all_labels = []
    all_preds = []

    for batch in dataloader:
        batch = {key: val.to(device) if isinstance(val, torch.Tensor) else val for key, val in batch.items()}  
        outputs = model(**batch)
        loss = outputs["loss"]

        # Only count non-padding tokens
        # (Same idea as ignore_index=PAD_IDX above)
        preds = outputs['logits'].argmax(-1)
        labels = outputs['labels']
        correct_preds = (labels == preds).sum()
        all_labels.append(labels)
        all_preds.append(preds)

        # Keeping track of metrics
        total_loss += loss.item()
        total_correct += correct_preds.item()
        total_count += preds.shape[0]
    all_labels = torch.cat(all_labels).cpu()
    all_preds = torch.cat(all_preds).cpu()
    return {
        "loss": total_loss / total_count,
        "accuracy": total_correct / total_count,
        "f1_score": f1_score(all_labels, all_preds, average='macro')
    }

def train_step(optimizer, model, batch):
    """Run a single train step."""
    model.train()
    optimizer.zero_grad()
    outputs = model(**batch)
    loss = outputs["loss"]
    loss.backward()
    optimizer.step()
    return loss.item()

In [27]:
def lr_lambda(current_step: int, warmup_steps: int, total_steps: int, decay_type='linear'):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    if decay_type is None:
        return 1.0
    elif decay_type == 'linear':
        w = - 1 / (total_steps - warmup_steps)
        return (current_step - warmup_steps) * w + 1.0
    elif decay_type == 'cosine':
        w = np.pi / (total_steps - warmup_steps)
        return 0.5 * np.cos(w * (current_step - warmup_steps)) + 0.5
    else:
        raise ValueError('invalid decay_type {} entered'.format(decay_type))

In [28]:
NUM_EPOCHS = 10
model = ClinicalBERT(5)
optimizer = Adam(model.parameters(), lr=1e-4, weight_decay=1e-6)
warmup_steps = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timestamp = datetime.datetime.now()
date = timestamp.strftime("%Y%m%d")
time = timestamp.strftime("%H%M%S")
save_path = f'./results/{date}/{time}'
total_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = None if warmup_steps is None \
                    else torch.optim.lr_scheduler.LambdaLR(optimizer, partial(lr_lambda, warmup_steps=warmup_steps, 
                                                                                         total_steps=total_steps))
train_loss_list = []
writer = SummaryWriter(log_dir=os.path.join(save_path, 'tb_logs'))
model.to(device)
for step, epoch, batch in zip(range(total_steps), epoch_iter(NUM_EPOCHS, train_dataloader), itertools.cycle(train_dataloader)):
    batch = {key: val.to(device) if isinstance(val, torch.Tensor) else val for key, val in batch.items()}
    loss_val = train_step(
        optimizer=optimizer,
        model=model,
        batch=batch,
    )
    writer.add_scalar("learning_rate", optimizer.param_groups[0]['lr'], step)
    if lr_scheduler is not None:
        lr_scheduler.step()

    writer.add_scalar("epoch", epoch, step)
    writer.add_scalar("loss/train", loss_val, step)
    train_loss_list.append(loss_val)
    if step % 5 == 0 and step != 0:
        val_results = eval_loop(
            model=model,
            dataloader=val_dataloader,
            device=device
        )
        for key, value in val_results.items():
            writer.add_scalar(f"{key}/val", value, step)
        print("Step: {}/{}, val acc: {:.3f}, val f1: {:.3f}".format(
            step, 
            total_steps,
            val_results["accuracy"],
            val_results["f1_score"])
        )
writer.flush()
writer.close()

Some weights of the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Step: 5/450, val acc: 0.177, val f1: 0.133
Step: 10/450, val acc: 0.531, val f1: 0.228
Step: 15/450, val acc: 0.531, val f1: 0.228
Step: 20/450, val acc: 0.521, val f1: 0.171
Step: 25/450, val acc: 0.521, val f1: 0.171
Step: 30/450, val acc: 0.521, val f1: 0.171
Step: 35/450, val acc: 0.521, val f1: 0.171
Step: 40/450, val acc: 0.531, val f1: 0.210
Step: 45/450, val acc: 0.479, val f1: 0.212
Step: 50/450, val acc: 0.500, val f1: 0.209
Step: 55/450, val acc: 0.458, val f1: 0.196
Step: 60/450, val acc: 0.521, val f1: 0.174
Step: 65/450, val acc: 0.521, val f1: 0.174
Step: 70/450, val acc: 0.510, val f1: 0.214
Step: 75/450, val acc: 0.531, val f1: 0.193
Step: 80/450, val acc: 0.562, val f1: 0.274
Step: 85/450, val acc: 0.635, val f1: 0.345
Step: 90/450, val acc: 0.594, val f1: 0.369
Step: 95/450, val acc: 0.594, val f1: 0.306
Step: 100/450, val acc: 0.542, val f1: 0.257
Step: 105/450, val acc: 0.510, val f1: 0.197
Step: 110/450, val acc: 0.500, val f1: 0.193
Step: 115/450, val acc: 0.500,

KeyboardInterrupt: 