# Using `torch.func`

In [1]:
import torch
from torch.func import jacrev, jvp, vmap
from torch import nn

In [2]:
x = torch.Tensor([[1.0, 1.0], [1.0, 1.0]]).to(torch.float32)  # (batch_size, dim)
x

tensor([[1., 1.],
        [1., 1.]])

In [7]:
def vector_field(x):
    y = torch.empty_like(x)
    y[:, 0] = x[:, 0]**2 + x[:, 1]
    y[:, 1] = 3*x[:, 0]**2 - 2*x[:, 1]
    return y

vector_field = vmap(vector_field)

def vector_field_jacobian(x):
    jac = torch.empty((2, 2), dtype=x.dtype)
    jac[0, 0] = 2*x[0]
    jac[0, 1] = 1
    jac[1, 0] = 6*x[0]
    jac[1, 1] = -2
    return jac

vector_field_jacobian = vmap(vector_field_jacobian)

def vector_field_divergence(x):
    div = 2*x[0] - 2
    return div

vector_field_divergence = vmap(vector_field_divergence)

def torch_vector_field_divergence(x):
    divergence = 0
    for i in range(len(x)):
        e_i = torch.zeros_like(x)
        e_i[i] = 1.0
        _, jvp_result = jvp(vector_field, (x,), (e_i,))
        print(jvp_result)
        divergence += jvp_result.flatten()[i]
    return divergence

torch_vector_field_divergence = vmap(torch_vector_field_divergence)

vetor_field_value = vector_field(x)
print(f'Vector field value:\n{vetor_field_value}')

torch_vector_field_jacobian = jacrev(vector_field)
explicit_jacobian = vector_field_jacobian(x)
torch_jacobian = torch_vector_field_jacobian(x)
print(f'Explicit Jacobian:\n{explicit_jacobian}')
print(f'Torch Jacobian:\n{torch_jacobian}')
jacobians_equal = torch.allclose(explicit_jacobian, torch_jacobian)
print(f'Jacobians equal: {jacobians_equal}')
explicit_divergence = vector_field_divergence(x)
torch_divergence = torch_vector_field_divergence(x)
print(f'Explicit divergence: {explicit_divergence}')
print(f'Torch divergence: {torch_divergence}')
divergences_equal = torch.allclose(explicit_divergence, torch_divergence)
print(f'Divergences equal: {divergences_equal}')

IndexError: too many indices for tensor of dimension 1

In [32]:
nn_model = nn.Sequential(
    nn.Linear(2, 16),
    nn.Tanh(),
    nn.Linear(16, 2)
)

def torch_jvp_divergence(func, x):
    divergence = 0
    for i in range(len(x)):
        basis = torch.zeros(len(x), dtype=x.dtype)
        basis[i] = 1.0
        _, vjp = jvp(func, (x,), (basis,))
        divergence += vjp.flatten()[i]
    return divergence

nn_jacobian = jacrev(nn_model)
jac = nn_jacobian(x)
print(f'NN Jacobian: {jac}')
print(f'Divergence: {torch.trace(jac)}')
nn_divergence = torch_jvp_divergence(nn_model, x)
print(f'NN Divergence via JVP: {nn_divergence}')
equal = torch.allclose(torch.trace(jac), nn_divergence)
print(f'Divergences equal: {equal}')

NN Jacobian: tensor([[-0.0936,  0.0523],
        [-0.0084,  0.0908]], grad_fn=<ViewBackward0>)
Divergence: -0.0027619674801826477
NN Divergence via JVP: -0.0027619674801826477
Divergences equal: True
