# Torch High-order Derivates with create_graph

In [117]:
import torch 
v = -2.0
x = torch.tensor(v, requires_grad=True)

# function 
f = 2 * x**3
print(f"x={v},         f = {f}")

# first order
f.backward(retain_graph=True)
print(f"x={v},     df/dx = {x.grad}")

# second order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
df.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^2f/dx^2 = {x.grad}")

# thrid order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
ddf = torch.autograd.grad(df, x, create_graph=True)[0]
ddf.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^3f/dx^3 = {x.grad}")

# fourth order
x.grad.zero_()
df = torch.autograd.grad(f, x, create_graph=True)[0]
ddf = torch.autograd.grad(df, x, create_graph=True)[0]
dddf = torch.autograd.grad(ddf, x, create_graph=True)[0]
dddf.backward(gradient=torch.tensor(1.0))
print(f"x={v}, d^4f/dx^4 = {x.grad}")


x=-2.0,         f = -16.0
x=-2.0,     df/dx = 24.0
x=-2.0, d^2f/dx^2 = -24.0
x=-2.0, d^3f/dx^3 = 12.0
x=-2.0, d^4f/dx^4 = 0.0


You can check the result with the following equations

$$
\begin{aligned}
    \text{(function)~~~} &f = 2x^3 & \text{when~} x=-2 \Rightarrow 2 \cdot (-8) = -16 \\
    \text{(first-order)~~~} &\frac{df}{dx} = 6x^2 &  \text{when~} x=-2 \Rightarrow 6 \cdot 4 = 24 \\
    \text{(second-order)~~~} &\frac{d}{dx}\Big(\frac{df}{dx}\Big) = 12x & \text{when~} x=-2 \Rightarrow 12 \cdot 2 =24 \\
    \text{(third-order)~~~} &\frac{d}{dx}\Big(\frac{d^2f}{dx^2}\Big) = 12 &  \text{when~} x=-2 \Rightarrow 12 = 12 
\end{aligned}
$$

# Neural Network 

In [56]:
import torch.nn as nn 

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Linear(2,5)
        self.out = nn.Linear(5, 1)
    def forward(self, x):
        x = self.mlp(x**2)
        x = nn.functional.relu(x)
        x = self.out(x)
        return x 
    
net = Net()
input= torch.tensor([1.0]*2, requires_grad=True)


In [57]:
f = net(input)
f.backward()
input.grad 

tensor([-0.0906, -0.0671])

In [60]:
input.grad.zero_() 
f = net(input)
gx = torch.autograd.grad(f, input, create_graph=True)[0]
for i in range(2):
    gx[i].backward(retain_graph=True)
    print(input.grad)

tensor([-0.0906,  0.0000])
tensor([-0.0906, -0.0671])


In [61]:
input.grad.zero_() 
f = net(input)
gx = torch.autograd.grad(f, net.mlp.weight, create_graph=True)[0]
print(gx)
(gx**2).sum().backward()
net.mlp.weight.grad

tensor([[-0.2138, -0.2138],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.3319,  0.3319]], grad_fn=<TBackward0>)


tensor([[-1.0688, -1.0688],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 0.0000,  0.0000],
        [ 1.6595,  1.6595]])