In [1]:
import jax
import jax.numpy as jnp
import jaxtorch
from jaxtorch import nn, Context
from tqdm import tqdm
import optax

In [2]:
# Define a model using pytorch-style modules
class MLP(jaxtorch.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, output_dim)
    
    # The forward function accepts cx, a Context object as the first argument
    # always. This provides random number generation as well as the parameters.
    def forward(self, cx: Context, x: jax.Array):
        x = jax.nn.relu(self.linear1(cx, x))
        return self.linear2(cx, x)
    
    # If you want to do a custom initialization of paramaters, you can override setup()
    def setup(self, cx):
        super().setup(cx)
        print('Setup!')
        cx[self.linear2.weight] = jnp.zeros(self.linear2.weight.shape)

# Create and initialize the model
model = MLP(10, 64, 1)
rng = jaxtorch.PRNG(jax.random.PRNGKey(0))
params = model.initialize(rng.split())

# In jaxtorch, parameters are stored seperately from the model itself.
# The params dict contains all parameters of the model indexed by name.
print('params=', {k: v.shape for (k, v) in params.items()})

# The "weights" in the Module tree are jaxtorch.Param objects which store only the name and shape.
print('model.linear2.weight=', model.linear2.weight)

Setup!
params= {'linear1.weight': (64, 10), 'linear1.bias': (64,), 'linear2.weight': (1, 64), 'linear2.bias': (1,)}
model.linear2.weight= <Param at linear2.weight (1, 64)>


In [3]:
# Create an optimizer using optax
opt = optax.adam(0.01)
opt_state = opt.init(params)

# Training loop example with dummy data
@jax.jit
def loss_fn(params, key, x, y):
    # To run the model forward, we need to create the Context object and pass it.
    cx = Context(params, key)
    pred = model(cx, x)
    return jnp.mean((pred - y) ** 2)

grad_fn = jax.value_and_grad(loss_fn)

x = jax.random.normal(rng.split(), (100, 10))
y = jax.random.normal(rng.split(), (100, 1))

print('Initial loss:', loss_fn(params, rng.split(), x, y))
for step in tqdm(range(100)):
    loss, grads = grad_fn(params, rng.split(), x, y)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
print('Final loss:', loss)

Initial loss: 0.9808868


100%|██████████| 100/100 [00:00<00:00, 118.37it/s]

Final loss: 0.0001356766



