# Tutorial 0: predictive coding via **pcax**
In this notebook, you will learn how to use pcax to build arbitrary predictive coding networks. We will focus on a fully connected network trained on MNIST: the "hello world" of neural networks and deep learning.
The current version of the library is 0.5.0, which had some minor but impactful changes compared to v0.3.0. Be sure to use the right version of pcax, or you'll run into a bunch of syntax errors!

Good luck!

## Part 0: Importing dependencies
pcax is based on JAX and combines the ideas behind [equinox](https://github.com/patrick-kidger/equinox) and [objax](https://github.com/google/objax), two deep-learning libraries built on top of JAX. In particular, *equinox* is currently a dependency as we use the `nn` modules provided by it.


The library is divided in three modules:
- *core*: defines the basic building blocks of pcax, unrelated to predictive coding itself.
- *pc*: here lies all the predictive coding implementation, which will probably keep changing and being updated as more discoveries are made in the field.
- *nn*: simply contains the typical layers you could expect from a deep learning library, which are currently built as a wrap around *equinox* layers.
- *utils*: various tools to ease the development of complex applications by providing shurtcuts for commond implementation patterns.

In [1]:
# Core dependencies
import jax
import optax

# pcax

# Importing pcax is equivalent to importing pcax.pc
# which contains the functionalities to build a predictive coding network
import pcax as px

# A filter is core object of pcax. It is used to filter which parameters in a network
# should undergo a specific JAX transformation (e.g. jax.grad, jax.jit, etc.).
# Consequently, despite being a core object, pcax offers a shortuct to use it: pcax.f.

# pcax.nn contains the neural network modules (e.g. Conv2d, Linear, etc.)
# at the moment only Linear is implemented, but more coming soon!
import pcax.nn as nn

# pcax.utils contains some useful utilities to train and use the network
import pcax.utils as pxu

# finally we can import the library's core which is simply a wrapper around JAX,
# unrelated to predictive coding. useful for some advanced configurations.
# We will not use it in this tutorial. 
# import pcax.core as pxc

In this example, we will also use the following dependencies:

In [2]:
from typing import Callable, Optional
import timeit
import os
import numpy as np
from torchvision.datasets import MNIST

In [3]:
# By default JAX will use all the availble memory on the target GPU device.
# This can be beneficial for performance but can also result in a device being completely
# unsuable by other people. The following flag disables this behaviour.
# PLEASE REMEMBER TO ALWAYS SET THIS FLAG when working on shared machines.

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# As with PyTorch, we can define specify which gpu to use:
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

## Part 1: Defining a Model
Defining a basic pcax model is very straightforward: simply interpone in the forward call a `px.Node` between any two `nn.Layer`s (e.g. linear layers followed by an activation function). To do so first define them in the `__init__`, no arguments are required for basic usage!

In the `__init__`, we also define the activation function we are going to use and the specify we do not want to update the `x` of the last node, since it will contain the label we want our network to learn. It is, in fact, common practice to freeze the target nodes during training. Note this model does not directly support inference on the inputs as there is no node to store the input `x` of the forward call.

In [4]:
# A model is defined in a way that is similar to PyTorch. In particular:
# - we inherit from px.EnergyModule;
# - we define the activation functions, nodes, and layers in the __init__ method;
# - we define the forward pass in the __call__ method. By default, the forward pass will also be used to
#   initialize the network as it will be described later.

class Model(px.EnergyModule):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int,
                 output_dim: int,
                 nm_layers: int,
                 act_fn: Callable[[jax.Array], jax.Array]) -> None:
        super().__init__()

        self.act_fn = act_fn

        # This is quite standard. We define the layers and nodes (one node to follow each layer).
        self.layers = [nn.Linear(input_dim, hidden_dim)] + [
            nn.Linear(hidden_dim, hidden_dim) for _ in range(nm_layers - 1)
        ] + [nn.Linear(hidden_dim, output_dim)]

        self.nodes = [px.Node() for _ in range(nm_layers + 1)]

        # We normally use the x of the last layer as the target, therefore we don't want to update it
        # during training.
        self.nodes[-1].x.frozen = True

    # Here things are a bit different. __call__ accepts an optional target t (used during training),
    # which is used to set the value of the last node to the target label.
    def __call__(self,
                 x: jax.Array,
                 t: Optional[jax.Array] = None):

        # !!! IMPORTANT !!!
        # Each (pc) layer contains a cache the stores the important intermediate values computed in the forward pass.
        # By default, these are the incoming activation 'u', the node value 'x' and the energy 'e'.
        # You can access them by using the [] operator, e.g., self.pc["x"].

        # We forward x through the network, clamping the last node to the target label.
        # We use the activation function defined in the __init__ method for all the layers but the last one.
        # Notice how the input x is passed directly to the first layer before being saved to any node:
        # this PCN works only in forward mode, therefore we don't need to save the input.
        # Finally, note that the syntax `node(x)["x"]` is, by default, a shortcut for the following:
        # 
        # node["u"] = x
        # if node.is_init:
        #     node["x"] = node["u"]
        # return node["x"]
        #
        # which means that, by default, the forward pass is also used to initialize the node values.
        for node, layer in zip(self.nodes[:-1], self.layers[:-1]):
            x = node(self.act_fn(layer(x)))["x"]

        x = self.nodes[-1](self.layers[-1](x))["x"]

        # If the have a target we use it to clamp the last node.
        if t is not None:
            self.nodes[-1]["x"] = t

        # The output of the network is the activation received by the last layer (since its x is clamped to the label).
        return self.nodes[-1]["u"]

Let's now define the training parameters we are going to use:

In [5]:
params = {
    "batch_size": 256,
    "x_learning_rate": 0.01,
    "w_learning_rate": 1e-3,
    "num_epochs": 4,
    "hidden_dim": 128,
    "input_dim": 28 * 28,
    "output_dim": 10,
    "seed": 0,
    "T": 4,
}

Let's now define the dataloaders we'll need to train and test our model.

In [6]:
# This is all standard and uses PyTorch's datasets and dataloaders. We are assuming cuda is available to set the dataloaders' parameters.

# We'll train with the standard pc energy function, that is, the sum of the squared differences between the node values and the target.
# Therefore, we need to convert the targets to one-hot vectors.
def one_hot(t, k):
    return np.array(t[:, None] == np.arange(k), dtype=np.float32)


# Used to convert the square 0-255 images to 0-1 float vectors.
class FlattenAndCast:
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=np.float32) / 255.0)


train_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=True,
)
train_dataloader = pxu.data.TorchDataloader(
    train_dataset,
    batch_size=params["batch_size"],
    num_workers=8,
    shuffle=True,
    persistent_workers=True,
    pin_memory=True,
)


test_dataset = MNIST(
    "/tmp/mnist/",
    transform=FlattenAndCast(),
    download=True,
    train=False,
)
test_dataloader = pxu.data.TorchDataloader(
    test_dataset,
    batch_size=params["batch_size"],
    num_workers=8,
    shuffle=False,
    persistent_workers=True,
    pin_memory=True,
)

## Part 2: Where the fun begins
As mentioned, pcax is based on JAX, which is a functional framework. Consequently, in order to offer a simple object oriented interface to it, there are some compromises to be made and strict patterns to follow. In particular, JAX requires to keep track of all the tensors involved in a computation. pcax does this by tracking all the `Params` contained in any `Module` passed to a function **as a keyword argument**. So be careful when defining a function... but we'll go into more details later.

The way JAX works is by transforming simply functions to achieve complex behaviours. Let's see what this means!

### Part 2.1: Defining a function
Let's see how a function involving a `px.Module` (`px.EnergyModule` is a `px.Module`) is defined. We'll proceed step by step (so only the last version is the correct one).

Consider the function `predict(x, t)` which simply calls the forward pass of the model (we also include the optional target `t` as well since it is needed during training).
In normal python, `predict` should look similar to the following:

In [7]:
def predict(model, x, t = None):
    return model(x, t)

**However**, since this function uses a `px.Module` (i.e., `model`), we need to treat it differently and explicitely pass it as a keyword argument. There's not much too add, it needs to be done :P
So `predict` should look like this:

In [8]:
# '*' means that all the following arguments must be passed by name.
# It is standard Python syntax and it can be omitted if you remember to pass the arguments by name.
# However, defining the function in this way makes it easier to use avoid this simple mistake.
def predict(x, *, t = None, model = None):
    return model(x, t)

Unfortunately, we are not done. There's one more convention that need to be followed, and it's the difference between static and dynamic parameters (another key JAX concept). In general, you can think at it in this way: a tensor (even if it contains a single number) is a dynamic value, anything else is static. In details, a dynamic parameter is a parameter that *does not* alter the flow of execution of a program, but only its output. Think about the following function:

```python
def op(a: float, b: float, op: str):
    if op == '+':
        return a + b
    elif op == '-':
        return a - b
```

Here, different `a` and `b` values will produce different results but the required computation dependes exclusively on `op` (given that the type of `a` and `b` is fixed). Thus, `a` and `b` are dynamic parameters, `op` is static.
This, however, means that we cannot use the value of `a` or `b` to alter the flow of the function. For example, the following is not valid:

```python
def op(a: float, b: float, op: str):
    if op == '+':
        c = a + b
    elif op == '-':
        c = a - b

    if c < 0:
        return c + 1
    else:
        return c + 2
```
Here we condition on `c` which, being a product of dynamic parameters, is a dynamic parameter as well. This will result in a compilation error.
There are primitives that allow us to dynamically execute different pieces of code based on dynamic values and we will see them in another tutorial.
To conclude this section, we always have to decide if a parameter is static (i.e., we want to use it as a *flag* to compile the same code multiple times with different behaviours) or dynamic (it is an actual parameter of the function we define). What happens is that static parameters are hardcoded into the function when it is compiled, and using different static values for the same parameter will result in a recompilation of the function.

pcax follows the following convention: positional arguments are dynamic, keyword arguments are static. Thus, we would have to do the following to call `op`: `op(1.0, 2.0, op="+")`. Note that this is necessary only for JAX-transformed functions (we'll look at them in a moment).

All of this means that `predict` should be defined as:

In [9]:
# again note that '*' is not necessary but it makes it easier to avoid mistakes.
def predict(x, t = None, *, model = None):
    return model(x, t)

And now it is ready to be used by pcax!

The next step is **vectorization** (used to achieve **batching** in this case). JAX (and therefore pcax) defines each computation on a tensor assuming it contains a single sample (that is, there is no batch dimension as it happens, for instance, in Pytorch). Therefore, in the function above, `x` is supposed to be an array with shape [784,] (e.g., a flattened MNIST image) and `t` an array with shape [10,] (the corresponding one-hot encoded label). If we want `predict` to be able to work on batched input (i.e., `x` with shape [n, 784]), we need to *vectorize* the function such that the `predict` computation will be repeated along the batch dimension. In JAX, this is achieved using `vmap`. In pcax, given its inspiration from objax, using `Vectorize` (this may change in the future as the library evolves past its original objax-like formulation).

`pcax.utils` offers a list of functions and decorators that can be used to easily apply these transformations to any function. Since we have already imported it as `pxu`, we can "batch" `predict` by adding the following decorator to it:

In [10]:
@pxu.vectorize(px.f(px.NodeParam, with_cache=True), in_axis=(0, 0))
def predict(x: jax.Array, t: Optional[jax.Array] = None, *, model):
    return model(x, t)

There is quite a lot to unpack:
- first of all, now predict can be exclusively be used with batched input, which means that both `x` and `t` must share the same size for dimension `0`.
- `pxu.vectorize` is a thin wrapper around `pxc.Vectorize`. The first argument specifies which parameters of the models passed to the function we want to vectorize using `px.f`. As of now, for predictive coding, these are, always and only, the node values, which, in pcax, are identified as `px.NodeParam`. In fact, each data sample should have its own exclusive set of node values (on the contrary, weight values are shared between different batch samples and we do not want to have them vectorized). `with_cache` specifies that you want to capture not only the node values, but also their cached transformations (this for example includes the activation `u` produced by each layer and received by each node). Again, as of now, you will probably, always and only, have to specify it when using vectorize to batch a function.
- `in_axis` specifies how to treat the *positional* arguments of the function (note how there are only two values in the tuple despite the function taking three: model, being a keyword, and thus static, argument, is automatically ignored). `0` means that the function will be vectorized over the first dimension of the corresponding argument while passing `None` means that the argument will be ignored by the vectorization and passed as it is to the function. In the case of `predict`, both `x` and `t` will be batches of sample data with shape [batch_dim, ...] so we pass `0` for both of them.
- by default `pcu.vectorize` expects a single output vectorized along dimension `0`, which is exactly what we have here.

You can look at the definition of `px.f` for instructions and tips on how to use it.

Let's go to the next function we need: the loss function.  
Similarly to how in Pytorch you compute the gradients from the loss by calling `loss.backward()`, here we compute the gradients by transforming the loss function such that is also output the gradients.

In predictive coding, the standard loss function simply computes and returns the model's energy:

In [11]:
@pxu.vectorize(px.f(px.NodeParam, with_cache=True), in_axis=(0, 0), out_axis=("sum",))
def loss(x: jax.Array, t: Optional[jax.Array] = None, *, model):
    model(x)
    return model.energy()

Let's observe a few things:
- we pass `t` (the target label) and compute `y` (the output of the model) as it's standard practice, however, here, they are not strictly necessary: we don't need the model's output to compute the error (as it is used already when computing the energy) and we assume that the target has already been set to the `x` of the last node when we initialized the node values (so, again, it is already included in the energy computation).
- the first two parameters of `px.vectorize` are the same of `predict`, however here we add a modifier for the output (notice that, even if we have a single output, `out_axis` is specified using a tuple). We use `"sum"` to specify that the output of the function should be the summed over the batch dimension. We do this because, as it happens in Pytorch, the loss must be a single floating point value (we don't use the mean since we want the total energy coming from each error node, not their average).
- `model.energy()` computes and caches the energy value for each node in `model`, returning their sum.

Now, we want to define two different "backward" functions, one for the *x step* (in which we update the value nodes) and one for the *w step* (in which we update the weights). They are identical except for the fact that they compute gradients with respect of different elements.  
`pxu.grad_and_values` transform a function such that is outputs the gradients with respect to the specified variables. Note how the filter used in either case does not specify `with_cache=True` as we do not want (and it would not make much sense) to compute the gradient with respect of cached intermediate values.

In [12]:
# train_x computes the gradients with respect to the node values that are not frozen.
train_x = pxu.grad_and_values(
    px.f(px.NodeParam)(frozen=False), # px.f.__call__(**kawrgs) is used to apply a filter, selecting all the variables with the specified properties
)(loss)

# train_w computes the gradients with respect to the weights.
train_w = pxu.grad_and_values(
    px.f(px.LayerParam)
)(loss)

This covers almost everything you need to do about defining a function the operates on one (or more) `px.Module`. We can now define a model. This would normally be done inside the `if __name__ == "__main__"`, but to provide a more linear tutorial we introduce it here.

In [13]:
model = Model(28 * 28, params["hidden_dim"], 10, 2, jax.nn.tanh)

### Part 2.2: Optimizers
In pcax (as it is inspired from Objax), optimizers are `px.Module`s as well. Consequentely, they need to be treated similarly to how we used `model`. In particular, they must be passaed as keywords arguments to any function that uses them.
pcax offers a single `px.Optim` class that allows `px.Module`s to interact with `optax` optimizers (the most common JAX library for optimizers). As you can see from the following example, you simply have to specify which `optax` optimizer to use for which subset of the `px.Module` variables. In this case, similarly to how we defined the two loss functions, we defined an *x optimizer* and a *w optimizer*.

In [14]:
# dummy run to init the optimizer parameters
with pxu.train(model, np.zeros((params["batch_size"], 28 * 28)), None):
    optim_x = pxu.Optim(
        optax.sgd(params["x_learning_rate"]),
        model.parameters().filter(px.f(px.NodeParam)(frozen=False))
    )
    optim_w = pxu.Optim(
        optax.adam(params["w_learning_rate"]),
        model.parameters().filter(px.f(px.LayerParam)),
    )

You may ask yourself what is first line for?

It may be the case that your optimizers defines some parameters that are linked to the individual trainable parameter inside your model. For instance, you could choose to use `optax.adam` for `optim_x` (not recommended, for all we know you should stick with stateless optimizators for the value nodes). In this case, the optimizer needs to instantiate the `adam` parameters and, thus, it requires to know the shape of all the `px.NodeParams`s inside your model. However, if we do not run the model at least once, all the node values will be empty, since they are lazily created. Furthermore, the shape of the value nodes depends on the batch size (since, remember, we want different value nodes for each different input), so we need to perform a dummy run on a dummmy input with the same shape of the samples we are gonna train on to correctly initialize all the parameters inside the model. Only then we can safely create the two optimizers. (Note that this requires `batch_size` to be constant throught the program. To guarantee this, by default, the dataloaders have `drop_last=True`. Unless you know what you're doing, do not attempt to modify that.)

It actually doesn't really matter what you pass as `x` and `t`, but only that their shape matches the one of the true input batches (but in this way you can see how you can pass `None` as the target of the predict function).

`px.train` (and `px.eval`) does exactly this: under the hood it behaves like the `predict` function we defined previously (if you pass an arguments other than `model`). **NOTE**: This is one of the few, if not the only, "hidden behaviours" of pcax: `px.train` (and `px.eval`) calls a batched version of `model.__call__` if you provide any arguments to it. In order to customise its batching behaviour `px.train` accepts all the kwargs you would normally pass to `pxu.vectorize` (all the other args will be passed to `model.__call__`).

In addition to this, `px.train`/`px.eval` set the internal status flag of the model such that either `is_train`/`is_eval` returns `True`.

As we have already hinted, it is important to know that `px.Module` uses an internal caching system to store intermediate values for later computations (such as a layer's activation `u`, necessary to compute a node's energy). If you need to run multiple forward passes on the same input (like during training), you have to enclose each one of them into a `px.step` context manager. This ensures that the cache is cleared and the next forward pass with compute new updated values instead of reusing previous ones.

### Part 2.3: Training the model
We are almost done. We have all the ingredients to train the model, we just need to assemble them in the training and evaluation functions. There's only one core concept to introduce: *jitting*. *Just In Time* compilation allows the program to get compiled and optimized (it's the reason we have so many constraints on how code should be written). We jit the computations executed on a single batch as we want to compile as much code as possible, and `train_on_batch` represents the single largest repeated operation occurring during training.

In [15]:
@pxu.jit()
def train_on_batch(x, y, *, model, optim_w, optim_x):
    # We are working on the input x, so we initialise the internal nodes with it.
    with pxu.train(model, x, y) as (y_hat,):
        for i in range(params["T"]):
            # Each forward pass caches the intermediate values (such as activations and energies), so we can use them to compute the gradients.
            # px.step takes care of managing the cache.
            with pxu.step(model):
                g, (v,) = train_x(x, y, model=model)
                optim_x(g)

        with pxu.step(model):
            g, (v,) = train_w(x, y, model=model)
            optim_w(g)

A few remarks:
- this is the standard pc training procedure:
    - we initialize the node values using `px.train` (which by default uses forward initialization)
    - we repeat the *x-update* step for `T` times and then we perform a single *w-update*.
- `optim_*(g)` updates the node values/weights of the model, therefore all the computed cached values must be cleared. That's why we enclose each trainig operation in a `px.step`.
- you can see that the values returned by `pxu.train` (and `predict`) and `train_*` (the first value returned by `train_*` are the computed gradients, and then there are the actual return values of the function) are all tuples, even if a single value is acutally returned by the orginal functions (before being transformed). Just something to keep in mind.

Let's see the remaining functions.

In [16]:
@pxu.jit()
def evaluate(x, y, *, model):
    # As in train_on_batch, we initialise the internal nodes with the input x. By doing so we also get the model's output y_hat.
    with pxu.eval(model, x, y) as (y_hat,):
        return (y_hat.argmax(-1) == y.argmax(-1)).mean()


def epoch(dl):
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        train_on_batch(x, y, model=model, optim_w=optim_w, optim_x=optim_x)

    return 0


def test(dl):
    accuracies = []
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        accuracies.append(evaluate(x, y, model=model))

    return np.mean(accuracies)

And the main body:

In [17]:
if __name__ == "__main__":
    t = timeit.timeit(lambda: epoch(train_dataloader), number=1)
    print("Compiling + Epoch 1 took", t, "seconds")

    # Time of an epoch (without jitting)
    t = timeit.timeit(lambda: epoch(train_dataloader), number=params["num_epochs"]) / params["num_epochs"]
    print("An Epoch takes on average", t, "seconds")

    print("Final Accuracy:", test(test_dataloader))

Compiling + Epoch 1 took 2.450080598006025 seconds
An Epoch takes on average 1.141245722246822 seconds
Final Accuracy: 0.96033657


If everything is all right, you should get a final accuracy of ~96%. On pssr2, you should also get a time per epoch of ~1.20. The training time is actually heavily bottlenecked by the data transfer between CPU and GPU, that's why we are using 8 workers in each dataloader. So depending on your configuration the final speed my change.

**However**, there is one more step that we can do to improve performance: using `snapshots`. JAX jitting works by keeping track of all the dynamic and static arguments you pass to a jitted function: if the type/shape of a dynamic argument or value of a static one changes the function needs to be recompiled. Of course, this is the correct intended behaviour. However, checking for these changes is computationally expensive. pcax let the user specifies that some static arguments (e.g., the modules we use) are not gonna change between calls of the same function (we are not magically gonna add new parameters to our model while its training...) by creating a snapshot of them at a specific state.

In [18]:
train_fn = train_on_batch.snapshot(model=model, optim_w=optim_w, optim_x=optim_x)
test_fn = evaluate.snapshot(model=model)

Now we can use these functions inside `epoch` and `test` and speed everything up:

In [19]:
def epoch(dl, train_fn):
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        train_fn(x, y)

    return 0


def test(dl, test_fn):
    accuracies = []
    for batch in dl:
        x, y = batch
        y = one_hot(y, 10)

        accuracies.append(test_fn(x, y))

    return np.mean(accuracies)

Since we want a fair training comparison, let's reset the model parameters before training it again. This can be done using `px.move`:

In [20]:
# Create a newly randomly initialised model and copy its layer parameters to the trained model.
px.move(
    Model(28 * 28, params["hidden_dim"], 10, 2, jax.nn.tanh).parameters().filter(px.f(px.LayerParam)),
    model.parameters().filter(px.f(px.LayerParam))
)

{"(Model).layers.(SequenceKey(idx=0), 'nn.(FlattenedIndexKey(key=0),)')": LayerParam(Array([[ 0.01528633,  0.00420511,  0.02517964, ...,  0.02357602,
         -0.01924622, -0.00733546],
        [-0.03351723, -0.02838478, -0.02506822, ..., -0.0231735 ,
          0.01136308, -0.01726058],
        [-0.02733533, -0.01074283,  0.02663257, ..., -0.02865989,
         -0.02428304, -0.0295321 ],
        ...,
        [ 0.00344639, -0.00684377,  0.0313066 , ...,  0.02332941,
         -0.01089154,  0.02777496],
        [ 0.01860064,  0.00833698, -0.01173914, ...,  0.01967683,
         -0.02953144, -0.01178022],
        [-0.0023311 ,  0.00346158,  0.00269954, ...,  0.00920143,
         -0.03562053, -0.03415234]], dtype=float32), reduce=reduce_none),
 "(Model).layers.(SequenceKey(idx=0), 'nn.(FlattenedIndexKey(key=1),)')": LayerParam(Array([-0.01578693,  0.01510274, -0.00815019,  0.01473249,  0.01042687,
         0.03170029,  0.028273  ,  0.02313706,  0.00070101,  0.02785118,
        -0.01413033,  0

The main body now looks like this:

In [21]:
if __name__ == "__main__":
    t = timeit.timeit(lambda: epoch(train_dataloader, train_fn), number=1)
    print("Compiling + Epoch 1 took", t, "seconds")

    # Time of an epoch (without jitting)
    t = timeit.timeit(lambda: epoch(train_dataloader, train_fn), number=params["num_epochs"]) / params["num_epochs"]
    print("An Epoch takes on average", t, "seconds")

    print("Final Accuracy:", test(test_dataloader, test_fn))

Compiling + Epoch 1 took 1.5790030059870332 seconds
An Epoch takes on average 0.7215498482692055 seconds
Final Accuracy: 0.96033657


The training time should be reduced by ~40%, to around ~0.7 seconds per epoch. The compiling happens again as we have to compile the newly created snapshot.