In [11]:
from dataclasses import dataclass
from functools import partial
from typing import Callable, Optional

import jax
import jax.numpy as jnp
from jax import jit, vmap, grad, random, lax

import os
import pwd
os.chdir("/home/zongchen/thinned_mfld")
from utils.configs import CFG
from utils.problems import Problem
from jaxtyping import Array 
from jax_tqdm import scan_tqdm


In [12]:
def q1_nn(z, x):
    # Simple 2-layer NN for demonstration
    d_hidden = x.shape[0] - 3
    W1, b1, W2, b2 = x[:d_hidden], x[d_hidden+1], x[d_hidden+2], x[d_hidden+3]
    h = jnp.tanh(z @ W1 + b1)
    return jnp.dot(W2, h) + b2

In [13]:
Z = jnp.array([[0.5, 1.0], [1.5, -0.5], [-1.0, 2.0]])
x_all = jax.random.normal(jax.random.PRNGKey(0), (3, 5))  # 3 samples of NN params

preds_all = jax.vmap(q1_nn, in_axes=(None, 0))(Z, x_all)
preds = preds_all.mean(axis=0)
print("Predictions:", preds)

Predictions: [ 0.0637453  0.8020474 -0.380552 ]


In [None]:
def two_layer_nn(params, x):
    """Forward pass: x ∈ R^{d_in}, output ∈ R^{d_out}"""
    W1, b1, W2, b2 = params
    h = jnp.tanh(x @ W1 + b1)
    return h @ W2 + b2

preds = two_layer_nn(
    params=x_all,
    x=Z,
)
print("Predictions (two_layer_nn):", preds)