# Autoscaling Ray on Databricks and Apache Spark

With the release of **Ray 2.8.0**, we have enabled Ray auto-scaling with Ray on Databricks and Apache Spark. Below, we showcase the functionality by going through an example of hyper-parameter tuning for a deep learning model on the CIFAR dataset.

Ray Auto-scaling works with **DBR runtime 14+**, and the code has been tested with the following cluster configurations:

**Azure**: Driver NC6s_v3 and autoscaling with 4 worker nodes NC6s_v3.


**WORK-IN-PROGRESS**

## Install the Ray library and any other python Dependencies
Once specified you do not need to respecify the libraries during Ray initialization

In [0]:
%pip install ray['default,tune'] >=2.8.0

In [0]:
dbutils.library.restartPython()

## Start the ray cluster 
Use the Ray on spark API's to start the cluster refer to the [here](https://docs.ray.io/en/latest/cluster/vms/user-guides/community/spark.html?highlight=ray.util.spark#ray-on-spark-apis)  for more details on the parameters

In [0]:
from ray.util.spark import setup_ray_cluster, shutdown_ray_cluster


num_cpu_cores_per_worker = 4 # total cpu's present in each node
num_cpus_head_node = 4
num_gpu_per_worker = 1
num_gpus_head_node = 1

ray_conf = setup_ray_cluster(
  num_worker_nodes= 4,#this should be set max number of nodes the cluster is allowed to auto-scale
  num_cpus_head_node= num_cpus_head_node, #this should be set cores used in the driver node used for jobs
  num_gpus_head_node= num_gpus_head_node, #this only should be set for GPU enabled cluster 
  num_cpus_per_node=num_cpu_cores_per_worker, #this should be set cores added from each worker node 
  num_gpus_per_node=num_gpu_per_worker,#this should be set gpus added from each worker node 
  autoscale = True)


In [0]:
#Incase you want to restart the cluster use `shutdown_ray_cluster` this will not restart the interpretor or REPL
# shutdown_ray_cluster()

## Import all the libraries

In [0]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
import time

In [0]:
def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/.data.lock")):
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

In [0]:
class Net(nn.Module):
    def __init__(self, l1=120, l2=84):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, l1)
        self.fc2 = nn.Linear(l1, l2)
        self.fc3 = nn.Linear(l2, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

##The Train function
Now it gets interesting, because we introduce some changes to the example from the [PyTorch documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).

The full code example looks like this:

In [0]:
def train_cifar(config,loc):
    
    print("num_cpus:",int(train.get_context().get_trial_resources().head_cpus))
    torch.set_num_threads(int(train.get_context().get_trial_resources().head_cpus))
    net = Net(config["l1"], config["l2"])

    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)

    # To restore a checkpoint, use `train.get_checkpoint()`.
    loaded_checkpoint = train.get_checkpoint()
    if loaded_checkpoint:
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
           model_state, optimizer_state = torch.load(os.path.join(loaded_checkpoint_dir, "checkpoint.pt"))
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    data_dir = os.path.abspath("./data")
    trainset, testset = load_data(data_dir)

    test_abs = int(len(trainset) * 0.8)
    train_subset, val_subset = random_split(
        trainset, [test_abs, len(trainset) - test_abs])

    trainloader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)
    valloader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=int(config["batch_size"]),
        shuffle=True,
        num_workers=8)

    for epoch in range(config['max_epoch']):  # loop over the dataset multiple times
        running_loss = 0.0
        epoch_steps = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            epoch_steps += 1
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
                                                running_loss / epoch_steps))
                running_loss = 0.0

        # Validation loss
        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, data in enumerate(valloader, 0):
            with torch.no_grad():
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                loss = criterion(outputs, labels)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        # Here we save a checkpoint. It is automatically registered with
        # Ray Tune and can be accessed through `train.get_checkpoint()`
        # API in future iterations.
        os.makedirs(f"{loc}/mymodel", exist_ok=True)
        torch.save(
            (net.state_dict(), optimizer.state_dict()), f"{loc}/mymodel/checkpoint.pt")
        checkpoint = Checkpoint.from_directory(f"{loc}/mymodel/")
        train.report({"loss": (val_loss / val_steps),"try_gpu" : False, "accuracy": correct / total}, checkpoint=checkpoint)
    print("Finished Training")

In [0]:
def main(num_samples=10, max_num_epochs=10,
         grace_period=5,cpus_per_trial=1, 
         gpus_per_trial=0 , loc = '/dbfs/pj/ray/'):
    config = {
        "l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16]),
        "max_epoch":20
    }
    scheduler = ASHAScheduler(
        max_t=config['max_epoch'],
        grace_period=5,
        reduction_factor=2)
    
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(train_cifar,loc = loc),
            resources={"cpu": cpus_per_trial, "gpu":gpus_per_trial }
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        run_config=train.RunConfig(
        storage_path=os.path.expanduser(loc),
        name="tune_checkpointing_location",
    ),
        param_space=config,
    )
    results = tuner.fit()
    
    best_result = results.get_best_result("loss", "min")

    print("Best trial config: {}".format(best_result.config))
    print("Best trial final validation loss: {}".format(
        best_result.metrics["loss"]))
    print("Best trial final validation accuracy: {}".format(
        best_result.metrics["accuracy"]))

    test_best_model(best_result)


In [0]:
def test_best_model(best_result):
    best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    best_trained_model.to(device)

    checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")

    model_state, optimizer_state = torch.load(checkpoint_path)
    best_trained_model.load_state_dict(model_state)

    trainset, testset = load_data()

    testloader = torch.utils.data.DataLoader(
        testset, batch_size=4, shuffle=False, num_workers=2)

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = best_trained_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()


    print("Best trial test set accuracy: {}".format(correct / total))

In [0]:
# Run a CPU only Trial
main(num_samples=8, max_num_epochs=10,grace_period=5,cpus_per_trial=3, gpus_per_trial=0 , loc = '/dbfs/pj/ray/')

In [0]:
# Run a GPU only Trial
main(num_samples=8, max_num_epochs=10,grace_period=5,cpus_per_trial=1, gpus_per_trial=0.5 , loc = '/dbfs/pj/ray/')