### Finetuning our BERT-based model for NAT Prediction

We identified optimal hyperparameters using a Bayesian search.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
import traceback
from typing import Tuple, Dict, Union

import numpy as np
import pandas as pd
from tqdm import tqdm
from smart_open import open
import pickle

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, average_precision_score
from transformers import BertConfig, get_linear_schedule_with_warmup
import wandb

from common.constants import (
    S3_FINETUNE_UNRESTRICTED_PREPROCESSED_PATH,
    S3_MODEL_OUTPUT_PATH,
)
from common.models import BertFinetune
from common.datasets import FinetuneDataset
from common.utilities import load_vocab, create_dataloader


Create training and evaluation loops.

In [None]:
def evaluate_finetune(
    model: BertFinetune, dataloader: DataLoader, device: torch.device
) -> Tuple[float]:
    """Evaluate model on validation dataset."""
    model.eval()

    # Keep running values for loss, logits, and labels
    eval_loss = 0
    eval_logits, eval_labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {key: seq.to(device) for key, seq in batch.items()}

            outputs = model(
                feature_ids=batch["feature_tokens"],
                time_ids=batch["time_from_prediction_tokens"],
                code_type_ids=batch["code_type_tokens"],
                attention_mask=batch["attention_mask"],
                labels=batch["label"],
            )

            eval_loss += outputs[0].item()
            eval_logits += torch.squeeze(outputs[1]).tolist()
            eval_labels += torch.squeeze(batch["label"]).tolist()

    # Calculate metrics
    eval_loss /= len(dataloader)
    eval_auroc = roc_auc_score(y_true=eval_labels, y_score=eval_logits)
    eval_auprc = average_precision_score(
        y_true=eval_labels, y_score=eval_logits
    )

    model.train()

    return eval_loss, eval_auroc, eval_auprc


def training(
    model: BertFinetune,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    device: torch.device,
    config: wandb.Config,
):
    """Model training loop."""
    best_val_auroc = 0

    for epoch in range(config.epochs):
        # Track performance on segments of the traning dataset
        running_loss = 0
        running_logits, running_labels = [], []

        for i, batch in enumerate(tqdm(train_dataloader)):
            batch = {key: seq.to(device) for key, seq in batch.items()}

            outputs = model(
                feature_ids=batch["feature_tokens"],
                time_ids=batch["time_from_prediction_tokens"],
                code_type_ids=batch["code_type_tokens"],
                attention_mask=batch["attention_mask"],
                labels=batch["label"],
            )

            loss = outputs[0]

            # Track training statistics
            running_loss += loss.item()
            running_logits += torch.squeeze(outputs[1]).tolist()
            running_labels += torch.squeeze(batch["label"]).tolist()

            # Normalize for gradient accumulation
            loss = loss / config.num_accumulation_steps
            loss.backward()

            # Optimize with accumulated gradients
            if (i + 1) % config.num_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            # Evaluate on validation dataset
            if (i + 1) % config.log_steps == 0:
                val_loss, val_auroc, val_auprc = evaluate_finetune(
                    model=model,
                    dataloader=val_dataloader,
                    device=device,
                )
                wandb.log(
                    {
                        "epoch": epoch,
                        "steps": i,
                        "train_loss": running_loss / config.log_steps,
                        "train_auroc": roc_auc_score(
                            y_true=running_labels, y_score=running_logits
                        ),
                        "train_auprc": average_precision_score(
                            y_true=running_labels, y_score=running_logits
                        ),
                        "val_loss": val_loss,
                        "val_auroc": val_auroc,
                        "val_auprc": val_auprc,
                    }
                )
                running_loss = 0
                running_logits, running_labels = [], []

                # Save model if performance improved
                if val_auroc > best_val_auroc:
                    wandb_save_path = os.path.join(
                        S3_MODEL_OUTPUT_PATH,
                        "finetuned",
                        wandb.run.name + ".pt",
                    )
                    with open(wandb_save_path, "wb") as f:
                        torch.save(model.state_dict(), f)
                    print(
                        f"Saved model weights (with val auroc of {val_auroc}) to:",
                        wandb_save_path,
                    )
                    best_val_auroc = val_auroc


Create finetuning pipeline.

In [None]:
def finetune_pipeline(config: Dict[str, Union[int, float]]):
    """Run finetuning pipeline."""
    with wandb.init(project="transformer-nat-finetune", config=config):
        try:
            # Configuration is filled out by WandB according to the sweep configuration
            config = wandb.config
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu"
            )

            # Load vocabulary encoders
            feature_vocab = load_vocab(
                "feature_vocab_with_cls_pad_mask.pickle"
            )
            time_vocab = load_vocab("time_vocab_with_pad.pickle")
            code_type_vocab = load_vocab("type_vocab_with_pad.pickle")

            # Load and create training dataloader
            train_dataloader = create_dataloader(
                dataset_path=os.path.join(
                    S3_FINETUNE_UNRESTRICTED_PREPROCESSED_PATH, "train.parquet"
                ),
                dataset_constructor=FinetuneDataset,
                feature_vocab=feature_vocab,
                time_vocab=time_vocab,
                code_type_vocab=code_type_vocab,
                config=config,
                truncate_to=-1,
                weighted_sampling=True,
            )

            # Load and create validation dataloader
            val_dataloader = create_dataloader(
                dataset_path=os.path.join(
                    S3_FINETUNE_UNRESTRICTED_PREPROCESSED_PATH, "val.parquet"
                ),
                dataset_constructor=FinetuneDataset,
                feature_vocab=feature_vocab,
                time_vocab=time_vocab,
                code_type_vocab=code_type_vocab,
                config=config,
                truncate_to=-1,
                weighted_sampling=False,
            )

            # Create BERT configuration
            bert_config = BertConfig(
                # Native configurations
                pad_token_id=None,
                position_embedding_type=None,
                type_vocab_size=None,
                vocab_size=None,
                max_position_embeddings=None,
                # Values tuned by hyperparameter sweep
                classifier_dropout=config.classifier_dropout,
                hidden_dropout_prob=config.hidden_dropout_prob,
                attention_probs_dropout_prob=config.attention_probs_dropout_prob,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                num_attention_heads=config.num_attention_heads,
                num_hidden_layers=config.num_hidden_layers,
                pos_weight=config.pos_weight,
                # Custom configurations
                feature_vocab_size=len(feature_vocab),
                time_vocab_size=len(time_vocab),
                code_type_vocab_size=len(code_type_vocab),
                feature_pad_id=feature_vocab["PAD"],
                time_pad_id=time_vocab["PAD"],
                code_type_pad_id=code_type_vocab["PAD"],
            )

            # Create BERT model
            model = BertFinetune(config=bert_config).to(device)

            # Load pretrained model weights
            with open(config.pretrained_path, "rb") as f:
                model.load_state_dict(torch.load(f), strict=False)
            print("Loaded pretrained weights from:", config.pretrained_path)

            # Create Adam optimizer
            optimizer = AdamW(
                params=model.parameters(),
                lr=config.learning_rate,
                weight_decay=config.adam_weight_decay,
            )

            # Create learning rate scheduler
            scheduler = get_linear_schedule_with_warmup(
                optimizer=optimizer,
                num_warmup_steps=config.scheduler_warmup_steps,
                num_training_steps=(
                    len(train_dataloader) // config.num_accumulation_steps
                )
                * config.epochs,
            )

            # Run training loop
            training(
                model=model,
                train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                optimizer=optimizer,
                scheduler=scheduler,
                device=device,
                config=config,
            )

        # Handle errors without quitting wandb sweep
        except Exception as e:
            print(traceback.print_exc(), file=sys.stderr)
            exit(1)


### Bayesian hyperparameter search

Parameters tuned were: 
- `learning_rate`
- `adam_weight_decay`
- `scheduler_warmup_steps`
- `pos_weight`
- `classifier_dropout`
- `hidden_dropout_prob`
- `attention_probs_dropout_prob`
- `num_accumulation_steps`
- `sample_weight`

We used hyperband early termination.

In [None]:
sweep_config = {
    'name': 'finetune_sweep',
    'method': 'bayes',
    'metric': {
        'name': 'val_auroc',
        'goal': 'maximize'
    },
    'parameters': {
        'epochs': {
            'value': 1
        },
        'max_seq_length': {
            'value': 512
        },
        'log_steps': {
            'value': 1000
        },
        'hidden_size': {
            'value': 780
        },
        'intermediate_size': {
            'value': 800
        },
        'num_attention_heads': {
            'value': 3
        },
        'num_hidden_layers': {
            'value': 6
        },
        'batch_size': {
            'value': 32
        },
        'pretrained_path': {
            'value': 's3://transformer-v5/saved_models/pretrained/deep-cosmos-5.pt'
        },
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 5e-6,
            'max': 5e-5
        },
        'adam_weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 1e-4,
            'max': 1e-2
        },
        'scheduler_warmup_steps': {
            'distribution': 'int_uniform',
            'min': 10,
            'max': 100
        },
        'pos_weight': {
            'distribution': 'int_uniform',
            'min': 20,
            'max': 80,
        },
        'classifier_dropout': {
            'distribution': 'uniform',
            'min': 0.1,
            'max': 0.3
        },
        'hidden_dropout_prob': {
            'distribution': 'uniform',
            'min': 0.2,
            'max': 0.4
        },
        'attention_probs_dropout_prob': {
            'distribution': 'uniform',
            'min': 0.2,
            'max': 0.4
        },
        'num_accumulation_steps': {
            'distribution': 'int_uniform',
            'min': 1,
            'max': 10
        },
        'sample_weight': {
            'distribution': 'int_uniform',
            'min': 20,
            'max': 80
        }
    },
    'early_terminate': {
        'type': 'hyperband',
        'min_iter': 1,
        'eta': 2
    }
}


In [None]:
# Create hyperparameter sweep
sweep_id = wandb.sweep(sweep_config, project='finetune-sweep')


In [None]:
# Start sweep on this machine
wandb.agent(sweep_id, finetune_pipeline, count=50, project='finetune-sweep')


### Run using optimal hyperparameters

After identifying optimal hyperparameters, we ran finetuning again for a longer duration.

In [None]:
final_config = {
    "epochs": 1,
    "max_seq_length": 512,
    "log_steps": 5000,
    "learning_rate": 0.00002,
    "adam_weight_decay": 0.001,
    "batch_size": 32,
    "scheduler_warmup_steps": 40,
    "hidden_size": 780,
    "intermediate_size": 800,
    "num_attention_heads": 3,
    "num_hidden_layers": 6,
    "classifier_dropout": 0.2,
    "hidden_dropout_prob": 0.3,
    "attention_probs_dropout_prob": 0.3,
    "num_accumulation_steps": 1,
    "pretrained_path": "s3://transformer-v5/saved_models/pretrained/deep-cosmos-5.pt",
    "pos_weight": 80,
    "sample_weight": 2,
}


In [None]:
finetune_pipeline(config=final_config)
