# Notebook for all seqcls (PubMedQA, BioASQ, BIOSSES, HOC, ChemProt, DDI, GAD)

## Download BLURB, install and import libs, class definitions

### Download BLURB data

In [None]:
!wget https://nlp.stanford.edu/projects/myasu/LinkBERT/data.zip
!unzip -q data.zip

### Install libraries

In [None]:
!pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install transformers==4.9.1 datasets==1.11.0 fairscale==0.4.0 wandb sklearn seqeval
!pip install ray

In [None]:
import os

"""
!pip install wandb
os.environ["WANDB_API_KEY"] = "f419b5da75121c5feb2c141a08733d99f8171dbd"
import wandb
wandb.init(project="my-test-project", entity="nomisto")
"""

os.environ["WANDB_DISABLED"] = "true"
os.environ["LOCAL_RANK"] = "-1"

### Import dependencies

In [None]:
from transformers import Trainer, is_torch_tpu_available
from transformers.trainer_utils import PredictionOutput

import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional

import datasets
import numpy as np
from datasets import load_dataset, load_metric

import ray
from ray import tune
from ray.tune import JupyterNotebookReporter

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

### Define SeqClsTrainer
Copied straight from https://raw.githubusercontent.com/michiyasunaga/LinkBERT/main/src/seqcls/trainer_seqcls.py

In [None]:
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""

class SeqClsTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.evaluation_loop
        output = eval_loop(
                eval_dataloader,
                description="Evaluation",
                prediction_loss_only=None,
                ignore_keys=ignore_keys,
        )
        # self.label_names = label_names
        self.compute_metrics = compute_metrics

        # metrics = output.metrics
        metrics = self.compute_metrics(output, eval_dataset)
        metrics[f"{metric_key_prefix}_loss"] = output.metrics[f"{metric_key_prefix}_loss"]

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        self.log(metrics)

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
        return metrics

    def predict(self, predict_dataset, ignore_keys=None, metric_key_prefix: str = "test"):
        predict_dataloader = self.get_test_dataloader(predict_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.evaluation_loop
        output = eval_loop(
            predict_dataloader,
            description="Prediction",
            prediction_loss_only=None,
            ignore_keys=ignore_keys,
        )

        # self.label_names = label_names
        self.compute_metrics = compute_metrics

        # metrics = output.metrics
        metrics = self.compute_metrics(output, predict_dataset)

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        self.log(metrics) #Added

        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=metrics)


## Configuration

In [None]:
### Paths and metric names of datasets
datasets = {
    "PubMedQA": {
        "metric_name": "accuracy",
        "data_files": {
          "train": "data/seqcls/pubmedqa_hf/train.json", 
          "validation": "data/seqcls/pubmedqa_hf/dev.json", 
          "test": "data/seqcls/pubmedqa_hf/test.json"
        },
        "max_seq_length": 512,
    },
    "BioASQ": {
        "metric_name": "accuracy",
        "data_files": {
          "train": "data/seqcls/bioasq_hf/train.json", 
          "validation": "data/seqcls/bioasq_hf/dev.json", 
          "test": "data/seqcls/bioasq_hf/test.json"
        },
        "max_seq_length": 512,
    },
    "Biosses": {
        "metric_name": "pearsonr",
        "data_files": {
          "train": "data/seqcls/BIOSSES_hf/train.json", 
          "validation": "data/seqcls/BIOSSES_hf/dev.json", 
          "test": "data/seqcls/BIOSSES_hf/test.json"
        },
        "max_seq_length": 512,
    },
    "hoc": {
        "metric_name": "hoc",
        "data_files": {
          "train": "data/seqcls/HoC_hf/train.json", 
          "validation": "data/seqcls/HoC_hf/dev.json", 
          "test": "data/seqcls/HoC_hf/test.json"
        },
        "max_seq_length": 512,
    },
    "ChemProt": {
        "metric_name": "PRF1",
        "data_files": {
          "train": "data/seqcls/chemprot_hf/train.json", 
          "validation": "data/seqcls/chemprot_hf/dev.json", 
          "test": "data/seqcls/chemprot_hf/test.json"
        },
        "max_seq_length": 256,
    },
    "DDI": {
        "metric_name": "PRF1",
        "data_files": {
          "train": "data/seqcls/DDI_hf/train.json", 
          "validation": "data/seqcls/DDI_hf/dev.json", 
          "test": "data/seqcls/DDI_hf/test.json"
        },
        "max_seq_length": 256,
    },
    "GAD": {
        "metric_name": "PRF1",
        "data_files": {
          "train": "data/seqcls/GAD_hf/train.json", 
          "validation": "data/seqcls/GAD_hf/dev.json", 
          "test": "data/seqcls/GAD_hf/test.json"
        },
        "max_seq_length": 256,
    },
}

### LinkBERT hyperparameters
trainargs = {
    "PubMedQA": {
        "per_device_train_batch_size": 16,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 2e-5,
        "warmup_steps": 100,
        "num_train_epochs": 30,
    },
    "BioASQ": {
        "per_device_train_batch_size": 16,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 2e-5,
        "warmup_steps": 100,
        "num_train_epochs": 20,
    },
    "Biosses": {
        "per_device_train_batch_size": 16,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 1e-5,
        "num_train_epochs": 30,
        "seed": 5
    },
    "hoc": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 4e-5,
        "num_train_epochs": 40,
    },
    "ChemProt": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 3e-5,
        "num_train_epochs": 10,
    },
    "DDI": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 2e-5,
        "num_train_epochs": 5,
    },
    "GAD": {
        "per_device_train_batch_size": 32,
        "gradient_accumulation_steps": 1,
        "fp16": True,
        "learning_rate": 3e-5,
        "num_train_epochs": 10,
    },
}

### The following has to be configured for each dataset

In [None]:
dataset_name = "BioASQ"
metric_name, data_files, max_seq_length = datasets.get(dataset_name).values() # selects metric name and paths from above dict
pad_to_max_length = True
           
training_args = TrainingArguments( # huggingface training arguments https://huggingface.co/docs/transformers/v4.16.2/en/main_classes/trainer#transformers.TrainingArguments
        output_dir=f"./runs/{dataset_name}",
        do_train=True,
        do_eval=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        max_steps=-1,
        per_device_eval_batch_size=8,
        logging_dir="./logs",
        skip_memory_metrics=True,
        report_to="none" if os.environ["WANDB_DISABLED"] == "true" else "wandb",
        logging_steps=100, # logging steps for train loss
        do_predict=True,
        load_best_model_at_end=False, # do we do this here, linkbert doesn't, would require metric_for_best_model and greater_is_better
        **trainargs.get(dataset_name)
    )

### HPO
direction="maximize" # maximize if metric is bigger_is_better, else: minimize
n_trials = 10 # Number of trials for HPO

# Hyperparameter search space, overwriting training_args
# see https://docs.ray.io/en/latest/tune/key-concepts.html#search-spaces
def hp_space_ray(trial): 
    return {
        "learning_rate": tune.loguniform(1e-5, 5e-5),
        "num_train_epochs": tune.choice(range(10, 30)),
        #"seed": tune.choice(range(1, 41)), check with set_seed above, needed anyway?
        "per_device_train_batch_size": tune.choice([4, 8, 16]),
    }

### Seeding
set_seed(training_args.seed) # Set seed before initializing model.

## Set up logging

In [None]:
logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()

# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

## Load dataset

- Load raw dataset from json files
- Set `is_regression` (BIOSSES) and `is_multiclass_binary` (HOC)
- Create list of labels 

In [None]:
# Loading a dataset from your local files.
raw_datasets = load_dataset("json", data_files=data_files)

# Trying to have good defaults here, don't hesitate to tweak to your needs.
is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
is_multiclass_binary = raw_datasets["train"].features["label"].dtype in ["list"]

if is_regression:
    print ('is_regression')
    num_labels = 1
elif is_multiclass_binary:
    print ('is_multiclass_binary')
    assert metric_name.startswith("hoc")
    num_labels = len(raw_datasets["train"][0]["label"])
    label_list = list(range(num_labels))
else:
    # A useful fast method:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
    label_list = raw_datasets["train"].unique("label")
    label_list.sort()  # Let's sort it for determinism
    print ('\nlabel_list', label_list)
    num_labels = len(label_list)

label_to_id = None
if not is_regression:
  label_to_id = {v: i for i, v in enumerate(label_list)}

## Initialize model, tokenizer, config

In [None]:
model_name = "michiyasunaga/BioLinkBERT-base"
# model_name = "sshleifer/tiny-distilroberta-base"

config = AutoConfig.from_pretrained(
    model_name,
    num_labels=num_labels,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
## Needs to be encapsulated for hpo
def model_init():
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config
    )
    if not is_regression:
      model.config.label2id = label_to_id
      model.config.id2label = {id: label for label, id in model.config.label2id.items()}
    return model

## Preprocess dataset

In [None]:
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
    sentence1_key, sentence2_key = "sentence1", "sentence2"
elif "sentence" in non_label_column_names:
    sentence1_key, sentence2_key = "sentence", None
else:
    if len(non_label_column_names) >= 2:
        sentence1_key, sentence2_key = non_label_column_names[:2]
    else:
        sentence1_key, sentence2_key = non_label_column_names[0], None

# Padding strategy
if pad_to_max_length:
    padding = "max_length"
else:
    # We will pad later, dynamically at batch creation, to the max sequence length in each batch
    padding = False

if max_seq_length > tokenizer.model_max_length:
    logger.warning(
        f"The max_seq_length passed ({max_seq_length}) is larger than the maximum length for the"
        f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
    )
max_seq_length = min(max_seq_length, tokenizer.model_max_length)

def preprocess_function(examples):
    # Tokenize the texts
    args = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )

    result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

    # Map labels to IDs (not necessary for GLUE tasks)
    if label_to_id is not None and "label" in examples:
        if is_multiclass_binary:
            result["label"] = examples["label"]
        else:
            result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
    return result

with training_args.main_process_first(desc="dataset map pre-processing"):
    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )

train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]
predict_dataset = raw_datasets["test"]

## Custom evaluation for HOC (https://raw.githubusercontent.com/michiyasunaga/LinkBERT/main/src/seqcls/utils_hoc.py)

In [None]:
import numpy as np

LABELS = ['activating invasion and metastasis', 'avoiding immune destruction',
          'cellular energetics', 'enabling replicative immortality', 'evading growth suppressors',
          'genomic instability and mutation', 'inducing angiogenesis', 'resisting cell death',
          'sustaining proliferative signaling', 'tumor promoting inflammation']


def divide(x, y):
    return np.true_divide(x, y, out=np.zeros_like(x, dtype=np.float), where=y != 0)


def compute_p_r_f(preds, labels):
    TP = ((preds == labels) & (preds != 0)).astype(int).sum()
    P_total = (preds != 0).astype(int).sum()
    L_total = (labels != 0).astype(int).sum()
    P  = divide(TP, P_total).mean()
    R  = divide(TP, L_total).mean()
    F1 = divide(2 * P * R, (P + R)).mean()
    return P, R, F1


def eval_hoc(true_list, pred_list, id_list):
    data = {}

    assert len(true_list) == len(pred_list) == len(id_list), \
        f'Gold line no {len(true_list)} vs Prediction line no {len(pred_list)} vs Id line no {len(id_list)}'

    cat = len(LABELS)
    assert cat == len(true_list[0]) == len(pred_list[0])

    for i in range(len(true_list)):
        id = id_list[i]
        key = id.split('_')[0]
        if key not in data:
            data[key] = (set(), set())

        for j in range(cat):
            if true_list[i][j] == 1:
                data[key][0].add(j)
            if pred_list[i][j] == 1:
                data[key][1].add(j)

    print (f"There are {len(data)} documents in the data set")
    # print ('data', data)

    y_test = []
    y_pred = []
    for k, (true, pred) in data.items():
        t = [0] * len(LABELS)
        for i in true:
            t[i] = 1

        p = [0] * len(LABELS)
        for i in pred:
            p[i] = 1

        y_test.append(t)
        y_pred.append(p)

    y_test = np.array(y_test)
    y_pred = np.array(y_pred)

    p, r, f1 = compute_p_r_f(y_pred, y_test)
    return {"precision": p, "recall": r, "F1": f1}

## Init trainer

In [None]:
def compute_metrics(p: EvalPrediction, eval_dataset):
    
    metric = load_metric("accuracy")
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    if metric_name == "hoc":
        labels = np.array(p.label_ids).astype(int) #[num_ex, num_class]
        preds = (np.array(preds) > 0).astype(int)  #[num_ex, num_class]
        ids = eval_dataset["id"]
        return eval_hoc(labels.tolist(), preds.tolist(), list(ids))

    preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
    
    if metric_name == "pearsonr":
        from scipy.stats import pearsonr as scipy_pearsonr
        pearsonr = float(scipy_pearsonr(p.label_ids, preds)[0])
        return {"pearsonr": pearsonr}
    elif metric_name == "PRF1":
        TP = ((preds == p.label_ids) & (preds != 0)).astype(int).sum().item()
        P_total = (preds != 0).astype(int).sum().item()
        L_total = (p.label_ids != 0).astype(int).sum().item()
        P = TP / P_total if P_total else 0
        R = TP / L_total if L_total else 0
        F1 = 2 * P * R / (P + R) if (P + R) else 0
        return {"precision": P, "recall": R, "F1": F1}
    elif is_regression:
        return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
    else:
        return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
if pad_to_max_length:
    data_collator = default_data_collator
elif training_args.fp16:
    data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
else:
    data_collator = None

# Initialize our Trainer
trainer = SeqClsTrainer(
    model_init=model_init,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator
)

# Train/Eval/Test

## With HPO 

In [None]:
## needed only for google colab
# ray._private.utils.get_system_memory = lambda: psutil.virtual_memory().total

reporter = JupyterNotebookReporter(False)

best_trial = trainer.hyperparameter_search(
  direction=direction,
  backend="ray",
  hp_space=hp_space_ray,
  keep_checkpoints_num=1,
  n_trials=n_trials,
  local_dir=f"./runs/{dataset_name}/ray_results/",
  name=dataset_name,
  progress_reporter=reporter
)

### Load best model from HPO

In [None]:
def recover_checkpoint(tune_checkpoint_dir, model_name=None):
    if tune_checkpoint_dir is None or len(tune_checkpoint_dir) == 0:
        return model_name
    # Get subdirectory used for Huggingface.
    subdirs = [
        os.path.join(tune_checkpoint_dir, name)
        for name in os.listdir(tune_checkpoint_dir)
        if os.path.isdir(os.path.join(tune_checkpoint_dir, name))
    ]
    # There should only be 1 subdir.
    assert len(subdirs) == 1, subdirs
    return subdirs[0]

ray_result_dir = f"./runs/{dataset_name}/ray_results/{dataset_name}"

from ray.tune import ExperimentAnalysis
analysis = ExperimentAnalysis(ray_result_dir)
best_checkpoint = recover_checkpoint(
    analysis.get_best_trial(metric="objective",
                            mode="max" if direction=="maximize" else "min").checkpoint.value
)
best_model = AutoModelForSequenceClassification.from_pretrained(
    best_checkpoint)

# Initialize our Trainer
trainer = SeqClsTrainer(
    model=best_model,
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

### Evaluate

In [None]:
metrics = trainer.evaluate(eval_dataset=eval_dataset)
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

### Predict

In [None]:
results = trainer.predict(predict_dataset, metric_key_prefix="test")
predictions = results.predictions
metrics = results.metrics
metrics["test_samples"] = len(predict_dataset)

trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log(metrics)

import json
output_dir = training_args.output_dir
output_path = f"{output_dir}/test_outputs.json"
json.dump({"predictions": results.predictions.tolist(), "label_ids": results.label_ids.tolist()},
              open(output_path, "w"))

## Simple (should not be used/only for reprocucing LinkBERT numbers)

To reproduce exact numbers

- do `model=model_init()` right after model_init function definition (in "Initialize model, tokenizer, config") and change in Trainer init `model_init=model_init` to `model=model` (Otherwise the rng is not the same with the original script)
- Install same library versions as LinkBERT

```
!pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
!pip install transformers==4.9.1 datasets==1.11.0 fairscale==0.4.0 wandb sklearn seqeval
```

### Train

In [None]:
## To get exacly the same numbers, do `model=model_init()` right after model_init function definition (in "Initialize model, tokenizer, config")
## And change in Trainer init `model_init=model_init` to `model=model` (Otherwise the rng is not the same with the original script)

train_result = trainer.train()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)

trainer.save_model()  # Saves the tokenizer too for easy upload

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

### Evaluate

In [None]:
logger.info("*** Evaluate ***")

metrics = trainer.evaluate(eval_dataset=eval_dataset)
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

### Predict

In [None]:
logger.info("*** Predict ***")

results = trainer.predict(predict_dataset, metric_key_prefix="test")
predictions = results.predictions
metrics = results.metrics
metrics["test_samples"] = len(predict_dataset)

trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log(metrics)

import json
output_dir = training_args.output_dir
output_path = f"{output_dir}/test_outputs.json"
json.dump({"predictions": results.predictions.tolist(), "label_ids": results.label_ids.tolist()},
              open(output_path, "w"))