<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/intro_to_flax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install Flax and JAX
!pip install --upgrade -q "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
!pip install --upgrade -q git+https://github.com/google/flax.git
!pip install --upgrade -q git+https://github.com/deepmind/dm-haiku  # Haiku is here just for comparison purposes

[K     |████████████████████████████████| 138.5 MB 58 kB/s 
[K     |████████████████████████████████| 126 kB 5.1 MB/s 
[K     |████████████████████████████████| 65 kB 2.6 MB/s 
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone
  Building wheel for dm-haiku (setup.py) ... [?25l[?25hdone


In [None]:
import jax
from jax import lax, random, numpy as jnp

# NN lib built on top of JAX developed by Google Research (Brain team)
# Flax was "designed for flexibility" hence the name (Flexibility + JAX -> Flax)
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn  # nn notation also used in PyTorch and in Flax's older API
from flax.training import train_state  # a useful dataclass to keep train state

# DeepMind's NN JAX lib - just for comparison purposes, we're not learning Haiku here
import haiku as hk 

# JAX optimizers - a separate lib developed by DeepMind
import optax

# Flax doesn't have its own data loading functions - we'll be using PyTorch dataloaders
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# Python libs
import functools  # useful utilities for functional programs
from typing import Any, Callable, Sequence, Optional

# Other important 3rd party libs
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# init a new base model
model = nn.Dense(features=5)
nn.Dense.__bases__

(flax.linen.module.Module,)

In [None]:
# init random PRNG
k1, k2 = jax.random.split(jax.random.PRNGKey(2022))
x = jax.random.normal(k1, (10,))  # random (10, 1) dataset

y, params = model.init_with_output(k2, x)  # or use init() to remove output y
jax.tree_map(lambda x: x.shape, params)  # kernel is weight, biases are init and inferred shape of input

# apply feedforward
y = model.apply(params, x)

DeviceArray([-1.1229506 ,  0.22319266,  1.4273638 , -1.3900615 ,
             -1.0684718 ], dtype=float32)