# Train the model

This task trains the model using DDP using 
[Ray on Databricks](https://docs.databricks.com/aws/en/machine-learning/ray/).

It demonstrates multi-node, multi-GPU training and model logging using a 
HuggingFace trainer wrapped in Ray train. For hyperparameter sweeps, you 
can further wrap the Ray trainer with Ray tune and run that in a similar 
way over the cluster, with the potential for multiple parallel runs if needed.

Aside from Ray train, there is also the option to use TorchDistributor to 
distribute the training over a Spark cluster. Likewise, there are other 
mechanisms for the trainer and data loader as well, such as using Mosaic 
Composer and Streaming Dataset.

In [0]:
import numpy as np
import pandas as pd

import torch

from functools import partial
from itertools import batched
from dataclasses import dataclass

import mlflow
import mlflow.pyfunc
import mlflow.transformers
from mlflow.utils.databricks_utils import get_databricks_env_vars
from mlflow.pyfunc import PythonModel
from mlflow.models.signature import infer_signature

from ray.util.spark import setup_ray_cluster
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig
from ray.train.torch import TorchTrainer

import ray.train.huggingface.transformers
from ray.data.context import DataContext

import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TextClassificationPipeline,
    TrainingArguments,
    Trainer,
    TrainerCallback,
    pipeline
)

import evaluate

import os

DataContext.get_current().enable_progress_bars = False

In [0]:
dbutils.widgets.text("catalog_name", "")
dbutils.widgets.text("schema_name", "")
dbutils.widgets.text("data_dir", "")
dbutils.widgets.text("experiment_name", "")
dbutils.widgets.text("ray_results_dir", "")
dbutils.widgets.text("hugging_face_id", "")
dbutils.widgets.text("ray_collect_log_to_path", "")

catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
data_dir = dbutils.widgets.get("data_dir")
experiment_name = dbutils.widgets.get("experiment_name")
ray_results_dir = dbutils.widgets.get("ray_results_dir")
hugging_face_id = dbutils.widgets.get("hugging_face_id")
ray_collect_log_to_path = dbutils.widgets.get("ray_collect_log_to_path")


assert catalog_name, "catalog_name is required"
assert schema_name, "schema_name is required"
assert data_dir, "data_dir is required"
assert experiment_name, "experiment_name is required"
assert ray_results_dir, "ray_results_dir is required"
assert hugging_face_id, "hugging_face_id is required"
assert ray_collect_log_to_path, "ray_collect_log_to_path is required"

spark.sql(f"USE CATALOG {catalog_name}")
spark.sql(f"USE SCHEMA {schema_name}")

mlflow.set_experiment(experiment_name)

source_table_name = "yelp_reviews_silver"
splits = ["train", "test"]

working_dir = "/local_disk0/tmp/hf_fine_tuning_example"
os.makedirs(working_dir, exist_ok=True)
os.chdir(working_dir)

os.makedirs(ray_results_dir, exist_ok=True)

notebook_context = dbutils.notebook.entry_point.getDbutils().notebook().getContext()
databricks_host = spark.conf.get("spark.databricks.workspaceUrl")
databricks_token = notebook_context.apiToken().get()
# run_config_storage_path = "/local_disk0/tmp/ray_results"

cluster_profile = spark.conf.get("spark.databricks.cluster.profile", "multiNode")

print(f"catalog_name: {catalog_name}")
print(f"schema_name: {schema_name}")
print(f"data_dir: {data_dir}")
print(f"experiment_name: {experiment_name}")
print(f"source_table_name: {source_table_name}")
print(f"hugging_face_id: {hugging_face_id}")
print(f"splits: {splits}")
print(f"working_dir: {working_dir}")
print(f"databricks_host: {databricks_host}")
print(f"databricks_token: ****************")
print(f"ray_results_dir: {ray_results_dir}")
print(f"cluster_profile: {cluster_profile}")

In [0]:
# We will use the OSS Ray on Databricks option for distributing our training
# over multiple GPU's and multiple nodes in this case. To do so, we first initialize
# Ray over the cluster. If we are on a single node, we just call ray.init, but 
# on multinode we call setup_ray_cluster so it starts on all required nodes.
if not ray.is_initialized():

    if cluster_profile == "singleNode":

        # Setting dashboard host to 0.0.0.0 allows us to access the Ray dashboard 
        # from the driver proxy port. In multi-node clusters this is already handled.
        if not ray.is_initialized():
            ray.init(dashboard_host="0.0.0.0", dashboard_port=8265)
        
    else:
        max_worker_nodes = 2 # Maximum number of worker nodes that can be used by Ray
        min_worker_nodes = 2 # Minimum number of worker nodes that will be used by Ray
        num_cpus_worker_node = 32 # Using half of the available CPUs for Ray; the other half can be occupied by Spark
        num_gpus_worker_node = 4 # set to 1 so that Ray worker node can use 1 GPU. set to 0 will disallow Ray worker to use GPU.  0 # Set to 0 because if we've set spark.task.resource.gpu.amount to 0 on the cluster, Ray will still grab all the GPUs on the worker nodes
        num_cpus_head_node = 4 # Using half the CPUs on the main node
        num_gpus_head_node = 0 # Using all the GPUs on the main node

        setup_ray_cluster(
            max_worker_nodes=max_worker_nodes,
            min_worker_nodes=min_worker_nodes,
            num_cpus_per_node=num_cpus_worker_node,
            num_gpus_per_node=num_gpus_worker_node,
            num_cpus_head_node=num_cpus_head_node,
            num_gpus_head_node=num_gpus_head_node,
            collect_log_to_path=ray_collect_log_to_path
        )

        ray.init()

In [0]:
train_path = os.path.join(data_dir, "train.parquet")
test_path = os.path.join(data_dir, "test.parquet")

full_train_ds = ray.data.read_parquet(train_path)
train_ds, val_ds = full_train_ds.train_test_split(test_size=50000, shuffle=True, seed=42)

# We load the test set here to show where to load it, but in the run below
# we don't actually use it yet. You would add a final eval on this or do 
# a separate eval run as part of the deployment job. Skipping for simplicity
# and since its not the primary focus here.
test_ds = ray.data.read_parquet(test_path)

In [0]:
class CustomMLflowCallback(TrainerCallback):
    """
    Simple HuggingFace training callback for logging to MLflow.
    """

    def __init__(self, run_id: str = None, node_id: str = None):
        super().__init__()
        self.run_id = run_id
        self.node_id = node_id
        self.run = None

    def on_train_begin(self, args, state, control, **kwargs):
        # Mostly all the logging happens from global rank 0, but for multi-node 
        # system metrics collection we can also start the run on each local rank 0.
        if state.is_local_process_zero:
            mlflow.config.set_system_metrics_node_id(self.node_id)
            self.run = mlflow.start_run(run_id=self.run_id)
        if state.is_world_process_zero:
            mlflow.log_params(vars(args))

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero:
            if logs:
                metrics = {k: v for k, v in logs.items() if isinstance(v, (float, int))}
                mlflow.log_metrics(metrics, step=state.global_step)

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if state.is_world_process_zero:
            if metrics:
                mlflow.log_metrics({k: v for k, v in metrics.items() if isinstance(v, (float, int))}, step=state.global_step)

    def on_save(self, args, state, control, **kwargs):
        if state.is_world_process_zero:
            mlflow.log_artifacts(args.output_dir, artifact_path="checkpoints")

    def on_train_end(self, args, state, control, **kwargs):
        if state.is_local_process_zero:
            mlflow.end_run()


class ReviewClassifierModel(PythonModel):
    """
    Custom MLflow PythonModel class for pre and post processing logic, batching,
    and ensuring tokenization is correct.
    """

    def load_context(self, context):
        tokenizer_dir = context.artifacts["tokenizer_dir"]
        model_dir = context.artifacts["model_dir"]
        self.batch_size = context.model_config.get("batch_size")
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_dir)
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.max_length = self.model.config.max_position_embeddings

    def predict(self, context, model_input):
        texts = model_input["text"].tolist()        
        results = []
        for batch in batched(texts, self.batch_size):
            inputs = self.tokenizer(
                batch,
                truncation=True,
                max_length=self.max_length,
                padding=True,
                return_tensors="pt"
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            with torch.inference_mode():
                outputs = self.model(**inputs)
                probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()
            id2label = self.model.config.id2label
            for sample_scores in probs:
                top_idx = int(sample_scores.argmax())
                results.append({
                    "label": id2label[top_idx],
                    "score": float(sample_scores[top_idx])
                })
        return results


def collate_fn(batch):
    """
    Transform our incoming pre-tokenized batches according to our model's needs.
    """
    input_ids = batch["input_ids"]
    lengths = [len(x) for x in input_ids]
    max_length = max(lengths)
    pad_lengths = [max_length - x for x in lengths]
    padding = [np.zeros(x, dtype=np.int32) for x in pad_lengths]
    padded = [np.concatenate([x, [0] * y]) for x, y in zip(input_ids, padding)]
    attention_mask = [[1] * x + [0] * y for x, y in zip(lengths, pad_lengths)]
    batch["input_ids"] = torch.tensor(np.stack(padded))
    batch["attention_mask"] = torch.tensor(np.stack(attention_mask))
    batch["labels"] = torch.tensor(batch["label"])
    return {
        "input_ids": torch.tensor(np.stack(padded)),
        "attention_mask": torch.tensor(np.stack(attention_mask)),
        "labels": torch.tensor(batch["label"], dtype=torch.long)
    }


# Collect the local MLflow credentials so that our Ray train function
# can close over them for accessing the Databricks managed MLflow instance.
mlflow_db_creds = get_databricks_env_vars("databricks")

def train_func(config):
    """
    Train a model using the Ray Train API.
    """
    import os
    import mlflow

    # Unpack the Ray train config object.
    batch_size = config["batch_size"]
    hugging_face_name = config["hugging_face_name"]
    eval_steps = config["eval_steps"]
    max_steps = config["max_steps"]
    experiment_name = config["experiment_name"]
    run_id = config["run_id"]

    # Compute the ID to use for system metrics collection in MLflow.
    node_rank = ray.train.get_context().get_node_rank()
    node_id = f"node-{node_rank}"

    # Use the credentials to set our configured experiment and prepare for logging.
    os.environ.update(mlflow_db_creds)
    mlflow.set_experiment(experiment_name)

    train_shard = ray.train.get_dataset_shard("train")
    val_shard = ray.train.get_dataset_shard("val")

    train_iter_ds = train_shard.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    val_iter_ds = val_shard.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)

    model = AutoModelForSequenceClassification.from_pretrained(
        hugging_face_name, num_labels=5)
    
    metric = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        # convert the logits to their predicted class
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    # Hugging Face Trainer
    training_args = TrainingArguments(
        output_dir="yelp_review_classifier",
        logging_strategy="steps",
        logging_steps=1,
        logging_first_step=True,
        eval_strategy="steps",
        eval_steps=eval_steps,
        save_strategy="epoch",
        report_to="none",
        max_steps=max_steps,
        disable_tqdm=True
    )

    mlflow_callback = CustomMLflowCallback(
        run_id=run_id,
        node_id=node_id
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_iter_ds,
        eval_dataset=val_iter_ds,
        compute_metrics=compute_metrics,
        callbacks=[mlflow_callback]
    )

    callback = ray.train.huggingface.transformers.RayTrainReportCallback()
    trainer.add_callback(callback)

    trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)
    trainer.train()

In [0]:
with mlflow.start_run() as run:
    # We would likely move even more of these parameters to task parameters so they're 
    # externally configurable (e.g., # steps, batch sizes, etc...). Setting here for simplicity.
    ray_trainer = TorchTrainer(
        train_func,
        train_loop_config={
            "batch_size": 8,
            "hugging_face_name": "google-bert/bert-base-cased",
            "eval_steps": 25,
            "max_steps": 50,
            "experiment_name": experiment_name,
            "run_id": run.info.run_id
        },
        scaling_config=ScalingConfig(
            num_workers=4,
            use_gpu=True
        ),
        run_config=RunConfig(
            storage_path=ray_results_dir
        ),
        datasets={
            "train": train_ds,
            "val": val_ds
        },
        # [4a] For multi-node clusters, configure persistent storage that is
        # accessible across all worker nodes
        # run_config=ray.train.RunConfig(storage_path="s3://..."),
    )

    # To do a hyperparameter sweep, you would instead take the Trainer and pass
    # it to Ray tune and use that to do a fit. Other than that it is pretty similar.
    result: ray.train.Result = ray_trainer.fit()

    # Grab the best checkpoint we saw during the run and log it to MLflow,
    # along with the tokenizer (even though we pretokenized the results, at inference
    # we want the logged model to handle tokenization).
    best_checkpoint = result.get_best_checkpoint(metric="eval_accuracy", mode="max")
    tokenizer_dir = os.path.join(best_checkpoint.path, "tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(hugging_face_id)
    tokenizer.save_pretrained(tokenizer_dir)
    model_dir = os.path.join(best_checkpoint.path, "checkpoint")

    input_example = pd.DataFrame({"text": ["A positive example sentence."]})
    output_example = pd.DataFrame({"label": ["LABEL_4"], "score": [0.9]})
    signature = infer_signature(input_example, output_example)

    mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=ReviewClassifierModel(),
        artifacts={
            "model_dir": model_dir,
            "tokenizer_dir": tokenizer_dir
        },
        input_example=input_example,
        signature=signature,
        pip_requirements=[
            f"torch=={torch.__version__}",
            f"transformers=={transformers.__version__}",
            f"mlflow=={mlflow.__version__}"
        ],
        model_config={
            "batch_size": 32
        }
    )
       