# Fine-Tuned Transformer Model

In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = "cpu"

## PubMedBert

### Loading the dataset

In [3]:
from datasets import load_dataset

In [4]:
ds = load_dataset("paul-ww/ei-abstract-significance")

In [5]:
class_labels = ds["train"].features["label"]
label2id = {name: class_labels.str2int(name) for name in class_labels.names}
id2label = {v: k for k, v in label2id.items()}

### Tracking using Weights&Biases

In [6]:
%env WANDB_LOG_MODEL='end'
%env WANDB_WATCH='all'

env: WANDB_LOG_MODEL='end'
env: WANDB_WATCH='all'


In [7]:
config = {
    "model": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
    "batch_size": 64,
    "learning_rate": 1e-5,
    "weight_decay": 0.01,
    "num_epochs": 50,
    "weighted_loss": True,
    "gradient_accumulation_steps": 8,
    "gradient_checkpointing": True,
    "seed": 42,
}

In [8]:
import wandb

wandb.login()
run = wandb.init(
    project="significance_classification", group="transformer_finetuned", config=config
)

[34m[1mwandb[0m: Currently logged in as: [33mpaul_ww[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
from transformers import set_seed

set_seed(42)

### Model Setup

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    wandb.config["model"], model_max_length=512, truncation_side="left"
)


def tokenize_function(ds):
    return tokenizer(ds["text"], padding="max_length", truncation=True)

In [11]:
ds_tokenized = ds.map(tokenize_function, batched=True)

In [12]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    wandb.config["model"],
    num_labels=class_labels.num_classes,
    id2label=id2label,
    label2id=label2id,
).to(DEVICE)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall = evaluate.load("recall")
f1 = evaluate.load("f1")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy_score = accuracy.compute(predictions=predictions, references=labels)[
        "accuracy"
    ]
    precision_score = precision.compute(predictions=predictions, references=labels)[
        "precision"
    ]
    recall_score = recall.compute(predictions=predictions, references=labels)["recall"]
    f1_score = f1.compute(predictions=predictions, references=labels)["f1"]
    return {
        "accuracy": accuracy_score,
        "precision": precision_score,
        "recall": recall_score,
        "f1": f1_score,
    }

In [14]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    report_to="wandb",
    output_dir="models/pubmedbert_effect",
    learning_rate=wandb.config["learning_rate"],
    weight_decay=wandb.config["weight_decay"],
    per_device_train_batch_size=wandb.config["batch_size"],
    per_device_eval_batch_size=wandb.config["batch_size"],
    gradient_accumulation_steps=wandb.config["gradient_accumulation_steps"],
    gradient_checkpointing=wandb.config["gradient_checkpointing"],
    num_train_epochs=wandb.config["num_epochs"],
    optim="adamw_torch",
    fp16=True,
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="f1",
    load_best_model_at_end=True,
    overwrite_output_dir=True,
    seed=42,
)

In [15]:
from transformers import Trainer
from transformers import DataCollatorWithPadding
from torch import nn
import torch
import numpy as np

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

class_freqs = np.bincount(ds["train"]["label"], minlength=2)
class_freqs_inv = 1 / class_freqs
CLASS_WEIGHTS_NORM = class_freqs_inv / class_freqs_inv.sum()


class WeightedLossTrainer(Trainer):
    """A trainer with a weighted cross-entropy-loss function."""

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (weighted by inverse class frequency)
        loss_fct = nn.CrossEntropyLoss(
            weight=torch.tensor(
                CLASS_WEIGHTS_NORM, device=model.device, dtype=torch.float16
            )
        )
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

In [16]:
if wandb.config["weighted_loss"]:
    print(
        f"Using weighted cross entropy loss using normalized inverse class frequency. Weights: significant: {CLASS_WEIGHTS_NORM[0]}, not significant: {CLASS_WEIGHTS_NORM[1]}"
    )
    trainer = WeightedLossTrainer(
        model=model,
        args=training_args,
        train_dataset=ds_tokenized["train"],
        eval_dataset=ds_tokenized["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    wandb.config["class_weights"] = CLASS_WEIGHTS_NORM

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_tokenized["train"],
    eval_dataset=ds_tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Using weighted cross entropy loss using normalized inverse class frequency. Weights: significant: 0.6264591439688716, not significant: 0.3735408560311284


In [17]:
! nvidia-smi

Thu Nov 23 16:49:26 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.29.05    Driver Version: 495.29.05    CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A40          On   | 00000000:01:00.0 Off |                    0 |
|  0%   38C    P8    21W / 300W |      0MiB / 45634MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A40          On   | 00000000:25:00.0 Off |                    0 |
|  0%   42C    P0    76W / 300W |   1651MiB / 45634MiB |      0%      Default |
|       

In [18]:
from torch.utils.checkpoint import checkpoint

trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
0,No log,0.643378,0.635593,0.632479,1.0,0.774869
1,No log,0.638072,0.635593,0.632479,1.0,0.774869
2,No log,0.644999,0.618644,0.643564,0.878378,0.742857
3,No log,0.655623,0.627119,0.702703,0.702703,0.702703
4,0.668600,0.637352,0.677966,0.6875,0.891892,0.776471
5,0.668600,0.620341,0.627119,0.633929,0.959459,0.763441
6,0.668600,0.613242,0.644068,0.642857,0.972973,0.774194
8,0.668600,0.607817,0.70339,0.719101,0.864865,0.785276
9,0.611900,0.600071,0.694915,0.731707,0.810811,0.769231
10,0.611900,0.580637,0.728814,0.723404,0.918919,0.809524


TrainOutput(global_step=100, training_loss=0.4221226191520691, metrics={'train_runtime': 905.6673, 'train_samples_per_second': 56.754, 'train_steps_per_second': 0.11, 'total_flos': 1.27293128583168e+16, 'train_loss': 0.4221226191520691, 'epoch': 47.06})

In [19]:
from pathlib import Path

trainer.save_model(Path(run.dir) / "model_finetuned")

#### Evaluation

In [20]:
predictions_proba = trainer.predict(ds_tokenized["test"]).predictions

In [23]:
"""Module containing useful functions in the significance classification context."""

from typing import Any, Sequence

import numpy as np
import wandb
from sklearn.metrics import classification_report


def log_metrics_to_wandb(
    y_true_num: Sequence[int],
    y_pred_proba: np.ndarray,
    id2label: dict[int, str],
    labels: list[str],
    run: Any,
) -> None:
    """Log binary classification metrics to Weights&Biases."""
    y_pred_num = np.argmax(y_pred_proba, axis=1)
    y_true_str = [id2label[e] for e in y_true_num]
    y_pred_str = [id2label[e] for e in y_pred_num]
    # Confusion Matrix
    cm = wandb.plot.confusion_matrix(
        y_true=y_true_num, preds=y_pred_num, class_names=labels
    )
    wandb.log({"test_cm": cm})
    # PR-Curve
    wandb.log({"test_pr": wandb.plot.pr_curve(y_true_num, y_pred_proba, labels)})
    # ROC Curve
    wandb.log({"test_roc": wandb.plot.roc_curve(y_true_num, y_pred_proba, labels)})
    # Log predicted probabilities
    wandb.log(
        {
            "test_probas": wandb.Table(
                data=y_pred_proba, columns=["prob_significant", "prob_not_significant"]
            )
        }
    )
    # Additional Metrics
    report = classification_report(
        y_pred=y_pred_str, y_true=y_true_str, output_dict=True
    )
    wandb.log({"test": report})
    # Ensure summary metrics are present
    run.summary.update({"test": report})
    run.finish()

In [24]:
# from classification.utils import log_metrics_to_wandb

log_metrics_to_wandb(
    y_pred_proba=predictions_proba,
    y_true_num=ds["test"]["label"],
    id2label=id2label,
    labels=class_labels.names,
    run=run,
)

VBox(children=(Label(value='0.029 MB of 0.033 MB uploaded\r'), FloatProgress(value=0.8845244275249466, max=1.0…

0,1
eval/accuracy,▂▂▁▁▃▂▃▄▃▄▅▅▆▆▆▇▇▇▇█████▇▇█████████▇▇▇▇▇
eval/f1,▄▄▃▁▄▄▄▄▄▅▅▆▆▆▇▇▇▇▇▇▇██▇▇▇█████████▇▇▇▇▇
eval/loss,█████▇▇▇▇▆▅▅▅▄▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/precision,▁▁▁▃▃▁▃▄▄▄▄▅▆▆▆▆▇▇██████▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇
eval/recall,██▅▁▅▇▇▅▄▆▆▅▅▅▆▆▆▅▅▅▅▆▆▅▅▅▆▆▆▆▆▆▆▆▆▅▅▅▅▅
eval/runtime,▄▂▅▃▂▂▂▂▂▁▁▂▂▂█▂▂▁▃▃▃▂▂▂▂▂▃▃▂▃▁▂▂▁▂▂▂▄▂▂
eval/samples_per_second,▅▇▄▆▇▇▇▇▇██▇▇▇▁▇▇█▆▆▆▇▇▇▇▇▆▆▇▆█▇▇█▇▇▇▅▇▇
eval/steps_per_second,▅▇▄▆▇▇▇▇▇██▇▇▇▁▇▇█▆▆▆▇▇▇▇▇▆▆▇▆█▇▇█▇▇▇▅▇▇
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███████

0,1
eval/accuracy,0.83051
eval/f1,0.86842
eval/loss,0.3872
eval/precision,0.84615
eval/recall,0.89189
eval/runtime,0.5451
eval/samples_per_second,216.477
eval/steps_per_second,3.669
train/epoch,47.06
train/global_step,100.0
