In [1]:
import torch
from torch.autograd.functional import jacobian, hessian
import matplotlib.pyplot as plt

In [2]:
N = 10 # number of samples
w = torch.tensor(5., requires_grad=True)
t = torch.linspace(5,30,N)
x = torch.linspace(0,1,N)

## case where x is only input

In [3]:
def f(X):
    return 2*X**3*t + w*X**2*t**2 - 6*X
def fprime(x):
    return torch.diag(jacobian(func=f, inputs=x, create_graph=True))
def fprime_actual(x):
    return 6*x**2*t + 2*w*x*t**2 - 6
def fprimeprime_actual(x):
    return 12*x*t   + 2*w*t**2

x = torch.linspace(0,1,10)

f_x  = fprime(x)
f_xx = torch.diag(jacobian(func=fprime, inputs=x, create_graph=True))

print(f_x)
print()
print(fprime_actual(x))
print()
print(f_xx)
print()
print(fprimeprime_actual(x))
print()
def mse(x):
    return torch.sum(x**2).mean()

loss = mse(f(x) + f_x + f_xx)
# loss = f(x) + f_x + f_xx
loss.backward()
print(w.grad)

tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<DiagBackward>)

tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<SubBackward0>)

tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000], grad_fn=<DiagBackward>)

tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000], grad_fn=<AddBackward0>)

tensor(4.4646e+08)


### differentiate with respect to time

In [9]:
def f(t):
    return 2*x**3*t + w*x**2*t**2 - 6*x
def fprime(t):
    return torch.diag(jacobian(func=f, inputs=t, create_graph=True))
def fprime_actual(t):
    return 2*x**3 + 2*w*x**2*t
def fprimeprime_actual(t):
    return 2*w*x**2

f_t  = fprime(t)
f_tt = torch.diag(jacobian(func=fprime, inputs=t, create_graph=True))

# print(f_t)
# print()
print(fprime_actual(t))
print()
# print(f_tt)
# print()
print(fprimeprime_actual(t))

tensor([  0.0000,   0.9630,   5.2346,  14.8889,  32.0000,  58.6420,  96.8889,
        148.8148, 216.4938, 302.0000], grad_fn=<AddBackward0>)

tensor([ 0.0000,  0.1235,  0.4938,  1.1111,  1.9753,  3.0864,  4.4444,  6.0494,
         7.9012, 10.0000], grad_fn=<MulBackward0>)


## case where t is also an input

Function is:
\begin{align}
    f(x,t) &= 2x^3t+wx^2t^2-6x
\end{align}

In [13]:
def f(X):
    x = X.view(2,-1)[0]
    t = X.view(2,-1)[1]
    return 2*x**3*t + w*x**2*t**2 - 6*x

def drv1(X):
    jac = jacobian(func=f, inputs=X, create_graph=True)
    u_x = torch.diag(jac[:,:N])
    u_t = torch.diag(jac[:,N:])
    return torch.cat([u_x, u_t], axis=0)

def drv2(X):
    jac  = jacobian(func=drv1, inputs=X, create_graph=True)
    u_xx = torch.diag(jac[:N,:N])
    u_xt = torch.diag(jac[N:,:N])
    u_tx = torch.diag(jac[:N,N:])
    u_tt = torch.diag(jac[N:,N:])
    return torch.cat([u_xx, u_xt, u_tx, u_tt], dim=0)

X = torch.cat([x,t], dim=0)
first_deriv = drv1(X).view(2,-1)
u_x = first_deriv[0]
u_t = first_deriv[1]

second_deriv = drv2(X).view(4,-1)
u_xx = second_deriv[0]
u_xt = second_deriv[1]
u_tx = second_deriv[2]
u_tt = second_deriv[3]

print(u_x)
print()
print(u_t)
print()
print(u_xx)
print()
print(u_tt)

tensor([-6.0000e+00,  6.1791e+01,  2.4473e+02,  5.9548e+02,  1.1667e+03,
         2.0111e+03,  3.1814e+03,  4.7302e+03,  6.7102e+03,  9.1740e+03],
       grad_fn=<SelectBackward>)

tensor([  0.0000,   0.9630,   5.2346,  14.8889,  32.0000,  58.6420,  96.8889,
        148.8148, 216.4938, 302.0000], grad_fn=<SelectBackward>)

tensor([ 250.0000,  615.3086, 1142.3457, 1831.1111, 2681.6050, 3693.8274,
        4867.7783, 6203.4565, 7700.8633, 9360.0000], grad_fn=<SelectBackward>)

tensor([ 0.0000,  0.1235,  0.4938,  1.1111,  1.9753,  3.0864,  4.4444,  6.0494,
         7.9012, 10.0000], grad_fn=<SelectBackward>)
