# Flax basics

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import jax
import flax
import flax.linen as nn
import optax

from utils import MLP, AutoEncoder

## Random data

In [None]:
# create initial key
key = jax.random.key(0)

In [None]:
# generate random data
num_features = 100
batch_size = 32

key, subkey = jax.random.split(key)

x = jax.random.normal(subkey, (batch_size, num_features))

print(f'Input shape: {x.shape}')

## Multilayer perceptron

In [None]:
# create model
features = [64, 32, 16]

mlp = MLP(features)

In [None]:
# summarize model
key, *subkeys = jax.random.split(key, num=3)

print(
    mlp.tabulate(
        subkeys[0],
        jax.random.normal(subkeys[1], (1, num_features))
    )
)

In [None]:
# create parameters (that are not stored in the module instance)
key, *subkeys = jax.random.split(key, num=3)

params = mlp.init(
    subkeys[0],
    jax.random.normal(subkeys[1], (1, num_features)) # example inputs
)

# run model (shape inference is triggered)
y = mlp.apply(params, x)

print(f'Output shape: {y.shape}')

## Autoencoder

In [None]:
# create model
enc_features = [64, 32, 16]
dec_features = [32, 64, num_features]

ae = AutoEncoder(enc_features, dec_features)

In [None]:
# create parameters
key, *subkeys = jax.random.split(key, num=3)

params = ae.init(
    subkeys[0],
    jax.random.normal(subkeys[1], (1, num_features))
)

# run model
x_hat = ae.apply(params, x)

print(f'Output shape: {x_hat.shape}')

In [None]:
# run encoder/decoder submodules
z = ae.apply(params, x, method=ae.encode)
x_hat = ae.apply(params, z, method=ae.decode)

print(f'Encoding shape: {z.shape}')
print(f'Output shape: {x_hat.shape}')