# Introduction

## 1. Install REAX

```bash
pip install reax
```

## 2. Define a ReaxModule

A reax.Module keeps track of your model parameter and give you a place to put the code for the various steps in your training loop (training_step, validation_step, etc).

In [None]:
import os
from functools import partial
from flax import linen
import jax
import optax
import reax
from reax import demos


class Autoencoder(linen.Module):
    def setup(self):
        super().__init__()
        self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
        self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])

    def __call__(self, x):
        z = self.encoder(x)
        return self.decoder(z)


class ReaxAutoEncoder(reax.Module):
    def __init__(self):
        super().__init__()
        self.ae = Autoencoder()

    def configure_model(self, stage: reax.Stage, batch, /):
        if self.parameters() is None:
            inputs, _ = batch
            params = self.ae.init(self.rngs(), inputs)
            self.set_parameters(params)

    def training_step(self, batch, batch_idx):
        x = batch[0].reshape(len(batch[0]), -1)
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(self.parameters(), x, self.ae)
        self.log("train_loss", loss, on_step=True, prog_bar=True)
        return loss, grads

    @staticmethod
    @partial(jax.jit, static_argnums=2)
    def loss_fn(params, x, model):
        predictions = model.apply(params, x)
        return optax.losses.squared_error(predictions, x).mean()

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state

autoencoder = ReaxAutoEncoder()

## 3. Define a dataset

REAX supports any iterable (numpy arrays, lists etc) for the train/val/test/predict datasets.

In [None]:
# Setup the data
dataset = demos.mnist.MnistDataset(download=True)
data_loader = reax.ReaxDataLoader(dataset)

## 4. Train the mode

The REAX Trainer takes the module and dataset and combines them in a training loop, automating away most of the boilerplate.

In [None]:
trainer = reax.Trainer()
trainer.fit(autoencoder, data_loader, limit_train_batches=100, max_epochs=1)

## 5. Use the model