# Tune and Train with Push-based Metrics Collection Using MNIST

In this Notebook we are going to do the following:
- Train PyTorch MNIST image classification model(CNN).
- Improve the model HyperParameters with [Kubeflow Katib](https://www.kubeflow.org/docs/components/katib/overview/).
- Use Push-based Metrics Collection to efficiently collect metrics in the training containers.

## Install Kubeflow Python SDKs

You need to install Kubeflow SDKs to run this Notebook.

In [None]:
# TODO (Electronic-Waste): Change to release version when SDK with the updated `tune()` is published.
%pip install git+https://github.com/kubeflow/katib.git#subdirectory=sdk/python/v1beta1

## Create Train Script for CNN Model

This is simple **Convolutional Neural Network (CNN)** model for recognizing hand-written digits using [MNIST Dataset](https://yann.lecun.com/exdb/mnist/).

In [1]:
def train_mnist_model(parameters):
    import torch
    import logging
    import kubeflow.katib as katib
    from torchvision import datasets, transforms

    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%dT%H:%M:%SZ",
        level=logging.INFO,
    )
    logging.info("--------------------------------------------------------------------------------------")
    logging.info(f"Input Parameters: {parameters}")
    logging.info("--------------------------------------------------------------------------------------\n\n")

    # Get HyperParameters from the input params dict.
    lr = float(parameters["lr"])
    momentum = float(parameters["momentum"])
    batch_size = int(parameters["batch_size"])
    num_epoch = int(parameters["num_epoch"])
    log_interval = int(parameters["log_interval"])

    # Prepare MNIST Dataset.
    def mnist_train_dataset(batch_size):
        return torch.utils.data.DataLoader(
            datasets.FashionMNIST(
                "./data",
                train=True,
                download=True,
                transform=transforms.Compose([transforms.ToTensor()]),
            ),
            batch_size=batch_size,
            shuffle=True,
        )

    def mnist_test_dataset(batch_size):
        return torch.utils.data.DataLoader(
            datasets.FashionMNIST(
                "./data", train=False, transform=transforms.Compose([transforms.ToTensor()])
            ),
            batch_size=batch_size,
            shuffle=False,
    )
    
    # Build CNN Model.
    def build_and_compile_cnn_model():
        return torch.nn.Sequential(
            torch.nn.Conv2d(1, 20, 5, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
        
            torch.nn.Conv2d(20, 50, 5, 1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
        
            torch.nn.Flatten(),
        
            torch.nn.Linear(4 * 4 * 50, 500),
            torch.nn.ReLU(),
        
            torch.nn.Linear(500, 10),
            torch.nn.LogSoftmax(dim=1)
        )
    
    # Train CNN Model.
    def train_cnn_model(model, train_loader, optimizer, epoch):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = torch.nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % log_interval == 0:
                msg = "Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
                logging.info(msg)
    
    # Test CNN Model and report training metrics
    def test_cnn_model(model, test_loader):
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                test_loss += torch.nn.functional.nll_loss(
                    output, target, reduction="sum"
                ).item()  # sum up batch loss
                pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
        
        test_loss /= len(test_loader.dataset)
        test_accuracy = float(correct) / len(test_loader.dataset)
        katib.report_metrics({  # report metrics directly without outputing logs
            "accuracy": test_accuracy, 
            "loss": test_loss,
        })

    # Download dataset and construct loaders for training and testing
    train_loader = mnist_train_dataset(batch_size)
    test_loader = mnist_test_dataset(batch_size)

    # Build Model and Optimizer
    model = build_and_compile_cnn_model()
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum)

    # Train Model and report metrics
    for epoch_idx in range(1, num_epoch + 1):
       train_cnn_model(model, train_loader, optimizer, epoch_idx)
       test_cnn_model(model, test_loader)



## Start Model Tuning with Katib

If you want to improve your model, you can run HyperParameter tuning with Katib.

The following example uses **Random Search** algorithm to tune HyperParameters.

We are going to tune `learning rate` and `momentum`.

In [2]:
import kubeflow.katib as katib

# Set parameters with their distribution for HyperParameter Tuning with Katib.
parameters = {
    "lr": katib.search.double(min=0.01, max=0.03),
    "momentum": katib.search.double(min=0.3, max=0.7),
    "num_epoch": 1,
    "batch_size": 64,
    "log_interval": 10
}

# Start the Katib Experiment.
# TODO (Electronic-Waste): 
# 1. Change `kubeflow-katib` to release version when `0.18.0` is ready.
# 2. Change `base_image` to official image when `kubeflow-katib` release version `0.18.0` is ready.
exp_name = "tune-mnist"
katib_client = katib.KatibClient(namespace="kubeflow")

katib_client.tune(
    name=exp_name,
    objective=train_mnist_model, # Objective function.
    base_image="docker.io/electronicwaste/pytorch:gitv1",
    parameters=parameters, # HyperParameters to tune.
    algorithm_name="random", # Alorithm to use.
    objective_metric_name="accuracy", # Katib is going to optimize "accuracy".
    additional_metric_names=["loss"], # Katib is going to collect these metrics in addition to the objective metric.
    max_trial_count=12, # Trial Threshold.
    parallel_trial_count=2,
    packages_to_install=["git+https://github.com/kubeflow/katib.git@master#subdirectory=sdk/python/v1beta1"],
    metrics_collector_config={"kind": "Push"},
)

### Access to Katib UI

You can check created experiment in the Katib UI.



### Get the Best HyperParameters from the Katib Experiment

You can get the best HyperParameters from the most optimal Katib Trial.

In [4]:
status = katib_client.is_experiment_succeeded(exp_name)
print(f"Katib Experiment is Succeeded: {status}\n")

best_hps = katib_client.get_optimal_hyperparameters(exp_name)
print(f"Current Optimal Trial:\n{best_hps}")

Katib Experiment is Succeeded: True

Current Optimal Trial:
{'best_trial_name': 'tune-mnist-xqwfhr9w',
 'observation': {'metrics': [{'latest': '0.8276',
                              'max': '0.8276',
                              'min': '0.8276',
                              'name': 'accuracy'},
                             {'latest': '0.48769191679954527',
                              'max': '0.48769191679954527',
                              'min': '0.48769191679954527',
                              'name': 'loss'}]},
 'parameter_assignments': [{'name': 'lr', 'value': '0.024527727574297616'},
                           {'name': 'momentum', 'value': '0.6490973329748595'}]}


## Delete Katib Experiment

When jobs are finished, you can delete the resources.

In [5]:
katib_client.delete_experiment(exp_name)