In [33]:
import numpy as np
from scipy.special import softmax

f = lambda x: 0.1 * x**3

a = 10 * np.random.rand(100)
print(np.mean(np.exp(f(a)) * a) / np.mean(np.exp(f(a))))
print(np.sum(np.exp(f(a)) * a) / np.sum(np.exp(f(a))))
print(np.sum(softmax(f(a)) * a))

9.841532088251236
9.841532088251236
9.841532088251236


In [34]:
import jax
from jax import numpy as jnp

m = jnp.array([[0.0, 0.0], [5.0, -5.0], [20.0, -20.0]])
S = jnp.tile(jnp.eye(2), (3, 1, 1))
x = jax.random.multivariate_normal(key=jax.random.PRNGKey(0), mean=m, cov=S, shape=(100, 3))
print(x.shape)
print(x.mean(axis=0))

(100, 3, 2)
[[ -0.07354128   0.06427459]
 [  4.967654    -5.0375333 ]
 [ 20.176752   -20.075262  ]]


In [None]:
import math
import jax
from jax import numpy as jnp

def add(x):
    return x.sum()

a = jax.random.uniform(jax.random.PRNGKey(0), shape=(10, 5))
a_sum_list = []
for i in range(a.shape[1]):
    a_sum_list.append(add(a[:, i]))
a_sum = jax.vmap(add, in_axes=1, out_axes=0)(a)
print(jnp.array(a_sum_list))
print(a_sum)

ValueError: axis 1 is out of bounds for array of dimension 1

In [4]:
import jax
from jax import numpy as jnp
import matplotlib.pyplot as plt

a = 1.0
N = 5

S = jnp.exp(-0.5 * (jnp.subtract.outer(jnp.arange(N), jnp.arange(N)))**2 / 1.0**2)

A = jnp.eye(N, k=1) + jnp.diag(jnp.array([0.0] * (N-1) + [1.0]))
print(A)

print(A @ S @ A.T)

[[0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 1.]]
[[1.         0.60653067 0.13533528 0.011109   0.011109  ]
 [0.60653067 1.         0.60653067 0.13533528 0.13533528]
 [0.13533528 0.60653067 1.         0.60653067 0.60653067]
 [0.011109   0.13533528 0.60653067 1.         1.        ]
 [0.011109   0.13533528 0.60653067 1.         1.        ]]


In [35]:
import jax
from jax import numpy as jnp
from time import perf_counter

def max_min_eig(A):
    eigvals = jnp.linalg.eigvals(A)
    lam_max = jnp.max(eigvals.real)
    lam_min = jnp.min(eigvals.real)
    zero = 0
    return lam_max, lam_min, zero

@jax.jit
def max_eig(A):
    lam_max, lam_min, zero = max_min_eig(A)
    return zero

max_eig(jnp.zeros((500, 500)))

compute_time = 0.0
key = jax.random.PRNGKey(0)
lam_max = []
for i in range(100):
    key = jax.random.split(key, 1).squeeze()
    A = jax.random.uniform(key, (500, 500))
    tic = perf_counter()
    lam_max.append(max_eig(A))
    compute_time += perf_counter() - tic
lam_max_mean = sum(lam_max) / len(lam_max)
print(lam_max_mean, compute_time)

0.0 0.010122553999735828
