In [51]:
import torch
import torch.nn as nn
import time
from torch.autograd.functional import jvp

batch_size = 10
nx = 3

A = torch.randn(nx, nx, requires_grad=True)
B = torch.randn(nx, nx, requires_grad=True)

xs = torch.randn(batch_size, nx, requires_grad=True)
zs = torch.exp(xs)

inds = [0, 1, 2, 3, 5, 6, 7, 8]
zs_next = zs[inds, :] @ B.T

print(zs_next.shape)
print(xs.shape)

vsx = torch.randn(zs_next.shape[0], 3)
jx = torch.autograd.grad(zs_next, xs, grad_outputs=vsx, create_graph=True)
print(jx)
for j, i in enumerate(inds):
    print(
        vsx[None, j, ...] @ B @ torch.diag(
            torch.exp(xs[i, :])
        )
    )
    
zz = zs_next[0, :]
xx = xs[0, :]

jac = B @ torch.diag(torch.exp(xs[0]))
print(vsx[0][None, :] @ jac)

torch.Size([8, 3])
torch.Size([10, 3])
(tensor([[ 1.4862, -1.0360, -0.5143],
        [-1.1530,  1.2557,  0.2224],
        [-0.6994,  0.8583,  1.3914],
        [-0.1733, -0.0907,  1.1319],
        [ 0.0000,  0.0000,  0.0000],
        [-1.8744,  4.4257,  2.1459],
        [-1.3158,  0.1616, -0.0959],
        [-0.0464,  0.0097,  0.1212],
        [-2.0655,  0.1496,  1.5749],
        [ 0.0000,  0.0000,  0.0000]], grad_fn=<MulBackward0>),)
tensor([[ 1.4862, -1.0360, -0.5143]], grad_fn=<MmBackward0>)
tensor([[-1.1530,  1.2557,  0.2224]], grad_fn=<MmBackward0>)
tensor([[-0.6994,  0.8583,  1.3914]], grad_fn=<MmBackward0>)
tensor([[-0.1733, -0.0907,  1.1319]], grad_fn=<MmBackward0>)
tensor([[-1.8744,  4.4257,  2.1459]], grad_fn=<MmBackward0>)
tensor([[-1.3158,  0.1616, -0.0959]], grad_fn=<MmBackward0>)
tensor([[-0.0464,  0.0097,  0.1212]], grad_fn=<MmBackward0>)
tensor([[-2.0655,  0.1496,  1.5749]], grad_fn=<MmBackward0>)
tensor([[ 1.4862, -1.0360, -0.5143]], grad_fn=<MmBackward0>)


In [2]:
import torch
import torch.nn as nn
import time
from torch.autograd.functional import jvp

nx = 3
nu = 2
batch = 7

A = torch.randn(nx, nx, requires_grad=True)
B = torch.randn(nx, nu, requires_grad=True)


def f(x, u):
    return x @ A.T + u @ B.T

x = torch.randn(batch, nx, requires_grad=True)
u = torch.randn(batch, nu, requires_grad=True)
xu = torch.cat([x, u], dim=-1)

y = f(x, u)
print(x.shape, u.shape, y.shape)

n_proj = 10
vsx = torch.randn(batch, n_proj, nx, requires_grad=True)
vsu = torch.randn(batch, n_proj, nx, requires_grad=True)

grads_x = torch.zeros(batch,)
grads_u = torch.zeros(batch,)

for i in range(n_proj):
    Gx, Gu = torch.autograd.grad(y, (x, u), grad_outputs=vsx[:, i, :], create_graph=True)
        
    grads_x += 1/n_proj * (torch.norm(Gx, 2, dim=-1) ** 2)
    grads_u += 1/n_proj * (torch.norm(Gu, 2, dim=-1) ** 2)

grads_x_mean = grads_x.mean(dim=0)
grads_u_mean = grads_u.mean(dim=0)

print(grads_x_mean, torch.norm(A, 'fro') ** 2)
print(grads_u_mean, torch.norm(B, 'fro') ** 2)

torch.Size([7, 3]) torch.Size([7, 2]) torch.Size([7, 3])
tensor(7.8979, grad_fn=<MeanBackward1>) tensor(7.3632, grad_fn=<PowBackward0>)
tensor(9.3245, grad_fn=<MeanBackward1>) tensor(8.6587, grad_fn=<PowBackward0>)


In [3]:
import torch
import torch.nn as nn
import time
from torch.autograd.functional import jvp

# Example network: mapping from R^10 to R^20.
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(10, 32),
            nn.Mish(),
            nn.Linear(32, 32),
            nn.Mish(),
            nn.Linear(32, 32),
            nn.Mish(),
            nn.Linear(32, 20)
        )
    
    def forward(self, x):
        return self.model(x)

model = SimpleModel()

# Dummy input: a batch of 3 samples.
x = torch.randn(3, 10, requires_grad=True)

# Forward pass: compute output f(x) (shape: [3, 20])
output = model(x)

# Sample a random vector v with the same shape as output.
# For a proper unbiased estimator, v could be chosen with entries ±1 or Gaussian.
v = torch.randn_like(output)

# Compute the dot product between output and v, yielding a scalar.
# This is equivalent to summing over all samples and features.
scalar = torch.sum(output * v)

# Compute the gradient of this scalar with respect to x.
# This gradient is the Jacobian-vector product J_f(x)^T * v.
jvp_val = torch.autograd.grad(scalar, x, create_graph=True)[0]

# Now, you can compute a norm on the jvp (for example, its squared L2 norm)
jvp_norm_squared = torch.sum(jvp_val ** 2)

# Use this as a regularization loss term.
loss_reg = jvp_norm_squared

# Total loss combines your task loss and the Jacobian regularization
total_loss = loss_reg

# Backward pass:
total_loss.backward()

In [2]:
n_trials = 1000
start = time.perf_counter()

xs = [
    torch.randn(3, 10, requires_grad=True) for _ in range(n_trials)
]

for i in range(n_trials):
    # Dummy input: a batch of 3 samples.
    x = xs[i]
    output = model(x)

    v = torch.randn_like(output)
    scalar = torch.sum(output * v)

    # Compute the gradient of this scalar with respect to x.
    # This gradient is the Jacobian-vector product J_f(x)^T * v.
    jvp_val = torch.autograd.grad(scalar, x, create_graph=True)[0]
    jvp_norm_squared = torch.sum(jvp_val ** 2)
    total_loss = jvp_norm_squared

    # Backward pass:
    total_loss.backward()

end = time.perf_counter()
print(f"Time per trial: {(end - start) / n_trials:.6f} seconds")

Time per trial: 0.000413 seconds


In [3]:
n_trials = 1000
start = time.perf_counter()

def f(x):
    return model(x)

for _ in range(n_trials):
    # Dummy input: a batch of 3 samples.
    x = xs[i]
    v = torch.randn_like(x)
    output, jvp_val = jvp(f, (x,), (v,), create_graph=True)
    jvp_norm_squared = torch.sum(jvp_val ** 2)
    total_loss = jvp_norm_squared

    # Backward pass:
    total_loss.backward()

end = time.perf_counter()
print(f"Time per trial: {(end - start) / n_trials:.6f} seconds")

Time per trial: 0.000764 seconds


In [4]:
import einops

batch_size = 1
nx = 6
ny = nx
nproj = 100

def f(x):
    return 2*x

x = torch.randn(batch_size, nx, requires_grad=True)
y = f(x)

vs = torch.randn(batch_size, nproj, ny, requires_grad=True)

print(y.shape, vs.shape)

dot_products = torch.einsum('bpm,bm->bp', vs, y)  # shape: (B, P)
mean_dot = dot_products.mean(dim=-1)

grads = torch.autograd.grad(mean_dot, x, grad_outputs=torch.ones_like(mean_dot), create_graph=True)[0]  # shape: (B, nx)
grads_squared_norm = torch.norm(grads, dim=-1) ** 2
mean_grads_squared_norm = grads_squared_norm.mean(dim=0)  # shape: (nx,)

print(mean_grads_squared_norm)

torch.Size([1, 6]) torch.Size([1, 100, 6])
tensor(0.1704, grad_fn=<MeanBackward1>)


In [5]:
# analytic_1 = 1/3 * (
#     vs[0, 0, :] @ (2*torch.eye(ny)) #+ \
#     # vs[0, 1, :] @ (2*torch.eye(ny)) + \
#     # vs[0, 2, :] @ (2*torch.eye(ny))
# )

# print(analytic_1)
print(grads.shape)

torch.Size([1, 6])


In [6]:
batch_size = 1
nx = 6
ny = nx
nproj = 100

def f(x):
    return 2*x

seed = 42
torch.random.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

vs = torch.randn(batch_size, nproj, ny, requires_grad=True)

class JacobianReg(nn.Module):
    '''
    Loss criterion that computes the trace of the square of the Jacobian.

    Arguments:
        n (int, optional): determines the number of random projections.
            If n=-1, then it is set to the dimension of the output 
            space and projection is non-random and orthonormal, yielding 
            the exact result.  For any reasonable batch size, the default 
            (n=1) should be sufficient.
    '''
    def __init__(self, n=1):
        assert n == -1 or n > 0
        self.n = n
        super(JacobianReg, self).__init__()

    def forward(self, x, y):
        '''
        computes (1/2) tr |dy/dx|^2
        '''
        B, C = y.shape
        if self.n == -1:
            num_proj = C
        else:
            num_proj = self.n
        J2 = 0
        for ii in range(num_proj):
            if self.n == -1:
                # orthonormal vector, sequentially spanned
                v = torch.zeros(B, C)
                v[:, ii] = 1
            else:
                # random properly-normalized vector for each sample
                # v = self._random_vector(C=C, B=B)
                v = vs[:, ii, :]
            if x.is_cuda:
                v = v.cuda()
            Jv = self._jacobian_vector_product(y, x, v, create_graph=True)
            J2 += C * torch.norm(Jv) ** 2 / (num_proj * B)
        R = (1 / 2) * J2
        return R

    def _random_vector(self, C, B):
        '''
        creates a random vector of dimension C with a norm of C^(1/2)
        (as needed for the projection formula to work)
        '''
        # if C == 1:
        #     return torch.ones(B)
        # v = torch.randn(B, C)
        # arxilirary_zero = torch.zeros(B, C)
        # vnorm = torch.norm(v, 2, 1, True)
        # v = torch.addcdiv(arxilirary_zero, 1.0, v, vnorm)
        # return v
        # return vs
        pass
    
    def _jacobian_vector_product(self, y, x, v, create_graph=False):
        '''
        Produce jacobian-vector product dy/dx dot v.

        Note that if you want to differentiate it,
        you need to make create_graph=True
        '''
        flat_y = y.reshape(-1)
        flat_v = v.reshape(-1)
        grad_x, = torch.autograd.grad(flat_y, x, flat_v, retain_graph=True, create_graph=create_graph)
        return grad_x


jacreg = JacobianReg(n=nproj)
x = torch.randn(batch_size, nx, requires_grad=True)
y = f(x)

nrm = jacreg.forward(x, y)
print(nrm)

tensor(68.8100, grad_fn=<MulBackward0>)


In [71]:
nx = 3
batch = 2

A = torch.diag(torch.arange(nx * 1.0))

print(x.shape, y.shape)

def f(x):
    return x @ A.T

x = torch.randn(batch, nx, requires_grad=True)
y = f(x)

# vs = torch.randn(batch, nx, requires_grad=True)

# G = torch.autograd.grad(y, x, grad_outputs=vs, create_graph=True)[0]
# print(G.shape)
# print(G)

# for i in range(batch):
#     vT = vs[i, :].reshape((1, -1))
#     print(vT @ A)
    
n_proj = 10
vs = torch.randn(batch, n_proj, nx, requires_grad=True)
# vs = vs / torch.norm(vs, dim=-1, keepdim=True)

grads = torch.zeros(batch,)

for i in range(n_proj):
    G = torch.autograd.grad(y, x, grad_outputs=vs[:, i, :], create_graph=True)[0]    
    grads += 1/nproj * (torch.norm(G, 2, dim=-1) ** 2)

grads_mean = grads.mean(dim=0)
print(grads_mean)
print(torch.norm(A, 'fro') ** 2)

torch.Size([2, 3]) torch.Size([2, 3])
tensor(0.0589, grad_fn=<MeanBackward1>)
tensor(5.)


In [8]:
from torch.func import jvp

jvp_gradx = torch.zeros(batch_size)
# vs = torch.randn(batch_size, nproj, nx)

for i in range(nproj):
    value, grad = jvp(f, (x,), (vs[:, i, :],))
    jvp_gradx += nx * torch.norm(grad)**2 / nproj

jvp_gradx = jvp_gradx.mean(dim=0)
print(jvp_gradx / 2)

tensor(17.8793, grad_fn=<DivBackward0>)


In [9]:
torch.randn(2)

tensor([0.2607, 0.4251])

In [12]:
batch_size = 1
nx = 6
nu = 3
nproj = 1000

A = torch.randn(nx, nx)
B = torch.randn(nx, nu)

def f(x, u):
    return x @ A.T + u @ B.T

jvp_gradx = torch.zeros(batch_size)
jvp_gradu = torch.zeros(batch_size)

vsx = torch.randn(batch_size, nproj, nx)
vsu = torch.randn(batch_size, nproj, nu)

xs = torch.randn(batch_size, nx, requires_grad=True)
us = torch.randn(batch_size, nu, requires_grad=True)
zs = xs @ A.T + us @ B.T

for i in range(nproj):
    value, grad = jvp(lambda x, u: zs, (xs, us), (vsx[:, i, :], vsu[:, i, :]))
    jvp_gradx += torch.norm(grad)**2 / nproj

jvp_gradx = jvp_gradx.mean(dim=0)
mat_AB = torch.cat([A, B], dim=-1)

print(jvp_gradx)
print(torch.norm(mat_AB, 'fro')**2)

tensor(0.)
tensor(47.6827)
