## Introduction to Plyto with PyTorch
#### Python Machine Learning Visualization Toolkit
This notebook will demonstrate how to use our example PyTorch loss callback class with Plyto to visualize model loss throughout the training process of a machine learning model, as well as a tutorial on how to create your own callback class.

The <img src='style/icons/machinelearning-blue.svg'> 
toolbar item opens the Plyto model visualizer for this notebook!

#### Running a model
To demonstrate how Plyto works, we will be looking at the CIFAR-10 tiny image data, which can be loaded from torchvision.datasets

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Files already downloaded and verified
Files already downloaded and verified


#### How it works

A Plyto instance requires an Altair spec to define plots. Below is an example of a simple altair spec to create a line graph of samples versus loss.

In [2]:
# an array of Altair specs with one plot of samples versus loss
spec = [
    {
        # specifies an altair spec
        "$schema": "https://vega.github.io/schema/vega-lite/v2.json",
        "name": "lossGraph",
        
        #size of the plot
        "config": {
            "view": {
                "height": 300,
                "width": 300
            }
        },
        
        # name of the dataset must be "dataSet"
        "data": {
            "name": "dataSet"
        },
        
        # visual encodings of the plot
        "encoding": {
            "x": {
                "field": "samples",
                "type": "quantitative"
            },
            "y": {
                "field": "loss",
                "type": "quantitative"
            }
        },
        
        "mark": "line"
    }
]

A callback class that takes a Plyto instance as a constructor parameter is called each iteration through the training loop. 

However you structure your network, simply call the callback class's update method every N iterations through the training loop you want to update the data and open the Plyto model visualizer to see your statistics and plots update.

In [3]:
from time import time
from plyto import PytorchLossCallback, PlytoAPI

plyto_instance = PlytoAPI(spec)

callback = PytorchLossCallback(plyto_instance, 2, 12400) # 5 epochs of 
                                                         # 12400 mini-batches each

for epoch in range(2):  # loop over the dataset multiple times
    callback.update_step_number(epoch + 1) # update the current epoch

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 0 and i != 0:    # print every 100 mini-batches
            callback.update_total_progress(100) # update total progress
            callback.update_data(i, running_loss / 100) # update current progress,
                                                        # loss, and send data
            running_loss = 0.0

Using TensorFlow backend.


{'2a3156185b9c4038b660a3d9ec6b4296': <ipykernel.comm.comm.Comm object at 0x103f6bba8>}


Process Process-1:
Process Process-2:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/Jenna/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/Jenna/miniconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/Users/Jenna/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/Jenna/miniconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/Jenna/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/Users/Jenna/miniconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/Users/Jenna/minicon

KeyboardInterrupt: 

#### Writing your own callback function

A callback class for Plyto is a class that takes a Plyto instance as a parameter. 

Within this custom function, you can define functions to execute or update data at specific points in running the network.

For the progress bars in the status bar to work correctly, your callback function must send epochs, sample_amount, total_progress, current_progress, and epoch_number using Plyto. Further, start_time is required for the panel to display the runtime once the model is complete. Below is a base to work off of, only containing these variables for basic functionality and passing no altair spec for plots.

In [None]:
class PytorchBasicCallback:
    """
    Create a callback that will track and display training progress

    :param steps: number of epochs/steps

    :param sample_amount: number of samples/steps per epoch

    :param start_time: start of training time, used to calculate runtime

    :param plyto: an instance of a PlytoAPI class
    """

    def __init__(self, plyto_instance, steps=0, sample_amount=0):
        self.total_progress = 0
        self.start_time = time()
        self.plyto = plyto_instance
        self.initalize_plyto(steps, sample_amount)

    def initalize_plyto(self, steps, sample_amount):
        """
        Initalize the Plyto instance's total steps and step size
        
        :param steps: total number of steps

        :param sample_amount: number of samples/batches per step
        """
        self.plyto.update_total_steps(steps)
        self.plyto.update_size(sample_amount)

    def update_step_number(self, new_step):
        """
        Update the current step/epoch

        :param new_step: the current step/epoch
        """
        self.plyto.update_current_step(new_step)

    def update_total_progress(self, progress):
        """
        Update the total training progress

        :param progress: the amount to increment the total progress by
        """
        self.total_progress += progress

    def update_data(self, current_progress):
        """
        Update progress, total progress, loss, and runtime before sending data to frontend

        :param current_progress: the progress of training the current step/epoch
        
        :param loss: the current batch's training loss
        """
        self.plyto.update_current_progress(current_progress)
        self.plyto.update_total_progress(self.total_progress)
        self.plyto.update_runtime(time() - self.start_time)
        self.plyto.send_data()

In [None]:
from time import time

callback = PytorchBasicCallback(plyto_instance, 5, 12400) # 5 epochs of 
                                                         # 12400 mini-batches each

for epoch in range(5):  # loop over the dataset multiple times
    callback.update_step_number(epoch + 1) # update the current epoch

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 100 == 0 and i != 0:    # print every 100 mini-batches
            callback.update_total_progress(100) # update total progress
            callback.update_data(i) # update current progress and send data


*Note: if you are to stop and re-run the model, the plyto_instance and callback must be re-initialized. We recommend initializing them in the same cell as the call to model.fit() to ensure this works properly*