In [1]:
import torch
import copy
import torch.nn as nn
from torch.nn.utils.stateless import functional_call

# Torch NN Stateless 
- Turn stateful nn modules into to stateless functional form
- Interesting discussion can be found here: https://github.com/pytorch/pytorch/issues/49171
- https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/stateless.py
Code interesteing read

In [2]:
class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.foo = torch.zeros(1)
    def forward(self, a):
        self.foo =  self.foo + a

In [3]:
a = {'foo': torch.zeros(())}
mod = Foo()  # does self.foo = self.foo + 1
print(mod.foo)  # tensor(0.)
functional_call(mod, a, torch.ones(()))
print(mod.foo)  # tensor(0.)
print(a['foo'])  # tensor(1.)

tensor([0.])
tensor([0.])
tensor(1.)


In [4]:
# Create module and call it functional
lin_mod = nn.Linear(10,15)
param_dict = dict(lin_mod.named_parameters())
output = functional_call(lin_mod, param_dict, torch.randn(1,10))
# Should be same as calling lin_mod 1x15
print(output.shape)

# what if we wanted to reuse module but diffent weights
new_params = {"weight": torch.randn(20,10), "bias": torch.randn(20)}
output = functional_call(lin_mod, new_params, torch.randn(1,10))
# Same computation diffent tensors should be 1x20
print(output.shape)

torch.Size([1, 15])
torch.Size([1, 20])


https://pytorch.org/docs/master/generated/torch.nn.utils.stateless.functional_call.html#torch.nn.utils.stateless.functional_call

In [11]:
jitted = torch.jit.trace(lin_mod, example_inputs=torch.randn(20,10))
functional_call(jitted,param_dict, torch.randn(1,10))

RuntimeError: The stateless API can't be used with Jitted modules

In [15]:
# Forward mode AD
#Currenlty can not store dual tensors as nn.paramaters so this provides a work around

# We need a fresh module because the functional call requires the
# the model to have parameters registered.
import torch.nn as nn
import torch
import torch.autograd.forward_ad as fwAD

model = nn.Linear(5, 5)
input = torch.randn(16, 5)

params = {name: p for name, p in model.named_parameters()}
tangents = {name: torch.rand_like(p) for name, p in params.items()}

# Old way of doing things
with fwAD.dual_level():
    for name, p in params.items():
        delattr(model, name)
        setattr(model, name, fwAD.make_dual(p, tangents[name]))

    out = model(input)
    jvp = fwAD.unpack_dual(out).tangent

# Use stateless to replace paramaters with dual paramaters

dual_params = {}
with fwAD.dual_level():
    for name, p in params.items():
        # Using the same ``tangents`` from the above section
        dual_params[name] = fwAD.make_dual(p, tangents[name])
    out = functional_call(model, dual_params, input)
    jvp2 = fwAD.unpack_dual(out).tangent

# Check our results
assert torch.allclose(jvp, jvp2)