In [1]:
import torch
import torch.nn as nn


In [2]:
# General-purpose function to compute the Jacobian matrix using VJP
def jacobian_from_vjp(f, x):
    """
    Computes the Jacobian matrix of a function f at x using VJP.
    
    Args:
        f (callable): The function for which to compute the Jacobian. Takes a tensor `x` and outputs a tensor `y`.
        x (torch.Tensor): Input tensor with `requires_grad=True`.
    
    Returns:
        torch.Tensor: Jacobian matrix with shape (output_dim, input_dim).
    """
    x = x.requires_grad_()  # Ensure x requires gradient
    y = f(x)                # Evaluate the function
    output_dim = y.size(0)  # Output dimension of f(x)
    input_dim = x.size(0)   # Input dimension of x

    # Compute the Jacobian row-by-row using VJP
    jacobian_rows = []
    for i in range(output_dim):
        # Unit vector for the i-th output dimension
        grad_output = torch.zeros_like(y)
        grad_output[i] = 1.0
        # Compute the gradient of y[i] w.r.t. x
        vjp = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=grad_output, retain_graph=True)[0]
        jacobian_rows.append(vjp)

    # Stack rows to form the full Jacobian matrix
    jacobian = torch.stack(jacobian_rows)
    return jacobian


In [3]:
# Example: Neural Network as f
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.layer1 = nn.Linear(3, 4)
        self.layer2 = nn.Linear(4, 2)
    
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        return self.layer2(x)
    
# Example: f is a simple quadratic function
def quad_f(x):
    return torch.stack([x[0]**2 + x[1], x[1]**2 + x[2]])



In [4]:
# testing code for neural function
# Instantiate the model
model = SimpleNN()

# Input tensor
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# Define the function f using the model
def f(x):
    return model(x)

# Compute the Jacobian of f at x
jacobian = jacobian_from_vjp(f, x)
jacobian_builtin = torch.autograd.functional.jacobian(f, x)


# Output results
print("Input x:", x)
print("Output f(x):", f(x))
print("Jacobian matrix:\n", jacobian)
print("Jacobian matrix (inbuilt):\n", jacobian_builtin)

Input x: tensor([1., 2., 3.], requires_grad=True)
Output f(x): tensor([-0.4327, -0.4846], grad_fn=<ViewBackward0>)
Jacobian matrix:
 tensor([[ 0.0050, -0.1075,  0.0268],
        [ 0.0085, -0.1835,  0.0457]])
Jacobian matrix (inbuilt):
 tensor([[ 0.0050, -0.1075,  0.0268],
        [ 0.0085, -0.1835,  0.0457]])


In [5]:
# testing code for quadratic function
# Input tensor
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)

# Compute the Jacobian of f at x
jacobian = jacobian_from_vjp(quad_f, x)

# Compute the Jacobian using the inbuilt function
jacobian_builtin = torch.autograd.functional.jacobian(quad_f, x)


# Output results
print("Input x:", x)
print("Output f(x):", quad_f(x))
print("Jacobian matrix:\n", jacobian)
print("Jacobian matrix (inbuilt):\n", jacobian_builtin)

Input x: tensor([1., 2., 3.], requires_grad=True)
Output f(x): tensor([3., 7.], grad_fn=<StackBackward0>)
Jacobian matrix:
 tensor([[2., 1., 0.],
        [0., 4., 1.]])
Jacobian matrix (inbuilt):
 tensor([[2., 1., 0.],
        [0., 4., 1.]])
