In [1]:
import torch
from torch.autograd.functional import jacobian, hessian
import matplotlib.pyplot as plt

Function is:
\begin{align}
    u(x,t) &= 2x^3t+wx^2t^2-6x
\end{align}
Loss function is:
\begin{align*}
    \mathcal{L} &= MSE\left[u + u_x + u_{xx} \right]\\
    &= MSE\left[2x^3t+wx^2t^2-6x + 6x^2t + 2wxt^2 - 6 + 12xt   + 2wt^2\right]
\end{align*}

In [2]:
### Compute derivatives analytically
def dudx_true(X):
    x = X.view(2,-1)[0]
    return 6*x**2*t + 2*w*x*t**2 - 6

def du2dx2_true(X):
    x = X.view(2,-1)[0]
    return 12*x*t   + 2*w*t**2

def dudt_true(X):
    t = X.view(2,-1)[1]
    return 2*x**3 + 2*w*x**2*t

def du2dt2_true(X):
    t = X.view(2,-1)[1]
    return 2*w*x**2

def mse(x):
    return torch.sum(x**2).mean()

## efficient way

In [4]:
N = 10 # number of samples
w = torch.tensor(5., requires_grad=True)
t = torch.linspace(5,30,N)
t.requires_grad=True
x = torch.linspace(0,1,N)
x.requires_grad=True
X = torch.cat([x,t], dim=0)

def u(x,t):
    '''function'''
    return 2*x**3*t + w*x**2*t**2 - 6*x

u_ = u(x,t)

## Compute first derivative (WRT x)
u_.backward(torch.ones_like(x), create_graph=True)
u_x = x.grad.clone()
u_t = t.grad.clone()

## Compute second derivative
x.grad.zero_() # zero out gradients of x first
u_x.backward(torch.ones_like(x), create_graph=True)
u_xx = x.grad.clone()

## Finally, compute loss WRT parameters
x.grad.zero_() # zero out gradients first
t.grad.zero_()
w.grad.zero_()

## compute loss
loss = mse(u_ + u_x + u_xx)
loss.backward()
print(w.grad)

tensor(4.4646e+08, grad_fn=<ZeroBackward>)


## compare to actual

In [72]:
## compute actual derivatives
u_x_true = dudx_true(X)
u_t_true = dudt_true(X)
u_xx_true = du2dx2_true(X)
u_tt_true = du2dt2_true(X)

print(u_x_true)
print(u_x)
print()
print(u_t_true)
print(u_t)
print()
print(u_xx_true)
print(u_xx)
print()

loss = mse(u(x,t) + u_x + u_xx)
loss.backward()
print(w.grad)

tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<SubBackward0>)
tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<CloneBackward>)

tensor([  0.0000,   0.9630,   5.2346,  14.8889,  32.0000,  58.6420,  96.8889,
        148.8148, 216.4938, 302.0000], grad_fn=<AddBackward0>)
tensor([  0.0000,   0.9630,   5.2346,  14.8889,  32.0000,  58.6420,  96.8889,
        148.8148, 216.4938, 302.0000], grad_fn=<CloneBackward>)

tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000], grad_fn=<AddBackward0>)
tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000])



RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

## very inefficient way...

In [87]:
def u(X):
    '''function'''
    x = X.view(2,-1)[0]
    t = X.view(2,-1)[1]
    return 2*x**3*t + w*x**2*t**2 - 6*x

#### custom derivatives #####
def drv1(X):
    jac = jacobian(func=u, inputs=X, create_graph=True)
    u_x = torch.diag(jac[:,:N])
    u_t = torch.diag(jac[:,N:])
    return torch.cat([u_x, u_t], axis=0)

def drv2(X):
    jac  = jacobian(func=drv1, inputs=X, create_graph=True)
    u_xx = torch.diag(jac[:N,:N])
    u_xt = torch.diag(jac[N:,:N])
    u_tx = torch.diag(jac[:N,N:])
    u_tt = torch.diag(jac[N:,N:])
    return torch.cat([u_xx, u_xt, u_tx, u_tt], dim=0)

## compute actual derivatives
u_x_true = dudx_true(X)
u_t_true = dudt_true(X)
u_xx_true = du2dx2_true(X)
u_tt_true = du2dt2_true(X)

## compute custom derivatives
first_deriv = drv1(X).view(2,-1)
second_deriv = drv2(X).view(4,-1)

## parse derivatives
u_x = first_deriv[0]
u_t = first_deriv[1]
u_xx = second_deriv[0]
u_xt = second_deriv[1]
u_tx = second_deriv[2]
u_tt = second_deriv[3]

print(u_x_true)
print(u_x)
print()
print(u_t_true)
print(u_t)
print()
print(u_xx_true)
print(u_xx)
print()
print(u_tt_true)
print(u_tt)
print()

loss = mse(u(X) + u_x + u_xx)
loss.backward()
print(w.grad)
w.grad.zero_()
print()

tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<SubBackward0>)
tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<SelectBackward>)

tensor([  0.0000,   0.9630,   5.2346,  14.8889,  32.0000,  58.6420,  96.8889,
        148.8148, 216.4938, 302.0000], grad_fn=<AddBackward0>)
tensor([  0.0000,   0.9630,   5.2346,  14.8889,  32.0000,  58.6420,  96.8889,
        148.8148, 216.4938, 302.0000], grad_fn=<SelectBackward>)

tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000], grad_fn=<AddBackward0>)
tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000], grad_fn=<SelectBackward>)

tensor([ 0.0000,  0.1235,  0.4938,  1.1111,  1.9753