# Notebook for all tokcls (EBM PICO, JNLPBA, NCBI-disease, BC2GM, BC5CDR-disease, BC5CDR-chem)


## Download BLURB, install and import libs, class definitions

### Download BLURB data

In [None]:
!wget https://nlp.stanford.edu/projects/myasu/LinkBERT/data.zip
!unzip -q data.zip

### Install libraries

In [None]:
!pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install transformers==4.9.1 datasets==1.11.0 fairscale==0.4.0 wandb sklearn seqeval
!pip install ray

In [None]:
import os

"""
!pip install wandb
os.environ["WANDB_API_KEY"] = "f419b5da75121c5feb2c141a08733d99f8171dbd"
import wandb
wandb.init(project="my-test-project", entity="nomisto")
"""

os.environ["WANDB_DISABLED"] = "true"
os.environ["LOCAL_RANK"] = "-1"

### Import dependencies

In [None]:
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import ClassLabel, load_dataset, load_metric

import ray
from ray import tune
from ray.tune import JupyterNotebookReporter

import transformers
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
    DataCollatorForTokenClassification,
    HfArgumentParser,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

## Configuration
The following has to be configured for each dataset

In [None]:
trainargs = {
    "ebmnlp": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 5e-5,
        "num_train_epochs": 1,
    },
    "JNLPBA": {
        "per_device_train_batch_size": 16,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 1e-5,
        "warmup_ratio": 0.1,
        "num_train_epochs": 5,
    },
    "NCBI-disease": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 5e-5,
        "warmup_ratio": 0.1,
        "num_train_epochs": 20,
    },
    "BC2GM": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 6e-5,
        "warmup_ratio": 0.1,
        "num_train_epochs": 50,
    },
    "BC5CDR-disease": {
        "per_device_train_batch_size": 16,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 5e-5,
        "warmup_ratio": 0.1,
        "num_train_epochs": 8,
    },
    "BC5CDR-chem": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 5e-5,
        "warmup_ratio": 0.1,
        "num_train_epochs": 20,
    },
}

In [None]:
dataset_name = "JNLPBA" # one of [ebmnlp, JNLPBA, NCBI-disease, BC2GM, BC5CDR-disease, BC5CDR-chem]

data_files = {
  "train": f"data/tokcls/{dataset_name}_hf/train.json", 
  "validation": f"data/tokcls/{dataset_name}_hf/dev.json", 
  "test": f"data/tokcls/{dataset_name}_hf/test.json"
}

pad_to_max_length = False
max_seq_length = 512

training_args = TrainingArguments( # huggingface training arguments https://huggingface.co/docs/transformers/v4.16.2/en/main_classes/trainer#transformers.TrainingArguments
        output_dir=f"./runs/{dataset_name}",
        do_train=True,
        do_eval=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        max_steps=-1,
        per_device_eval_batch_size=8,
        logging_dir="./logs",
        skip_memory_metrics=True,
        report_to="none" if os.environ["WANDB_DISABLED"] == "true" else "wandb",
        logging_steps=100, # logging steps for train loss
        do_predict=True,
        load_best_model_at_end=False, # do we do this here, linkbert doesn't, would require metric_for_best_model and greater_is_better
        **trainargs.get(dataset_name)
    )


### HPO 
direction="maximize" # maximize if metric is bigger_is_better, else: minimize
n_trials = 10 # Number of trials for HPO
def hp_space_ray(trial): # Hyperparameter search space, overwriting training_args
    return {
        "learning_rate": tune.loguniform(1e-4, 1e-2),
        "num_train_epochs": tune.choice(range(1, 6)),
        #"seed": tune.choice(range(1, 41)), check with set_seed above, needed anyway?
        "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
    }

### Seeding
set_seed(training_args.seed) # Set seed before initializing model.

## Set up logging

In [None]:
logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

## Load dataset

- Load raw dataset from json files
- Set `is_regression` (BIOSSES) and `is_multiclass_binary` (HOC)
- Create list of labels 

In [None]:
# Loading a dataset from your local files.
raw_datasets = load_dataset("json", data_files=data_files)

column_names = raw_datasets["train"].column_names
features = raw_datasets["train"].features

text_column_name = "tokens"
label_column_name = "ner_tags"

# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels.
def get_label_list(labels):
    unique_labels = set()
    for label in labels:
        unique_labels = unique_labels | set(label)
    label_list = list(unique_labels)
    label_list.sort()
    return label_list

if isinstance(features[label_column_name].feature, ClassLabel):
    label_list = features[label_column_name].feature.names
    # No need to convert the labels since they are already ints.
    label_to_id = {i: i for i in range(len(label_list))}
else:
    label_list = get_label_list(raw_datasets["train"][label_column_name])
    label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)

## Initialize model, tokenizer, config

In [None]:
model_name = "michiyasunaga/BioLinkBERT-base"
# model_name = "sshleifer/tiny-distilroberta-base"

config = AutoConfig.from_pretrained(
    model_name,
    num_labels=num_labels,
    label2id=label_to_id,
    id2label={i: l for l, i in label_to_id.items()},
    finetuning_task="ner"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
## Needs to be encapsulated for hpo
def model_init():
    model = AutoModelForTokenClassification.from_pretrained(
        model_name,
        config=config
    )
    return model

## Preprocessing of dataset

In [None]:
# Padding strategy
padding = "max_length" if pad_to_max_length else False

# Tokenize all texts and align the labels with them.
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples[text_column_name],
        padding=padding,
        truncation=True,
        max_length=max_seq_length,
        # We use this argument because the texts in our dataset are lists of words (with a label for each word).
        is_split_into_words=True,
    )
    labels = []
    word_ids_list = []
    for i, label in enumerate(examples[label_column_name]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        word_ids_list.append(word_ids)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label_to_id[label[word_idx]])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx

        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    tokenized_inputs["word_ids"] = word_ids_list
    return tokenized_inputs


train_dataset = raw_datasets["train"]
with training_args.main_process_first(desc="train dataset map pre-processing"):
    train_dataset = train_dataset.map(
        tokenize_and_align_labels,
        batched=True,
        desc="Running tokenizer on train dataset",
    )

eval_dataset = raw_datasets["validation"]
with training_args.main_process_first(desc="validation dataset map pre-processing"):
    eval_dataset = eval_dataset.map(
        tokenize_and_align_labels,
        batched=True,
        desc="Running tokenizer on validation dataset",
    )

predict_dataset = raw_datasets["test"]
with training_args.main_process_first(desc="prediction dataset map pre-processing"):
    predict_dataset = predict_dataset.map(
        tokenize_and_align_labels,
        batched=True,
        desc="Running tokenizer on prediction dataset",
    )

## Init trainer

In [None]:
# Data collator
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)

# Metrics
metric = load_metric("seqeval")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    if dataset_name == "ebmnlp":
        Ps, Rs, Fs = [], [], []
        for type_name in results:
            if type_name.startswith("overall"):
                continue
            print ('type_name', type_name)
            Ps.append(results[type_name]["precision"])
            Rs.append(results[type_name]["recall"])
            Fs.append(results[type_name]["f1"])
        return {
            "macro_precision": np.mean(Ps),
            "macro_recall": np.mean(Rs),
            "macro_f1": np.mean(Fs),
        }
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

# Initialize our Trainer
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train/Eval/Test

## With HPO 

In [None]:
## needed only for google colab
# ray._private.utils.get_system_memory = lambda: psutil.virtual_memory().total

reporter = JupyterNotebookReporter(False)


best_trial = trainer.hyperparameter_search(
  direction=direction,
  backend="ray",
  hp_space=hp_space_ray,
  keep_checkpoints_num=1,
  n_trials=n_trials,
  local_dir=f"./runs/{dataset_name}/ray_results/",
  name=dataset_name,
  progress_reporter=reporter
)

### Load best model from HPO

In [None]:
def recover_checkpoint(tune_checkpoint_dir, model_name=None):
    if tune_checkpoint_dir is None or len(tune_checkpoint_dir) == 0:
        return model_name
    # Get subdirectory used for Huggingface.
    subdirs = [
        os.path.join(tune_checkpoint_dir, name)
        for name in os.listdir(tune_checkpoint_dir)
        if os.path.isdir(os.path.join(tune_checkpoint_dir, name))
    ]
    # There should only be 1 subdir.
    assert len(subdirs) == 1, subdirs
    return subdirs[0]

ray_result_dir = f"./runs/{dataset_name}/ray_results/{dataset_name}"

from ray.tune import ExperimentAnalysis
analysis = ExperimentAnalysis(ray_result_dir)
best_checkpoint = recover_checkpoint(
    analysis.get_best_trial(metric="objective",
                            mode="max" if direction=="maximize" else "min").checkpoint.value
)
best_model = AutoModelForTokenClassification.from_pretrained(
    best_checkpoint)

# Initialize our Trainer
trainer = Trainer(
    model=best_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

### Evaluate

In [None]:
metrics = trainer.evaluate(eval_dataset=eval_dataset)
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

### Predict

In [None]:
results = trainer.predict(predict_dataset, metric_key_prefix="test")
predictions = results.predictions
metrics = results.metrics
metrics["test_samples"] = len(predict_dataset)

trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log(metrics)

import json
output_dir = training_args.output_dir
output_path = f"{output_dir}/test_outputs.json"
json.dump({"predictions": results.predictions.tolist(), "label_ids": results.label_ids.tolist()},
              open(output_path, "w"))

## Simple (should not be used/only for reprocucing LinkBERT numbers)

To reproduce exact numbers

- do `model=model_init()` right after model_init function definition (in "Initialize model, tokenizer, config") and change in Trainer init `model_init=model_init` to `model=model` (Otherwise the rng is not the same with the original script)
- Install same library versions as LinkBERT

```
!pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install transformers==4.9.1 datasets==1.11.0 fairscale==0.4.0 wandb sklearn seqeval
```

### Train

In [None]:
logger.info("*** Train ***")

train_result = trainer.train()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)

trainer.save_model()  # Saves the tokenizer too for easy upload

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

### Evaluate

In [None]:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(eval_dataset=eval_dataset)
metrics["eval_samples"] = len(eval_dataset)

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

### Predict

In [None]:
logger.info("*** Predict ***")

results = trainer.predict(predict_dataset, metric_key_prefix="test")
predictions = results.predictions
metrics = results.metrics
metrics["test_samples"] = len(predict_dataset)

trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log(metrics)

import json
output_dir = training_args.output_dir
output_path = f"{output_dir}/test_outputs.json"
json.dump({"predictions": results.predictions.tolist(), "label_ids": results.label_ids.tolist()},
              open(output_path, "w"))