# Torch High-order Derivates with create_graph

In [117]:
import torch 
v = -2.0
x = torch.tensor(v, requires_grad=True)

# function 
f = 2 * x**3
print(f"x={v},         f = {f}")

# first order
f.backward(retain_graph=True)
print(f"x={v},     df/dx = {x.grad}")

# second order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
df.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^2f/dx^2 = {x.grad}")

# thrid order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
ddf = torch.autograd.grad(df, x, create_graph=True)[0]
ddf.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^3f/dx^3 = {x.grad}")

# fourth order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
ddf = torch.autograd.grad(df, x, create_graph=True)[0]
dddf = torch.autograd.grad(ddf, x, create_graph=True)[0]
dddf.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^4f/dx^4 = {x.grad}")


x=-2.0,         f = -16.0
x=-2.0,     df/dx = 24.0
x=-2.0, d^2f/dx^2 = -24.0
x=-2.0, d^3f/dx^3 = 12.0
x=-2.0, d^4f/dx^4 = 0.0


You can check the result with the following equations

$$
\begin{aligned}
    \text{(function)~~~} &f = 2x^3 & \text{when~} x=-2 \Rightarrow 2 \cdot (-8) = -16 \\
    \text{(first-order)~~~} &\frac{df}{dx} = 6x^2 &  \text{when~} x=-2 \Rightarrow 6 \cdot 4 = 24 \\
    \text{(second-order)~~~} &\frac{d}{dx}\Big(\frac{df}{dx}\Big) = 12x & \text{when~} x=-2 \Rightarrow 12 \cdot 2 =24 \\
    \text{(third-order)~~~} &\frac{d}{dx}\Big(\frac{d^2f}{dx^2}\Big) = 12 &  \text{when~} x=-2 \Rightarrow 12 = 12 
\end{aligned}
$$

# Neural Network 

In [56]:
import torch.nn as nn 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(2,5)
        self.out = nn.Linear(5, 1)
    def forward(self, x):
        x = self.mlp(x**2)
        x = nn.functional.relu(x)
        x = self.out(x)
        return x 

In [28]:
net = Net()
input= torch.tensor([1.0]*2, requires_grad=True)

# first order 
f = net(input)
f.backward()
print(input.grad)
print("------------")

# gradient of weights 
input.grad.zero_() 
f = net(input)
gx = torch.autograd.grad(f, input, create_graph=True)[0]
for i in range(2):
    gx[i].backward(retain_graph=True)
    print(input.grad)
print("------------")

# gradient of weights to minimize the norm of the gradients
input.grad.zero_() 
f = net(input)
gx = torch.autograd.grad(f, net.mlp.weight, create_graph=True)[0]
print(gx)
(gx**2).sum().backward()
net.mlp.weight.grad

tensor([ 0.1236, -0.0950])
------------
tensor([0., 0.])
tensor([0., 0.])
------------
tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.2585, 0.2585]], grad_fn=<TBackward0>)


tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.5169, 0.5169]])

# MAML

Step 1. Compute the task specific updated weight 
$$
\theta' = \theta - \alpha \nabla_\theta L_{task} (f(\theta)) 
$$

Step 2. compute the meta-test with the computed weight
$$
\theta \leftarrow \theta - \nabla_\theta L_{{meta}} (f(\theta'))
$$

### Implementation 

1. autograd.grad($L$, $\theta$, $\mathrm{\textcolor{orange}{create\_graph=True}}$) computes $\nabla_\theta L$ as a tensor with computational graph

2. forward with weight - grad (obtained by 1) computes the step 2

In [23]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(2,5)
        self.out = nn.Linear(5, 1)
    def forward(self, x, fast_weights=None):
        if fast_weights is None:
            x = self.mlp(x)
            x = nn.functional.relu(x)
            x = self.out(x)
        else:  # same logit with given weight
            w, b = fast_weights[0], fast_weights[1]
            x = F.linear(x, w, b)
            x = nn.functional.relu(x)
            w, b = fast_weights[2], fast_weights[3]
            x = F.linear(x, w, b)
        return x 

net = Net()
lr=1e-3
meta_optim = torch.optim.Adam(net.parameters(), lr=lr)


meta_loss = 0
num_tasks = 3
for task in range(num_tasks):
    # Step 1 : compute one_step weight
    input= torch.tensor([(task+1.0)]*2)
    f = net(input)
    grad = torch.autograd.grad(f, net.parameters(), create_graph=True)    # Step 1 Core
    fast_weights = list(map(lambda p: p[1] - lr * p[0], zip(grad, net.parameters())))

    # Step 2 : compute the loss with one step weight
    meta_input= torch.tensor([(task+2.0)]*2)
    test_loss = net(input, fast_weights)      # Step 2 Core
    meta_loss += test_loss 

meta_loss /= num_tasks

# optimize theta parameters
meta_optim.zero_grad()
meta_loss.backward()
meta_optim.step()

# Hessian 

Given the Hessian Matrix $H(\theta)_{N\times N}$ and vector $r$, 

the Hessian Vector product is computed by the gradient of the dot product between gradient and the vector. 

$$
\begin{aligned}
H(\theta) \cdot r &= \frac{d^2}{d\theta^2} \Big(  L(\theta) \Big) \cdot r \\
&= \frac{d}{d\theta} \Big( \frac{d}{d\theta} L(\theta) \cdot r \Big)

\end{aligned}
$$


#### Step 1. Compute the gradient 

#### Step 2. Compute the gradient times vector

In [43]:
net = Net()
params = [p for p in net.parameters() if len(p.size()) > 1]
N = sum(p.numel() for p in params)
print(f"Number of parameters : {N}")
print("Params", params)

# Compute the gardients
input= torch.tensor([(task+1.0)]*2)
f = net(input)
grad = torch.autograd.grad(f, inputs=params, create_graph=True)    # Step 1 Core

# Hessian Vector Product
prod = torch.autograd.Variable(torch.zeros(1)).type(type(grad[0].data))
vec = torch.rand_like(prod)
for (g,v) in zip(grad, vec):
    prod = prod + (g * v).cpu().sum()  # Step 2 Core
    
prod.backward()  # Now the params.grad stores the Hessian 

print("----------------------")
print("Hessian Vector Product")
for p in params:
    print(p.grad)

Number of parameters : 15
Params [Parameter containing:
tensor([[ 0.6361,  0.5278],
        [ 0.1699,  0.4604],
        [-0.0896, -0.2206],
        [-0.5090, -0.5044],
        [-0.5976, -0.0930]], requires_grad=True), Parameter containing:
tensor([[ 0.2209, -0.1456, -0.1973,  0.1386,  0.1548]], requires_grad=True)]
----------------------
Hessian Vector Product
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
tensor([[4.7362, 4.7362, 0.0000, 0.0000, 0.0000]])
