In [1]:
%env CUDA_VISIBLE_DEVICES=8
import jax
from jax import numpy as jnp, vmap, jit, random, lax, value_and_grad
from jax.numpy import linalg as jla
from jax.flatten_util import ravel_pytree
from neural_tangents import taylor_expand

from util import fold, laxmap

from matplotlib import pyplot as plt
import numpy as np
import pickle

from flax import linen as nn
from flax.linen import initializers as jinit
from functools import partial
from typing import Callable

env: CUDA_VISIBLE_DEVICES=8


In [2]:
# this is default PyTorch init for ReLU
torch_init = jinit.variance_scaling(1 / 2, "fan_in", "uniform")
TorchLinear = partial(
    nn.Dense, kernel_init=torch_init, bias_init=jinit.zeros, dtype=None
)

width = 100

class sMLP(nn.Module):
    sigma: Callable

    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)
        x = TorchLinear(width)(x)
        x = self.sigma(x)
        x = TorchLinear(1)(x)
        x = x[...,0]
        return x

In [3]:
seed = 1
model_rng, train_rng, test_rng, fn_rng, key = random.split(random.PRNGKey(seed),5)

In [91]:
def get_test_loss(d, n, T):

    beta = random.normal(fn_rng, (d,))
    beta = beta/jla.norm(beta)

    def f_star(x):
        A = jnp.diag(jnp.array([i > d/2 for i in range(d)]))
        return x.T @ A @ x / jnp.sqrt(675.) + ((beta @ x)**3 - 3*(beta @ x))/jnp.sqrt(6)

    X_train = random.normal(train_rng, (d, n)).T
    y_train = vmap(f_star)(X_train)

    X_test = random.normal(test_rng, (d, 10000)).T
    y_test = vmap(f_star)(X_test)

    sigma = lambda z: jax.nn.relu(z)
    model = sMLP(sigma=sigma)
    init_params = model.init(model_rng,X_train[:1])
    init_params, unravel = ravel_pytree(init_params)
    f = lambda p,x: model.apply(unravel(p),x)
    lr = 0.05

    @jit
    def step(params):

        loss_fn = lambda W : jnp.mean((y_train - f(W, X_train))**2)
        test_loss_fn = lambda W : jnp.mean((y_test - f(W, X_test))**2)
        
        loss = loss_fn(params)

        grads = jax.grad(loss_fn)(params)
        test_loss = test_loss_fn(params)
        params = params-lr*grads

        return dict(state=params,save=(loss, test_loss))

    params = init_params

    res = fold(step,params,steps=T,show_progress=True)
    return res['save'][1][-1]



In [92]:
def get_best_n(d, threshhold=0.1):
    ns = [i*int(d**2) for i in range(1, 21)]
    for n in ns:
        print(d, n)
        test_loss = get_test_loss(d, n, 30000)
        if test_loss < threshhold:
            return n

In [None]:
ns_nostop = []
for d in [10, 20, 30, 40, 50, 100]:
    ns_nostop.append(get_best_n(d))

In [None]:
fig,axs = plt.subplots(1, 1,figsize=(6, 4))
axs = np.ravel(axs)

plt.rcParams['font.size'] = '16'

plt.sca(axs[0])
ds = [10, 20, 30, 40, 50, 100]
plt.plot(ds, [6*d**2 for d in ds], linestyle = 'dotted', color='gray', label=r'$n = 6d^2$')
plt.plot(ds, [d**3 for d in ds], linestyle = 'dashed', color='red', label=r'$n = d^3$')
plt.plot(ds, ns_nostop, marker = 'o', color = 'blue', label='samples')
plt.yscale("log")
plt.xscale("log")
plt.xlabel(r'$d$')
plt.ylabel(r'$n$')
plt.legend()
plt.xticks(ds, ['10', '20', '30', '40', '50', '100'])
plt.tight_layout()