In [1]:
import pytensor
import numpy as np
import jax
import jax.numpy as jnp
import optax
from jax import grad

from pytensor_ml.loss import SquaredError
from pytensor_ml.model import Model
from pytensor_ml.optimizers import Adam

In [2]:

# Set seeds for reproducibility
rng = np.random.default_rng(seed=42)

# Data (simple linear regression y = 3x + 2)
X_np = rng.random(size=(10, 1), dtype=np.float32)
y_np = 3 * X_np + 2 + rng.normal(scale=0.1, size=(10, 1)).astype(np.float32)

In [3]:
X_np.shape

(10, 1)

In [None]:
# Pytensor setup
X_in = pytensor.tensor.tensor('X_in', shape=X_np.shape)
y_in = 3 * X_in + 2 + rng.normal(scale=0.1, size=(10, 1)).astype(np.float32)
f = pytensor.function([X_np], y_np)
model = Model(X_np, f)
loss_fn = SquaredError()
optim = Adam(model, loss_fn, ndim_out=1, learning_rate=0.01)


In [None]:

# JAX setup
def model_jax(params, x):
    return jnp.dot(x, params['w']) + params['b']

params_jax = {
    'w': jnp.array([[0.0]], dtype=jnp.float32),
    'b': jnp.array([0.0], dtype=jnp.float32)
}

def loss_fn(params, x, y):
    preds = model_jax(params, x)
    return jnp.mean((preds - y) ** 2)

optimizer_jax = optax.adam(learning_rate=0.01)
opt_state = optimizer_jax.init(params_jax)

# Training step for JAX
@jax.jit
def train_step_jax(params, opt_state, x, y):
    grads = grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer_jax.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_fn(params, x, y), grads

# Convert numpy arrays to torch tensors and jax arrays
X_jax = jnp.array(X_np)
y_jax = jnp.array(y_np)

# Training loop
num_epochs = 100

for epoch in range(num_epochs):
    params_jax, opt_state, loss_jax, updates_jax = train_step_jax(params_jax, opt_state, X_jax, y_jax)

    if epoch % 10 == 0:
        # Something to compare