#Initializations

In [None]:
import torch
from tqdm import tqdm
from google.colab import auth
auth.authenticate_user()

!pip install -q transformers datasets seqeval

# Copy data from google cloud

In [None]:
!mkdir i2b2_2014
!gsutil -m -q cp -r gs://deid-data/i2b2_2014/train/ i2b2_2014/
!gsutil -m -q cp -r gs://deid-data/i2b2_2014/test/ i2b2_2014/

# Mount Google bucket

In [None]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

In [None]:
!mkdir googleBucketFolder
!gcsfuse --implicit-dirs deid-data googleBucketFolder

# Clone repo

In [None]:
!git clone https://<user>:<token>@github.com/alistairewj/transformer-deid.git

In [None]:
cd transformer-deid

/content/transformer-deid


In [None]:
from datetime import datetime
import logging
from pathlib import Path
import os
import json
import random

import numpy as np

from transformers import DistilBertTokenizerFast
from transformers import DistilBertForTokenClassification, BertForTokenClassification
from transformers import Trainer, TrainingArguments
from datasets import load_metric

# local packages
from transformer_deid.data import DeidDataset, DeidTask
from transformer_deid.evaluation import compute_metrics
from transformer_deid.tokenization import assign_tags, encode_tags, split_sequences
from transformer_deid.utils import convert_dict_to_native_types

logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    level=logging.INFO
)
logger = logging.getLogger(__name__)


In [None]:
import random
def seed_everything(seed: int):    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

# Load data

In [None]:
# specify dataset arguments
task_name = 'i2b2_2014'
split_long_sequences = True
label_transform = 'base'

deid_task = DeidTask(
    task_name,
    #data_dir=f'/home/alistairewj/git/deid-gs/{task_name}',
    data_dir=f'../{task_name}',
    label_transform=label_transform
)

train_texts, train_labels = deid_task.train['text'], deid_task.train['ann']
split_idx = int(0.8 * len(train_texts))
val_texts, val_labels = train_texts[split_idx:], train_labels[split_idx:]
train_texts, train_labels = train_texts[:split_idx], train_labels[:split_idx]
test_texts, test_labels = deid_task.test['text'], deid_task.test['ann']

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')


# Data preprocessing

In [None]:

# split text/labels into multiple examples
# (1) tokenize text
# (2) identify split points
# (3) output text as it was originally
if split_long_sequences:
    train_texts, train_labels = split_sequences(
        tokenizer, train_texts, train_labels
    )
    val_texts, val_labels = split_sequences(
        tokenizer, val_texts, val_labels
    )
    test_texts, test_labels = split_sequences(
        tokenizer, test_texts, test_labels
    )

# look at one element of train encodings: transformers.tokenization_utils_base.BatchEncoding
train_encodings = tokenizer(
    train_texts,
    is_split_into_words=False,
    return_offsets_mapping=True,
    padding=True,
    truncation=True
)
val_encodings = tokenizer(
    val_texts,
    is_split_into_words=False,
    return_offsets_mapping=True,
    padding=True,
    truncation=True
)  
test_encodings = tokenizer(
    test_texts,
    is_split_into_words=False,
    return_offsets_mapping=True,
    padding=True,
    truncation=True
)

# use the offset mappings in train_encodings to assign labels to tokens
train_tags = assign_tags(train_encodings, train_labels)
val_tags = assign_tags(val_encodings, val_labels)
test_tags = assign_tags(test_encodings, test_labels)

# encodings are dicts with three elements:
#   'input_ids', 'attention_mask', 'offset_mapping'
# these are used as kwargs to model training later
train_tags = encode_tags(train_tags, train_encodings, deid_task.label2id)
val_tags = encode_tags(val_tags, val_encodings, deid_task.label2id)
test_tags = encode_tags(test_tags, test_encodings, deid_task.label2id)

# prepare a dataset compatible with Trainer module
train_encodings.pop("offset_mapping")
val_encodings.pop("offset_mapping")
test_encodings.pop("offset_mapping")
train_dataset = DeidDataset(train_encodings, train_tags)
val_dataset = DeidDataset(val_encodings, val_tags)
test_dataset = DeidDataset(test_encodings, test_tags)


# Train transformer (skip if loading model)

In [None]:
model = DistilBertForTokenClassification.from_pretrained(
    'distilbert-base-cased', num_labels=len(deid_task.labels)
)

epochs = 1
train_batch_size = 8
out_dir = f'../googleBucketFolder/DistilBERTresults{epochs}'

training_args = TrainingArguments(
    output_dir=out_dir,
    num_train_epochs=epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_strategy='epoch'
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

logger.info("***** Running training *****")
logger.info("  Num examples = %d", len(train_dataset))
logger.info("  Num Epochs = %d", training_args.num_train_epochs)

trainer.train()

save_location = f'{out_dir}/{task_name}_DistilBert_Model_{epochs}'

trainer.save_model(save_location)

trainer.evaluate()


# Run dataset through model

In [None]:
# predictions, labels, _ = trainer.predict(test_dataset)
# predicted_label = np.argmax(predictions, axis=2)

# Eval

In [None]:
import gspread
from google.auth import default
creds, _ = default()

gc = gspread.authorize(creds)

worksheet = gc.open_by_url('https://docs.google.com/spreadsheets/d/1tc_8g2cqBdt6zEobvMUursLzKtWAn0GBCL-iwGvjudA/edit?usp=sharing').worksheet("distilbert")

multi_class_fields = ['AGEprecision', 'AGErecall', 'AGEf1', 'AGEnumber', 'CONTACTprecision', 'CONTACTrecall', 'CONTACTf1', 'CONTACTnumber', 'DATEprecision', 'DATErecall', 'DATEf1', 'DATEnumber', 'IDprecision', 'IDrecall', 'IDf1', 'IDnumber', 'LOCATIONprecision', 'LOCATIONrecall', 'LOCATIONf1', 'LOCATIONnumber', 'NAMEprecision', 'NAMErecall', 'NAMEf1', 'NAMEnumber', 'PROFESSIONprecision', 'PROFESSIONrecall', 'PROFESSIONf1', 'PROFESSIONnumber', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy']
binary_fields = ['PHIprecision', 'PHIrecall', 'PHIf1', 'PHInumber', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy']


In [None]:
def flatten_dict(d):
    out = {}
    for key in d:
        if type(d[key]) is dict:
            child = flatten_dict(d[key])
            for child_key in child:
                val = child[child_key]
                if isinstance(val, np.int64):
                    val = int(val)
                out[key + child_key] = val
        else:
            out[key] = d[key]
    return out

In [None]:
def add_row(epochs, results_multiclass, results_binary, multi_class_fields, binary_fields):
    """
    Add row to worksheet
    fields: [epochs] + multi_class_fields + binary_fields
    """

    row = [epochs] + [flatten_dict(results_multiclass).get(field) for field in multi_class_fields] + [flatten_dict(results_binary).get(field) for field in binary_fields]

    worksheet.append_row(row, table_range='A1')

In [None]:
import pprint
metric_dir = "transformer_deid/token_evaluation.py"
metric = load_metric(metric_dir)

# Evaluate every checkpoint

In [None]:
import math
def eval_checkpoints(path):
    step = int(path.split('-')[-1])
    steps_per_epoch = math.ceil(len(train_dataset) / train_batch_size)
    epoch = step / steps_per_epoch
    model = DistilBertForTokenClassification.from_pretrained(path, num_labels=len(deid_task.labels))

    model.eval()

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )
    
    predictions, labels, _ = trainer.predict(test_dataset)
    predicted_label = np.argmax(predictions, axis=2)

    results_multiclass = compute_metrics(
        predicted_label, labels, deid_task.labels, metric=metric
    )
    results_binary = compute_metrics(
        predicted_label, labels, deid_task.labels, metric=metric, binary_evaluation=True
    )
    add_row(epoch, results_multiclass, results_binary, multi_class_fields, binary_fields)

In [None]:
root = f'../googleBucketFolder/DistilBERTresults{epochs}'
checkpoints = [
               item for item in os.listdir(root)
               if 'checkpoint' in item and os.path.isdir(os.path.join(root, item))
               ]
for item in tqdm(sorted(checkpoints, key=lambda x: int(x.split('-')[1]))):
    eval_checkpoints(os.path.join(root, item))