https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html#Implementing-a-Neural-Network-with-Flax 

In [2]:
import numpy as np
import flax
from flax import linen as nn
import jax
import jax.numpy as jnp

import torch.utils.data as data

In [None]:
class SimpleClassifier(nn.Module):
    num_hidden: int
    num_outputs: int
    
    def setup(self):
        self.linear1 = nn.Dense(features = self.num_hidden)
        self.linear2 = nn.Dense(features = self.num_outputs)

    # Forward in pytorch
    def __call__(self, x):
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)
        return x

In [None]:
# nn.compact: remove the for setup

class SimpleClassifierCompact(nn.Module):
    num_hiden: int
    num_outputs: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features = self.num_hiden)(x)
        x = nn.tanh(x)
        x = nn.Dense(features = self.num_outputs)(x)
        return x

In [None]:
model  = SimpleClassifier(num_hidden=8, num_outputs=1)
print(model)

In [None]:
rng, input_rng, init_rng = jax.random.split(jax.random.PRNGKey(0), 3)
inputs = jax.random.normal(input_rng, (8, 2))
params = model.init(init_rng, inputs)
print(f"Parameters: {params=}")

In [None]:
model.apply(params, inputs)

In [None]:
class XORDataset(data.Dataset):
    def __init__(self, size, seed, std=0.1):
        super().__init__()
        self.size = size
        self.np_rng = np.random.RandomState(seed=seed)
        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]