### Pretraining our BERT-based model
We used masked LM and next visit diagnosis prediction objectives. We identified optimal hyperparameters using a Bayesian search.

In [None]:
import os
import pickle
import sys
import traceback
from typing import Dict, Tuple, Union

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from smart_open import open
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AdamW, BertConfig, get_linear_schedule_with_warmup

import wandb
from common.constants import (
    S3_MODEL_OUTPUT_PATH,
    S3_PRETRAIN_PREPROCESSED_PATH,
)
from common.datasets import PretrainDataset
from common.models import BertPretrain
from common.utilities import create_dataloader, load_vocab


Create taining and evaluation loops.

In [None]:
def evaluate_pretrain(
    model: BertPretrain, dataloader: DataLoader, device: torch.device
) -> Tuple[float]:
    """Evaluate model on validation dataset."""
    model.eval()

    # Keep running values for loss, logits, and labels
    eval_total_loss, eval_mlm_loss, eval_next_visit_loss = 0, 0, 0

    with torch.no_grad():
        for batch in tqdm(dataloader):
            batch = {key: seq.to(device) for key, seq in batch.items()}
            outputs = model(
                feature_ids=batch["feature_tokens"],
                time_ids=batch["time_from_prediction_tokens"],
                code_type_ids=batch["code_type_tokens"],
                attention_mask=batch["attention_mask"],
                mask_labels=batch["mask_labels"],
                labels=batch["label"],
            )
            eval_total_loss += outputs[0].item()
            eval_mlm_loss += outputs[1].item()
            eval_next_visit_loss += outputs[2].item()

    # Calculate average loss
    eval_total_loss /= len(dataloader)
    eval_mlm_loss /= len(dataloader)
    eval_next_visit_loss /= len(dataloader)

    model.train()

    return eval_total_loss, eval_mlm_loss, eval_next_visit_loss


def training(
    model: BertPretrain,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler,
    device: torch.device,
    config: wandb.Config,
):
    """Model training loop."""
    best_val_total_loss = float("inf")

    for epoch in range(config.epochs):
        # Track performance on segments of the traning dataset
        running_total_loss, running_mlm_loss, running_next_visit_loss = 0, 0, 0

        for i, batch in enumerate(tqdm(train_dataloader)):
            batch = {key: seq.to(device) for key, seq in batch.items()}

            outputs = model(
                feature_ids=batch["feature_tokens"],
                time_ids=batch["time_from_prediction_tokens"],
                code_type_ids=batch["code_type_tokens"],
                attention_mask=batch["attention_mask"],
                mask_labels=batch["mask_labels"],
                labels=batch["label"],
            )

            loss = outputs[0]

            # Track training statistics
            running_total_loss += loss.item()
            running_mlm_loss += outputs[1].item()
            running_next_visit_loss += outputs[2].item()

            # Normalize for gradient accumulation
            loss = loss / config.num_accumulation_steps
            loss.backward()

            # Optimize with accumulated gradients
            if (i + 1) % config.num_accumulation_steps == 0:
                # Update Optimizer
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            # Evaluate on validation dataset
            if (i + 1) % config.log_steps == 0:
                (
                    val_total_loss,
                    val_mlm_loss,
                    val_next_visit_loss,
                ) = evaluate_pretrain(
                    model=model,
                    dataloader=val_dataloader,
                    device=device,
                )
                wandb.log(
                    {
                        "epoch": epoch,
                        "steps": i,
                        "train_total_loss": running_total_loss
                        / config.log_steps,
                        "train_mlm_loss": running_mlm_loss / config.log_steps,
                        "train_next_visit_loss": running_next_visit_loss
                        / config.log_steps,
                        "val_total_loss": val_total_loss,
                        "val_mlm_loss": val_mlm_loss,
                        "val_next_visit_loss": val_next_visit_loss,
                    }
                )
                running_total_loss = 0
                running_mlm_loss = 0
                running_next_visit_loss = 0

                # Save model
                if val_total_loss < best_val_total_loss:
                    wandb_save_path = os.path.join(
                        S3_MODEL_OUTPUT_PATH,
                        "pretrained",
                        wandb.run.name + ".pt",
                    )
                    with open(wandb_save_path, "wb") as f:
                        torch.save(model.state_dict(), f)
                    print(
                        f"Saved model weights (with total val loss of {val_total_loss}) to:",
                        wandb_save_path,
                    )
                    best_val_total_loss = val_total_loss


Create pretraining pipeline.

In [None]:
def pretrain_pipeline(config: Dict[str, Union[int, float]]):
    """Run pretraining pipeline."""
    with wandb.init(project="ehr-transformer-v5-pretrain", config=config):
        try:
            # Configuration is filled out by WandB according to the sweep configuration
            config = wandb.config
            device = torch.device(
                "cuda" if torch.cuda.is_available() else "cpu"
            )

            # Load vocabulary encoders
            feature_vocab = load_vocab(
                "feature_vocab_with_cls_pad_mask.pickle"
            )
            time_vocab = load_vocab("time_vocab_with_pad.pickle")
            code_type_vocab = load_vocab("type_vocab_with_pad.pickle")
            next_visit_diagnosis_vocab = load_vocab(
                "last_visit_dx_category_vocab.pickle"
            )

            # Load and create training dataloader
            train_dataloader = create_dataloader(
                dataset_path=os.path.join(
                    S3_PRETRAIN_PREPROCESSED_PATH, "train.parquet"
                ),
                dataset_constructor=PretrainDataset,
                feature_vocab=feature_vocab,
                time_vocab=time_vocab,
                code_type_vocab=code_type_vocab,
                config=config,
                truncate_to=-1,
            )

            # Load and create validation dataloader
            val_dataloader = create_dataloader(
                dataset_path=os.path.join(
                    S3_PRETRAIN_PREPROCESSED_PATH, "test.parquet"
                ),
                dataset_constructor=PretrainDataset,
                feature_vocab=feature_vocab,
                time_vocab=time_vocab,
                code_type_vocab=code_type_vocab,
                config=config,
                truncate_to=-1,
            )

            # Create BERT configuration
            bert_config = BertConfig(
                # Native configurations
                pad_token_id=None,
                position_embedding_type=None,
                type_vocab_size=None,
                vocab_size=len(feature_vocab),
                max_position_embeddings=None,
                # Values tuned by hyperparameter sweep
                classifier_dropout=config.classifier_dropout,
                hidden_dropout_prob=config.hidden_dropout_prob,
                attention_probs_dropout_prob=config.attention_probs_dropout_prob,
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                num_attention_heads=config.num_attention_heads,
                num_hidden_layers=config.num_hidden_layers,
                # Custom configurations
                feature_vocab_size=len(feature_vocab),
                time_vocab_size=len(time_vocab),
                code_type_vocab_size=len(code_type_vocab),
                feature_pad_id=feature_vocab["PAD"],
                time_pad_id=time_vocab["PAD"],
                code_type_pad_id=code_type_vocab["PAD"],
                next_visit_diagnosis_labels_size=len(
                    next_visit_diagnosis_vocab
                ),
            )

            # Create BERT model
            model = BertPretrain(config=bert_config).to(device)

            # Create Adam optimizer
            optimizer = AdamW(
                params=model.parameters(),
                lr=config.learning_rate,
                weight_decay=config.adam_weight_decay,
            )

            # Create learning rate scheduler
            scheduler = get_linear_schedule_with_warmup(
                optimizer=optimizer,
                num_warmup_steps=config.scheduler_warmup_steps,
                num_training_steps=len(train_dataloader)
                // config.num_accumulation_steps
                * config.epochs,
            )

            # Run training loop
            training(
                model=model,
                train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                optimizer=optimizer,
                scheduler=scheduler,
                device=device,
                config=config,
            )

        # Handle errors without quitting wandb sweep
        except Exception as e:
            print(traceback.print_exc(), file=sys.stderr)
            exit(1)


### Bayesian hyperparameter search

Parameters tuned were: 
- `learning_rate`
- `adam_weight_decay`
- `hidden_size`
- `intermediate_size`
- `num_attention_heads`
- `num_hidden_layers`
- `classifier_dropout`
- `hidden_dropout_prob`
- `attention_probs_dropout_prob`
- `num_accumulation_steps`

We used hyperband early termination.

In [None]:
sweep_config = {
    "name": "pretrain_sweep",
    "method": "bayes",
    "metric": {"name": "val_total_loss", "goal": "minimize"},
    "parameters": {
        "epochs": {"value": 1},
        "max_seq_length": {"value": 512},
        "log_steps": {"value": 600},
        "learning_rate": {
            "distribution": "log_uniform_values",
            "min": 1e-4,
            "max": 1e-3,
        },
        "adam_weight_decay": {
            "distribution": "log_uniform_values",
            "min": 1e-3,
            "max": 1,
        },
        "batch_size": {
            "value": 16,
        },
        "scheduler_warmup_steps": {
            "value": 0,
        },
        "hidden_size": {
            # hidden (embedding) size needs to be a multiple of num_attention_heads
            "values": [420, 540, 600, 660, 720, 780, 840, 900, 960]
        },
        "intermediate_size": {
            "distribution": "int_uniform",
            "min": 700,
            "max": 3072,
        },
        "num_attention_heads": {"values": [2, 3, 4, 5, 6, 10, 12]},
        "num_hidden_layers": {
            "distribution": "int_uniform",
            "min": 3,
            "max": 8,
        },
        "classifier_dropout": {
            "distribution": "uniform",
            "min": 0,
            "max": 0.3,
        },
        "hidden_dropout_prob": {
            "distribution": "uniform",
            "min": 0,
            "max": 0.3,
        },
        "attention_probs_dropout_prob": {
            "distribution": "uniform",
            "min": 0,
            "max": 0.5,
        },
        "num_accumulation_steps": {
            "distribution": "int_uniform",
            "min": 12,
            "max": 24,
        },
    },
    "early_terminate": {"type": "hyperband", "min_iter": 1, "eta": 2},
}


In [None]:
# Create hyperparameter sweep
sweep_id = wandb.sweep(sweep_config, project="pretrain-sweeps")

# Start sweep on this machine
wandb.agent(sweep_id, pretrain_pipeline, count=50, project="pretrain-sweeps")


### Run using optimal hyperparameters

After identifying optimal hyperparameters, we ran pretraining again for a longer duration.

In [None]:
final_config = {
    "epochs": 5,
    "max_seq_length": 512,
    "log_steps": 20000,
    "learning_rate": 4e-4,
    "adam_weight_decay": 0.5,
    "batch_size": 32,
    "scheduler_warmup_steps": 0,
    "hidden_size": 780,
    "intermediate_size": 800,
    "num_attention_heads": 3,
    "num_hidden_layers": 6,
    "classifier_dropout": 0.1,
    "hidden_dropout_prob": 0.2,
    "attention_probs_dropout_prob": 0.1,
    "num_accumulation_steps": 7,
}


In [None]:
pretrain_pipeline(config=final_config)
