# Custom Optimizer
We have seen in the [previous notebook]() that we can easily create graphs with an internal state using the ```addons.Module``` class. 
Besides layers, also optimizers often carry a state. Hence, they are a perfect use case for the ```Module``` class.

In this notebook we will show how to create a custom optimizer with state, the Adam optimizer.

In [None]:
import sys

In [None]:
import argparse
from functools import partial
from typing import Mapping, Optional
import torch
import numpy as np
import popxl
import popxl_addons as addons
import popxl.ops as ops
from typing import Union, Dict
from popxl_addons.graph import GraphWithNamedArgs
from popxl_addons.named_tensors import NamedTensors
from popxl_addons.input_factory import NamedInputFactories

In [None]:
!{sys.executable} -m pip install torch
import torch

In [None]:
!{sys.executable} -m pip install torchvision
import torchvision

In [None]:
!{sys.executable} -m pip install tqdm
from tqdm import tqdm

In [None]:
np.random.seed(42)

In [None]:
'''
Adam optimizer.
Defines adam update step for a single variable
'''
class Adam(addons.Module):
    # we need to specify in_sequence because a lot of operations are in place and their order 
    # shouldn't be rearranged 
    @popxl.in_sequence()
    def build(self,
              var: popxl.TensorByRef,
              grad: popxl.Tensor,
              *,
              lr: Union[float, popxl.Tensor],
              beta1: Union[float, popxl.Tensor] = 0.9,
              beta2: Union[float, popxl.Tensor] = 0.999,
              eps: Union[float, popxl.Tensor] = 1e-5,
              weight_decay: Union[float, popxl.Tensor] = 1e-2,
              first_order_dtype: popxl.dtype = popxl.float16,
              bias_correction: bool = True):

        # gradient estimators for the variable var - same shape as the variable 
        first_order = self.add_input_tensor("first_order", partial(np.zeros, var.shape), first_order_dtype, by_ref=True)
        ops.var_updates.accumulate_moving_average_(first_order, grad, f=beta1)

        # variance estimators for the variable var - same shape as the variable 
        second_order = self.add_input_tensor("second_order", partial(np.zeros, var.shape), popxl.float32, by_ref=True)
        ops.var_updates.accumulate_moving_average_square_(second_order, grad, f=beta2)

        # adam is a biased estimator: provide the step to correct bias
        step = None
        if bias_correction:
            step = self.add_input_tensor("step", partial(np.zeros, ()), popxl.float32, by_ref=True)

        # calculate the weight increment with adam euristic 
        updater = ops.var_updates.adam_updater(
            first_order, second_order,
            weight=var,
            weight_decay=weight_decay,
            time_step=step,
            beta1=beta1,
            beta2=beta2,
            epsilon=eps)

        # in place weight update: w += (-lr)*dw
        ops.scaled_add_(var, updater, b=-lr)

The first important thing to note is that all the build method is executed **in sequence**, as we've added the ```@popxl.in_sequence()``` decorator. 
It is necessary since most of the optimizer operations are **in place**, hence their order of execution must be stricly preserved.

Note also that the ```var``` input is a ```popxl.TensorByRef```: any change made to this variable will be automatically copied to the parent graph. See [TensorByRef]() for more information.

Allowing a ```Union[float, popxl.Tensor]``` type for the optimizer parameters such as the learning rate or the weight decay gives this module an interesting property.
If the parameter is provided as a ```float```, it will be "baked" into the graph, with no possibility of changing it.
Instead, if the parameter is a ```Tensor``` (or ```TensorSpec```) it will appear as an input to the graph, which needs to be provided when calling the graph. If you plan to change a parameter (for example, because you have a learning rate schedule), this is the way to go.

The rest of the logic is straightforward:

- we update the first moment, estimator for the gradient of the variable
- we update the second moment, estimator for the variance of the variable
- we optionally correct the estimators, since they are biased
- we compute the increment for the variable, ```dw```
- we update the variable ``` w += (-lr) * dw```

The ```ops.var_updates``` module contains several useful pre-made update rules, but you can also make your own. In this example we are using three of them:

- ```ops.var_updates.accumulate_moving_average_(average, new_sample, coefficient)``` updates ```average``` in place with an exponential moving average rule: 
    ```
    average = (coefficient * average) + ((1-coefficient) * new_sample)
    ```
- ```accumulate_moving_average_square_(average, new_sample, coefficient)``` does the same, but using the square of the sample.  
- ```ops.var_updates.adam_updater(...)``` returns the adam increment ```dw``` which is required for the weight update, computed using adam internal state, i.e. the first and second moments.

Let's inspect the optimizer graph and its use in a simple example.

In [None]:
ir = popxl.Ir()
ir.replication_factor = 1 

with ir.main_graph:
    var = popxl.variable(np.ones((2,2)),popxl.float32)
    grad = popxl.variable(np.full((2,2),0.1),popxl.float32)
    # create graph and factories - float learning rate
    adam_facts, adam = Adam(cache=True).create_graph(var, var.spec, lr=1e-3)
    # create graph and factories - Tensor learning rate
    adam_facts_lr, adam_lr = Adam().create_graph(var, var.spec, lr=popxl.TensorSpec((),popxl.float32))
    print("Adam with float learning rate\n")
    print(adam.print_schedule())
    print("\n Adam with tensor learning rate\n")
    print(adam_lr.print_schedule())
    # instantiate optimizer variables 
    adam_state = adam_facts.init()
    adam_state_lr = adam_facts_lr.init()
    # optimization step for float lr: call the bound graph providing the variable to update and the gradient 
    adam.bind(adam_state).call(var, grad)
    # optimization step for tensor lr: call the bound graph providing the variable to update, the gradient and the learning rate
    adam_lr.bind(adam_state_lr).call(var, grad, popxl.constant(1e-3))

ir.num_host_transfers = 1
session = popxl.Session(ir,"ipu_hw")
print("\n Before adam update")
var_data = session.get_tensor_data(var)
state = session.get_tensors_data(adam_state.tensors)
print("Variable:\n", var)
print("Adam state:")
for name, data in state.items():
    print(name,'\n', state[name])

session.run()

print("\n After adam update")
var_data = session.get_tensor_data(var)
state = session.get_tensors_data(adam_state.tensors)
print("Variable:\n", var)
print("Adam state:")
for name, data in state.items():
    print(name,'\n', state[name])

session.device.detach()

# Mnist with Adam
We can now refactor our mnist example to incorporate the Adam optimizer. 
Note that we need an optimizer for each variable: we first define a utility function to create all the graphs and perform a full weight update for all the variables in the neural network. 

We will use a float learning rate, since we don't plan to change its value during training.

The training code is almost unchanged from that of the previous tutorial, the only different piece is the code related to the optimizer in  ```train_program```. Also, since we are using Adam, we need to use a smaller learning rate. 

You will notice that we create the Adam module using 
```python
optimizer = Adam(cache=True)
```
Using `cache=True` will enable graph reuse, if possible, when calling `optimizer.create_graph`. For our optimizer this would be when there are multiple variables with the same shape/dtype.

In [None]:
'''
Update all variables creating per-variable optimizers. 
'''
def optimizer_step(variables,
                   grads: Dict[popxl.Tensor, popxl.Tensor],
                   optimizer : addons.Module,
                   lr : popxl.float32 = 1e-3):
    for name, var in variables.named_tensors.items():
        #create optimizer and state factories for the variable
        opt_facts, opt_graph = optimizer.create_graph(
            var,
            var.spec,
            lr=lr, 
            weight_decay=0.0,
            bias_correction=False
            )
        state = opt_facts.init()
        # bind the graph to its state and call it.
        # Both the state and the variables are updated in place and are passed by ref,
        # hence after the graph is called they are updated.
        opt_graph.bind(state).call(var, grads[var])

In [None]:
def get_mnist_data(test_batch_size: int, batch_size: int):
    training_data = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            '~/.torch/datasets',
            train=True,
            download=True,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307, ), (0.3081, )),  # mean and std computed on the training set.
            ])),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True)

    validation_data = torch.utils.data.DataLoader(torchvision.datasets.MNIST('~/.torch/datasets',
                                                                             train=False,
                                                                             download=True,
                                                                             transform=torchvision.transforms.Compose([
                                                                                 torchvision.transforms.ToTensor(),
                                                                                 torchvision.transforms.Normalize(
                                                                                     (0.1307, ), (0.3081, )),
                                                                             ])),
                                                  batch_size=test_batch_size,
                                                  shuffle=True,
                                                  drop_last=True)
    return training_data, validation_data


def accuracy(predictions: np.ndarray, labels: np.ndarray):
    ind = np.argmax(predictions, axis=-1).flatten()
    labels = labels.detach().numpy().flatten()
    return np.mean(ind == labels) * 100.0


class Linear(addons.Module):
    def __init__(self, out_features: int, bias: bool = True):
        super().__init__()
        self.out_features = out_features
        self.bias = bias

    def build(self, x: popxl.Tensor) -> popxl.Tensor:
        # add a state variable to the module
        w = self.add_input_tensor("weight", partial(np.random.normal, 0, 0.02, (x.shape[-1], self.out_features)),
                                  x.dtype)
        y = x @ w
        if self.bias:
            # add a state variable to the module
            b = self.add_input_tensor("bias", partial(np.zeros, y.shape[-1]), x.dtype)
            y = y + b
        return y


class Net(addons.Module):
    def __init__(self, cache: Optional[addons.GraphCache] = None):
        super().__init__(cache=cache)
        self.fc1 = Linear(512)
        self.fc2 = Linear(512)
        self.fc3 = Linear(512)
        self.fc4 = Linear(10)

    def build(self, x: popxl.Tensor):
        x = x.reshape((-1, 28 * 28))
        x = ops.gelu(self.fc1(x))
        x = ops.gelu(self.fc2(x))
        x = ops.gelu(self.fc3(x))
        x = self.fc4(x)
        return x

In [None]:
train_batch_size = 8
test_batch_size = 80
device = "ipu_hw" 
lr = 1e-3
epochs = 1

In [None]:
def train_program(batch_size, device, lr):
    ir = popxl.Ir()
    ir.replication_factor = 1
    with ir.main_graph:
        # Create input streams from host to device 
        img_stream = popxl.h2d_stream((batch_size, 28, 28), popxl.float32, "image")
        img_t = ops.host_load(img_stream) #load data
        label_stream = popxl.h2d_stream((batch_size, ), popxl.int32, "labels")
        labels = ops.host_load(label_stream, "labels")

        # Create forward graph
        facts, fwd_graph = Net().create_graph(img_t)
        # Create backward graph via autodiff transform
        bwd_graph = addons.autodiff(fwd_graph)

        # Initialise variables (weights)
        variables = facts.init()

        # Call the forward with call_with_info because we want to retrieve information from the call site
        fwd_info = fwd_graph.bind(variables).call_with_info(img_t)
        x = fwd_info.outputs[0] # forward output
        
        # Compute loss and starting gradient for backprop 
        loss, dx = addons.ops.cross_entropy_with_grad(x, labels)
        
        # Setup a stream to retrieve loss values from the host
        loss_stream = popxl.d2h_stream(loss.shape, loss.dtype, "loss")
        ops.host_store(loss_stream, loss)
        
        # retrieve activations from the forward
        activations = bwd_graph.grad_graph_info.inputs_dict(fwd_info)
        # call the backward providing the starting value for backprop and activations
        bwd_info = bwd_graph.call_with_info(dx, args=activations)
        
        # Adam Optimizer, with cache
        grads_dict = bwd_graph.grad_graph_info.fwd_parent_ins_to_grad_parent_outs(fwd_info, bwd_info)
        optimizer = Adam(cache=True)
        optimizer_step(variables, grads_dict, optimizer, lr)
            
    ir.num_host_transfers = 1
    return popxl.Session(ir,device), [img_stream, label_stream], variables, loss_stream

In [None]:
training_data, test_data = get_mnist_data(test_batch_size, train_batch_size)
train_session, train_input_streams, train_variables, loss_stream = train_program(train_batch_size,device,lr)

In [None]:
nr_batches = len(training_data)
for epoch in range(1, epochs + 1):
    print("Epoch {0}/{1}".format(epoch, epochs))
    bar = tqdm(training_data, total=nr_batches)
    for data, labels in bar:
        inputs : Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(zip(train_input_streams,
                                                                        [data.squeeze().float(),
                                                                         labels.int()]))
        loss = train_session.run(inputs)[loss_stream]
        bar.set_description("Loss:{:0.4f}".format(loss))

In [None]:
# get weights data
trained_weights_data_dict = train_session.get_tensors_data(train_variables.tensors)
train_session.device.detach()

In [None]:
def test_program(test_batch_size, device):
    ir = popxl.Ir()
    ir.replication_factor = 1
    with ir.main_graph:
        # Inputs
        in_stream = popxl.h2d_stream((test_batch_size, 28, 28), popxl.float32, "image")
        in_t = ops.host_load(in_stream)

        # Create graphs
        facts, graph = Net().create_graph(in_t)

        # Initialise variables
        variables = facts.init()

        # Forward
        outputs, = graph.bind(variables).call(in_t)
        out_stream = popxl.d2h_stream(outputs.shape, outputs.dtype, "outputs")
        ops.host_store(out_stream, outputs)
        
    ir.num_host_transfers = 1
    return popxl.Session(ir, device), [in_stream], variables, out_stream

In [None]:
# Create test program and test session
test_session, test_input_streams, test_variables, out_stream = test_program(test_batch_size,device)

# Copy trained weights to the program, with a single host to device transfer at the end
test_session.write_variables_data(dict(zip(test_variables.tensors,trained_weights_data_dict.values())))

In [None]:
nr_batches = len(test_data)
sum_acc = 0.0
with torch.no_grad():
    for data, labels in tqdm(test_data, total=nr_batches):
        inputs : Mapping[popxl.HostToDeviceStream, np.ndarray] = dict(zip(test_input_streams,
                                                                        [data.squeeze().float(),
                                                                         labels.int()]))
        output = test_session.run(inputs)
        sum_acc += accuracy(output[out_stream], labels)
print("Accuracy on test set: {:0.2f}%".format(sum_acc / len(test_data)))