## Neural Networks (2023-2024)
https://sites.google.com/uniroma1.it/neuralnetworks2023/

This is a short notebook highlighting the use of [torch.func](https://pytorch.org/docs/stable/func.html), a PyTorch module that provides a functional interface to the framework and transformations mirroring those found in [JAX](https://jax.readthedocs.io/en/latest/) (vmap, jit, ...). The notebook wants to highlight the difference in working in OOP or in functional paradigms, and showing examples where the functional approach is simpler and more elegant due to the possibility of easily chaining functional transformations.

In [None]:
import torch
from torch import nn, func

In [None]:
# Define a simple PyTorch model
net = nn.Sequential(
    nn.Linear(3, 4),
    nn.ReLU(inplace=True),
    nn.Linear(4, 5),
    nn.Softmax(1)
)

In [None]:
# Parameters can be accessed through iterators
net.parameters()

<generator object Module.parameters at 0x7fe0d7305a10>

In [None]:
# Note: from a functional perspective, .backward() is very strange because most of its
# behaviour is hidden, and there are many side effects.
x = torch.randn((10, 3))
net(x).sum().backward()

In [None]:
net[0].weight.grad

tensor([[ 6.1237e-09,  3.7508e-09, -4.9393e-11],
        [ 7.4083e-09, -1.0464e-09,  1.0326e-09],
        [-4.3719e-10, -3.8304e-09,  6.3779e-11],
        [ 1.2196e-09,  1.0686e-08, -1.7792e-10]])

In [None]:
# To move to a functional representation, we first extract the parameters,
# then we convert the object instance into a pure function taking as input
# both x and the parameters.
w = dict(net.named_parameters())
net_fcn = lambda w, x: func.functional_call(net, w, x)

In [None]:
# func.grad is an operator (higher-order function) that returns a new function
# that evaluates the gradient.
net_grad_fcn = func.grad(lambda w, x: net_fcn(w, x).sum())

In [None]:
# In this approach, gradients are returned directly as output of the function.
net_grad_fcn(w, x)

{'0.weight': tensor([[ 6.1237e-09,  3.7508e-09, -4.9393e-11],
         [ 7.4083e-09, -1.0464e-09,  1.0326e-09],
         [-4.3719e-10, -3.8304e-09,  6.3779e-11],
         [ 1.2196e-09,  1.0686e-08, -1.7792e-10]], grad_fn=<TBackward0>),
 '0.bias': tensor([ 5.4171e-09,  1.0515e-08,  2.0189e-09, -5.6320e-09],
        grad_fn=<ViewBackward0>),
 '2.weight': tensor([[1.6468e-08, 1.1425e-08, 5.0754e-09, 8.2010e-09],
         [1.3497e-08, 7.5688e-09, 3.1503e-09, 5.0904e-09],
         [3.4575e-08, 1.6428e-08, 6.0394e-09, 9.7587e-09],
         [1.7959e-08, 8.6938e-09, 3.2041e-09, 5.1774e-09],
         [1.8011e-08, 1.0379e-08, 4.3835e-09, 7.0831e-09]],
        grad_fn=<TBackward0>),
 '2.bias': tensor([3.3357e-08, 2.4654e-08, 5.7422e-08, 2.9984e-08, 3.3398e-08],
        grad_fn=<ViewBackward0>)}

In [None]:
# Suppose we add a new dimension on our input (multi-view input): we have a batch of 10
# elements, each of which is composed of 5 vectors of dimension 3. Note that in this case,
# Linear works as expected, but softmax is normalizing across a wrong axis.
x = torch.randn((10, 5, 3))
net(x)[0, 0].sum()

tensor(0.9590, grad_fn=<SumBackward0>)

In [None]:
# We can use vmap to vectorize functions (apply them in parallel) over new axes.
net_vect_fcn = func.vmap(net_fcn, in_dims=(None, 1), out_dims=1)

In [None]:
net_vect_fcn(w, x)[0, 0].sum()

tensor(1.0000, grad_fn=<SumBackward0>)

In [None]:
# We can combine vmap and grad as many times as needed.
net_grad_grad_fcn = func.grad(lambda w, x:
                         func.grad(lambda w, x: net_vect_fcn(w, x).sum())
                         )