In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchopt
from torch.func import grad, grad_and_value, vmap

Functional programming
순수함수로 구현해야함
- 동일한 입력 인자들에 대해서 동일한 결과를 리턴해야함
- 사이드 이펙트가 없어야함. 입력 인자들을 코드 내부에서 수정해서는 안됨. 

In [2]:
class SimpleNN(nn.Module):
    def __init__(
        self,
        num_layers: int = 1,
        num_neurons: int = 5,
    ) -> None:
        """Basic neural network architecture with linear layers
        
        Args:
            num_layers (int, optional): number of hidden layers
            num_neurons (int, optional): neurons for each hidden layer
        """
        super().__init__()
        
        layers = []

        # input layer
        layers.append(nn.Linear(1, num_neurons))

        # hidden layers with linear layer and activation
        for _ in range(num_layers):
            layers.extend([nn.Linear(num_neurons, num_neurons), nn.Tanh()])

        # output layer
        layers.append(nn.Linear(num_neurons, 1))

        # build the network
        self.network = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x.reshape(-1, 1)).squeeze()

torch의 nn.Module 의 경우 stateful 함

In [10]:
import torch

x = torch.randn([])
model = SimpleNN() # constructed above
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
# modify the state of the model
# by applying a single optimization step
out1 = model(x)
true = torch.tensor(1.0)
loss = F.mse_loss(out1, true)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# recompute the output with exactly the same input
out2 = model(x)
assert not torch.equal(out1, out2)

In [11]:
model

SimpleNN(
  (network): Sequential(
    (0): Linear(in_features=1, out_features=5, bias=True)
    (1): Linear(in_features=5, out_features=5, bias=True)
    (2): Tanh()
    (3): Linear(in_features=5, out_features=1, bias=True)
  )
)

 ## 순수함수 형태의 stateless 연산

`functional_call(model, params, (arg1, arg2, ...))`

In [15]:
import torch
from torch.func import functional_call

x = torch.randn([]) # random input data
model = SimpleNN() # constructed above
params = dict(model.named_parameters()) # model parameters

# make a functional call to the model above
out = functional_call(model, params, (x,))
out

tensor(-0.0357, grad_fn=<SqueezeBackward0>)

In [26]:
grad_fn = grad(model)
params = tuple(model.named_parameters())
grad_values = grad_fn(x)
grad_values

tensor(-0.0926, grad_fn=<ViewBackward0>)

In [27]:
def mse_loss(params, x, t):
    pred = functional_call(model, params, (x,))
    loss = (t - pred) ** 2
    return loss

In [28]:
loss_grad_fn = grad(mse_loss)
params = tuple(model.named_parameters())
grad_values = loss_grad_fn(params, x, true)
grad_values

ValueError: Thing passed to transform API must be Tensor, got <class 'str'>

In [29]:
def make_functional_fwd(_model):
    def fn(data, parameters):
        return functional_call(_model, parameters, (data,))
    return fn

In [34]:
model_func = make_functional_fwd(model) # functional forward
params = tuple(model.named_parameters())

In [35]:
grad_params = grad(model_func, argnums=1)(x, params)

ValueError: Thing passed to transform API must be Tensor, got <class 'str'>

In [36]:
optimizer = torchopt.adam()

In [39]:
import functorch

In [40]:
fmodel, fparams = functorch.make_functional(model)

  warn_deprecated('make_functional', 'torch.func.functional_call')


In [43]:
optimizer = torchopt.Adam(model.parameters())