In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
from jax import vmap, jit
import jax.numpy as jnp
from pathlib import Path
import numpy as np
from einops import rearrange
import jax
import matplotlib.pyplot as plt


In [4]:
N = 1_000
A = np.random.uniform(0,1,size=N)*100
c1 = np.sum(np.diff(A))
c2 = A[-1] - A[0]
err = (c1-c2)/(c2)
c1, c2, err

(51.22869590700778, 51.22869590700807, -5.548005649411022e-15)

In [5]:
N = 1_000
A = np.random.uniform(0,1,size=N)*100
A = jnp.asarray(A)
c1 = jnp.sum(jnp.diff(A))
c2 = A[-1] - A[0]
err = (c1-c2)/(c2)
c1, c2, err

(Array(37.84372, dtype=float32),
 Array(37.843845, dtype=float32),
 Array(-3.3264328e-06, dtype=float32))

In [10]:
from hflow.net.build import build_mlp
from jax import vmap, grad, jit
u_config = {'width': 55,
            'layers': ['D']*7,
            'activation': 'swish'}

key = jax.random.key(np.random.randint(1e7))
u_fn, params = build_mlp(u_config, in_dim=1, out_dim=1, key=key)
u_fn_V = vmap(u_fn, (None, 0))

@jit
def loss_1(params, X):
    Y = u_fn_V(params, X)
    l = jnp.sum(Y[1:] - Y[:-1])
    return l
@jit
def loss_2(params, X):
    Y = u_fn_V(params, X)
    Y = jnp.squeeze(Y)
    l = Y[-1] - Y[0]
    return l

N = 5000
X = np.random.uniform(0,1,size=(N, 1))*10
X = np.linspace(-4,10,N)
X = jnp.asarray(X).reshape(-1,1)



l1 = loss_1(params, X)
l2 = loss_2(params, X)
err = (l1-l2)/(l2)
l1, l2, err

(Array(-0.14288783, dtype=float32),
 Array(-0.14288783, dtype=float32),
 Array(-0., dtype=float32))

In [11]:
g1 = grad(loss_1)(params, X)
g2 = grad(loss_2)(params, X)
g1 = jax.flatten_util.ravel_pytree(g1)[0]
g2 = jax.flatten_util.ravel_pytree(g2)[0]
(g1 - g2).sum()

Array(0., dtype=float32)