In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd.functional import jvp, jacobian
import numpy as np
from time import perf_counter
import einops

In [35]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ntrials = 1000
A = torch.randn(50, 50, device=device)
X = torch.randn(1_000, 50, device=device)

_ = X @ A.T
_ = einops.einsum(A, X, 'i j, k j -> k i')

start = perf_counter()
for _ in range(ntrials):
    y = X @ A.T
end = perf_counter()
print("Time: ", end - start)

start = perf_counter()
for _ in range(ntrials):
    y = einops.einsum(A, X, 'i j, k j -> k i')
end = perf_counter()
print("Time: ", end - start)


Time:  0.014265583005908411
Time:  0.020336519999546


In [38]:
# A simple feedforward model
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        return self.fc2(x)

# Function to compute Jv (Jacobian-vector product)
def compute_jacobian_penalty1(model, x, n_projections=1):
    def f(x_):
        return model(x_)

    penalty = 0.0    
    for _ in range(n_projections):
        v = torch.randn_like(x)  # random direction
        v = v / (v.norm(dim=1, keepdim=True) + 1e-8)  # normalize for stability

        _, jvp_result = jvp(f, (x,), (v,), create_graph=True)

        # L2 norm squared of the directional derivative
        penalty += (jvp_result ** 2).sum(dim=1).mean()

    return penalty / n_projections

def compute_jacobian_penalty2(model, x, n_projections=1):
    def f(x_):
        return model(x_)
        
    penalty = 0.0
    J = jacobian(f, x, create_graph=True)

    for _ in range(n_projections):
        v = torch.randn_like(x)  # random direction
        v = v / (v.norm(dim=1, keepdim=True) + 1e-8)  # normalize for stability

        jvp_result = J @ v.T
        # _, jvp_result = jvp(f, (x,), (v,), create_graph=True)

        # L2 norm squared of the directional derivative
        penalty += (jvp_result ** 2).sum(dim=1).mean()

    return penalty / n_projections

# Hyperparameters
input_dim = 10
hidden_dim = 64
output_dim = 30
model = Net(input_dim, hidden_dim, output_dim)
_ = compute_jacobian_penalty1(model, torch.randn(1, input_dim))

start = perf_counter()
for _ in range(1000):
    x = torch.randn(1, input_dim)
    l = compute_jacobian_penalty1(model, x, n_projections=20)
    l.backward()
end = perf_counter()
print("Time 1: ", (end - start)/1000)

_ = compute_jacobian_penalty2(model, torch.randn(1, input_dim))

start = perf_counter()
for _ in range(1000):
    x = torch.randn(1, input_dim)
    l = compute_jacobian_penalty2(model, x, n_projections=20)
    l.backward()
end = perf_counter()
print("Time 2: ", (end - start)/1000)

Time 1:  0.005075670477002859
Time 2:  0.004025802748001297


In [26]:
def f(x):
    return torch.sin(x)

n = 20
x = torch.randn((32, n), requires_grad=True)
jacobians = torch.stack([jacobian(f, x[i]) for i in range(x.shape[0])])
print(jacobians.shape)

assert torch.allclose(
    jacobians[0],
    jacobian(f, x[0])
)

torch.Size([32, 20, 20])


In [53]:
from sklearn.decomposition import PCA

cov = torch.randn(10, 10)
cov = cov @ cov.T

X = np.random.multivariate_normal(np.zeros(10), cov, 10_000)

cov_recon = np.cov(X, rowvar=False)

print(cov_recon)
print(cov)

[[ 7.82  -0.645  2.915  2.522  1.944 -6.129  1.915 -0.308  0.352  2.207]
 [-0.645  3.752 -2.887 -0.948  1.736 -0.905  0.407 -1.12  -0.711 -0.411]
 [ 2.915 -2.887 17.034 -6.815  0.54  -7.634  3.163  5.359  2.248  4.239]
 [ 2.522 -0.948 -6.815 15.145  0.577  5.117 -5.266 -4.68   1.255  0.245]
 [ 1.944  1.736  0.54   0.577  4.355 -0.798 -0.741 -2.336  0.869  1.189]
 [-6.129 -0.905 -7.634  5.117 -0.798 16.148 -4.797  0.346  2.728  0.125]
 [ 1.915  0.407  3.163 -5.266 -0.741 -4.797 13.306  6.837 -4.172  4.336]
 [-0.308 -1.12   5.359 -4.68  -2.336  0.346  6.837 10.972  0.955  4.711]
 [ 0.352 -0.711  2.248  1.255  0.869  2.728 -4.172  0.955  3.879  1.377]
 [ 2.207 -0.411  4.239  0.245  1.189  0.125  4.336  4.711  1.377  5.42 ]]
tensor([[ 7.7136, -0.6979,  2.8631,  2.5697,  1.8827, -5.8404,  1.8831, -0.1574,
          0.4195,  2.2244],
        [-0.6979,  3.6852, -2.9118, -0.8948,  1.6696, -0.8958,  0.3533, -1.2211,
         -0.7452, -0.4812],
        [ 2.8631, -2.9118, 17.0376, -6.8457,  0.549

In [7]:
xs = torch.randn(10, requires_grad=True)
ys = torch.tanh(xs)

v = torch.randn(10)
_, out = jvp(lambda xs: torch.tanh(xs), (xs,), (v,), create_graph=True)
print(out)

tensor([ 0.5513,  1.1239, -0.3528, -0.6187,  0.1626, -0.9629,  1.3004,  0.0809,
        -0.7823, -0.2454], grad_fn=<TanhBackwardBackward0>)
