In [3]:
import torch
from torch.func import jacrev, vmap

In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


Notes: 

- torch.autograd.functional.jacobian calculates cross-terms which we don't need
- the otehr version is faster

In [None]:
N = 5 # number of points along x and y axes
D = 2 # dimensions

# each point on 5 x 5 2D grid maps to a 2 x 2 matrix
any_square_matrix = torch.randn(size = (N, N, D, D))

# construct the antisymmetric (skew-symmetric) matrix (no scaling needed) 
# we transpose just the last two dimensions
A = any_square_matrix - any_square_matrix.transpose(-1, -2)

# assert raises an error if the condition is not met
assert torch.allclose(A, - A.transpose(-1, -2)), "Matrix is not antisymmetric!"

# https://github.com/facebookresearch/neural-conservation-law/blob/main/pytorch/divfree.py

In [39]:
A.diagonal(dim1 = - 1, dim2 = -2).sum(dim = -1)

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [40]:
import jax
import jax.numpy as jnp
from jax import jit, vmap
from jax import jacfwd, jacrev, grad, jvp

import flax.linen as nn
from einops import rearrange

import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'jax'

## Old

In [18]:
A_flat = A.reshape(-1) # flattens A into a 1D tensor

In [19]:
x.nelement()

10

In [None]:
# from NCL
def div(u):
    """Accepts a function u:R^D -> R^D."""
    J = jacrev(u)
    return lambda x: torch.trace(J(x))

In [None]:
# Define vector function f: R^2 -> R^2
def f(x):
    x1, x2 = x[..., 0], x[..., 1]
    return torch.stack([
        x1**2 + x2,
        torch.sin(x1) + x2**3
    ], dim = -1)  # Keep vector output shape

# Generate a batch of 5 random vectors in R^2
x = torch.randn(5, 2, requires_grad=True)

# Compute Jacobian correctly: Per-sample, not across batch
J_fn = vmap(jacrev(f))  # Vectorized Jacobian function
J = J_fn(x)  # Apply to batch

# Print output shape
print("Corrected Jacobian shape:", J.shape)  # Should be (5, 2, 2)

# Print an example Jacobian for the first input
print("Jacobian for the first input:\n", J[0])

Corrected Jacobian shape: torch.Size([5, 2, 2])
Jacobian for the first input:
 tensor([[ 4.5247,  1.0000],
        [-0.6377,  0.8643]], grad_fn=<SelectBackward0>)


In [None]:
x = torch.randn(5, 3)

15

In [15]:
# https://pytorch.org/functorch/nightly/generated/functorch.jacrev.html
J_fn = jacrev(func = f, argnums = 1)
J_fn(x)

RuntimeError: Got argnum=1, but only 1 positional inputs

# My annotation of the code


In [None]:
import torch
import torch.nn as nn
from functorch import make_functional
from functorch import vmap
from functorch import jacrev


def div(u):
    """Accepts a function u:R^D -> R^D."""
    J = jacrev(u)
    return lambda x: torch.trace(J(x))


def build_divfree_vector_field(module):
    """Returns an unbatched vector field, i.e. assumes input is a 1D tensor."""

    F_fn, params = make_functional(module)

    J_fn = jacrev(F_fn, argnums=1)

    def A_fn(params, x):
        J = J_fn(params, x)
        A = J - J.T
        return A

    def A_flat_fn(params, x):
        A = A_fn(params, x)
        A_flat = A.reshape(-1)
        return A_flat

    def ddF(params, x):
        D = x.nelement() # dimension of the input (counts all elements across dimensions)
        dA_flat = jacrev(A_flat_fn, argnums = 1)(params, x)
        Jac_all = dA_flat.reshape(D, D, D)
        ddF = vmap(torch.trace)(Jac_all)
        return ddF

    return ddF, params, A_fn

In [None]:
# jac rev is faster