In [None]:
import torch
import torch.nn as nn

# Matrix field reconstruction

The matrix field reconstruction has one less Jacobian step than the vector field reconstruction which is why it is computationally more efficient.

We use the following deterministic transformations to get from the NN output to a divergence-free vector field v:
1. Parameterise the Skew-Symmetric decomposition of A
    - U = NN(x) (non-zero values of the Upper Triangular U are of size N(N - 1)/(2) (x2))
2. Construct anti-symmetric matrix A
    - A = U - U.T 
3. Attain divergence-free vector field v via
    - v = (div(A1), div(A2)), trace of the Jacobian

see https://github.com/facebookresearch/neural-conservation-law/blob/main/pytorch/divfree.py

Dimensionalities:
- If our input in (4 x 4, 2) so flat that is (16, 2) coordinate pairs, U should be (6, 2, 2)
- 6 = (sqrt(N)(sqrt(N) - 1)/2)

NN batch-wise or not


## Questions

- Do we use a NN that processes a batch or points individually? (Batch)
- What is the output shape of u_v?
    - For the dim = 2 case, do it is always (dim * (dim - 1) / 2), which is 1 because we only estimate the upper right corner of every (N) 2 x 2 matrix
    - Can we build this directly into the net?

The model seems to only be implemented under [jax > models.py > Divfree()](https://github.com/facebookresearch/neural-conservation-law/blob/20a403d00affad905d1c47b041bc60d0ff0ea360/jax/models.py#L118). DivfreeSparse() and DivFreeImplicit() are not used anywhere.

The model is used in [jax > hh_experiment_DivFree.py](https://github.com/facebookresearch/neural-conservation-law/blob/20a403d00affad905d1c47b041bc60d0ff0ea360/jax/hh_experiment_DivFree.py#L53). Hodge decomp.

dim = 10.  
mlp = MLP(depth = layers, width = width, act = act, out_dim = **dim * (dim-1) // 2**, std = 1, bias = True)

For dim = 2, at each point, each matrix A (2 x 2) is antisymm. So we only have to estimate a scalar for each input point.

u_fn, params, _ = build_divfree_vector_field(self.module)

In [29]:
class MLP(nn.Module):
    def __init__(self, input_dim = 2, hidden_dim = 32):
        super().__init__()
        output_dim = int((input_dim * (input_dim - 1)) / 2)
        print(output_dim)
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)  # Output shape: (4,)

In [None]:
model = MLP()

N = 16  # N should be a perfect square
N_side = int(N ** 0.5)
dims = 2
inputs = torch.randn(N, dims)  # Random (N, 2) inputs
inputs.shape

1


torch.Size([16, 2])

In [None]:
# U_fill shape is (N, 1)
U_fill = model(inputs)
U = torch.zeros(N, dims, dims)
# Fills all top right corners of the (N, 2, 2) tensor
U[:, 0, 1] = U_fill.squeeze()
# U is (N, 2, 2), so we need to swap the last two dims and then subtract
A = U - U.transpose(1, 2)

def compute_A(inputs):
    # U_fill is (N, 1)
    U_fill = model(inputs)
    # This version works with vmap
    U = torch.triu(torch.ones(N, dims, dims), diagonal = 1)
    U = U * U_fill.unsqueeze(1)
    A = U - U.transpose(1, 2)
    return A

In [92]:
# U = U.index_add(1, torch.tensor([0, 1]), U_fill.squeeze().unsqueeze(1).repeat(1, 2))

## Jacobian

- Jacobian is probably among the most expensive functions
- [torch.func.jacrev](https://pytorch.org/docs/stable/generated/torch.func.jacrev.html#torch.func.jacrev)
    - The implementation goes forward
    - torch.func.jacobian chooses based on efficiency
- batched Jacobians via vmap
- torch.autograd.functional.jacobian(f, x)
    - not as fast as func

In [None]:
from torch.func import jacrev, jacfwd, vmap
x = torch.randn(5, 2, 2)
jacobian = vmap(jacrev(torch.sin))(x)
jacobian.shape

torch.Size([5, 2, 2, 2, 2])

In [181]:
# Issue with how we construct U 
# Jacobian: torch.Size([16, 16, 2, 2, 2])
# without vmap it was also torch.Size([16, 2, 2, 16, 2])
jacobian = vmap(jacrev(compute_A))(inputs)

# Remove redundant dim
# (jacobian[:, 0, : , :, :] == jacobian[:, 15, : , :, :]).any()
jacobian_sq = jacobian[:, 0, : , :, :]
jacobian_sq.shape

torch.Size([16, 2, 2, 2])

In [187]:
# torch.diagonal(jacobian_sq, dim1 = 2, dim2 = 3).sum(dim = 1)
# torch.diagonal(jacobian_sq, dim1 = 1, dim2 = 3).sum(dim = 1)

In [196]:
v = torch.diagonal(jacobian_sq, dim1 = 2, dim2 = 3).sum(dim = 1)

In [173]:
jacobian[:, :, 1, 0, 0]
jacobian[:, 0, 1, 0, 0] # second d is unnessary
jacobian[:, 0, 0, 1, 0]
jacobian[:, 0, 0, :, 1]

tensor([[ 0.0000,  0.0344],
        [ 0.0000, -0.0620],
        [ 0.0000, -0.0434],
        [ 0.0000, -0.0447],
        [ 0.0000, -0.0142],
        [ 0.0000,  0.0077],
        [ 0.0000,  0.0283],
        [ 0.0000,  0.0057],
        [ 0.0000,  0.0440],
        [ 0.0000,  0.0528],
        [ 0.0000, -0.0360],
        [ 0.0000,  0.0541],
        [ 0.0000,  0.0662],
        [ 0.0000,  0.0424],
        [ 0.0000,  0.1336],
        [ 0.0000,  0.0620]], grad_fn=<SelectBackward0>)

In [197]:
# Compute the Jacobian using autograd
# Takes in function & input
# torch.Size([16, 2, 2, 16, 2])
# jacobian = torch.autograd.functional.jacobian(compute_A, inputs, vectorize = True)
jacobian_func = torch.func.vmap(torch.func.jacfwd(compute_A))(inputs)
# trace_result = torch.trace(jacobian, dim1 = 1, dim2 = 2)

In [200]:
vmap(torch.func.jacfwd(compute_A))(inputs).shape

torch.Size([16, 16, 2, 2, 2])

In [201]:
torch.func.jacrev(compute_A)(inputs).diagonal(dim1 = 2, dim2 = 3).shape
torch.func.jacfwd(compute_A)(inputs)[0, :, :, 0]
# torch.func.jacrev(compute_A)(inputs)[0, 0, 1, 0, 0]
# torch.func.jacrev(compute_A)(inputs)[0, 1, 0, 0, 0]

tensor([[[ 0.0000,  0.0000],
         [ 0.0273,  0.0344]],

        [[-0.0273, -0.0344],
         [ 0.0000,  0.0000]]], grad_fn=<SelectBackward0>)

In [58]:
jacobian_func.diagonal(dim1 = 1, dim2 = 2).sum()

AttributeError: 'function' object has no attribute 'diagonal'

In [54]:
jacobian.diagonal(dim1 = 1, dim2 = 2).sum()

tensor(0.)

In [55]:
jacobian.shape

torch.Size([16, 2, 2, 16, 2])

In [38]:
U_empty[:, 0, 1]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [None]:
U = jnp.zeros((N_side, N_side))
idx = jnp.triu_indices(N, 1)
U = U.at[idx].set(b) # go through via row
A = U - U.T # miuns now multiplication

In [11]:
import torch
import torch.nn as nn
import itertools
import math

class MLP(nn.Module):
    def __init__(self, input_dim = 2, hidden_dim = 32, output_dim = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)  # Output shape: (4,)

def generate_pairwise_matrices(inputs, model):
    """
    Takes (N, 2) inputs, applies the model, and outputs (sqrt(N)(sqrt(N)-1)/2, 2, 2) shaped tensor.
    """
    N = inputs.shape[0]
    
    # Check that N is a perfect square for âˆšN pairs
    sqrt_N = math.isqrt(N)
    assert sqrt_N ** 2 == N, "N should be a perfect square!"

    # Compute the number of pairwise combinations: sqrt(N) * (sqrt(N) - 1) / 2
    num_pairs = (sqrt_N * (sqrt_N - 1)) // 2

    pairs = list(itertools.combinations(range(N), 2))  # Generate all unique (i, j) pairs

    outputs = []
    for i, j in pairs:
        pair_input = (inputs[i] + inputs[j]) / 2  # Combine inputs (simple average)
        matrix_flat = model(pair_input)  # Get (4,) shaped output
        matrix = matrix_flat.view(2, 2)  # Reshape to (2, 2)
        outputs.append(matrix)

    return torch.stack(outputs)  # Shape: (num_pairs, 2, 2)

# Example usage
N = 16  # N should be a perfect square
inputs = torch.randn(N, 2)  # Random (N, 2) inputs

model = MLP()
output_matrices = generate_pairwise_matrices(inputs, model)

print("Input shape:", inputs.shape)  # (N, 2)
print("Output shape:", output_matrices.shape)  # (12, 2, 2) for N = 16

Input shape: torch.Size([16, 2])
Output shape: torch.Size([120, 2, 2])
