#### simple nn example
#### creates a 10-[12-8]-4 neural network.

https://jamesmccaffrey.wordpress.com/2022/01/18/the-flax-neural-network-library/

In [3]:
from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn



class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

# creates a 10-[12-8]-4 neural network.
model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))



variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1.

#### LR w/ Flax
https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/flax_basics.ipynb

In [9]:
import jax
from typing import Any, Callable, Sequence, Optional
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn

# We create one dense layer instance (taking 'features' parameter as input)
model = nn.Dense(features=5)

Layers (and models in general, we'll use that word from now on) are subclasses of the linen.Module class.

Model parameters & initialization: 
Parameters are not stored with the models themselves. You need to initialize parameters by calling the init function, using a PRNGKey and a dummy input parameter.

In [22]:
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
#print(x.shape)

params = model.init(key2, x) # Initialization call
#print('\n',params)

jax.tree_map(lambda x: x.shape, params) # Checking output shapes


FrozenDict({
    params: {
        bias: (5,),
        kernel: (10, 5),
    },
})

In [17]:
model.apply(params, x)

DeviceArray([-1.3721193 ,  0.61131495,  0.6442836 ,  2.2192965 ,
             -1.1271116 ], dtype=float32)

In [18]:
# Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5

# Generate random ground truth W and b.
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
x_samples = random.normal(key_sample, (n_samples, x_dim))
y_samples = jnp.dot(x_samples, W) + b + 0.1 * random.normal(key_noise,(n_samples, y_dim))
print('x shape:', x_samples.shape, '; y shape:', y_samples.shape)

x shape: (20, 10) ; y shape: (20, 5)


In [23]:
# Same as JAX version but using model.apply().
def mse(params, x_batched, y_batched):
  # Define the squared loss for a single pair (x,y)
  def squared_error(x, y):
    pred = model.apply(params, x)
    return jnp.inner(y-pred, y-pred) / 2.0
  # Vectorize the previous to compute the average of the loss on all samples.
  return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)

Gradient Descent

In [36]:
learning_rate = 0.3  # Gradient step size.
print('Loss for "true" W,b: ', mse(true_params, x_samples, y_samples))
loss_grad_fn = jax.value_and_grad(mse)

#same as the line above - loss_grad_fn = jax.value_and_grad(mse)
@jax.value_and_grad
def loss_grad_fn(params, x_batched, y_batched):
  return mse(params, x_batched, y_batched)

#also jax's way
@jax.jit
def update_params(params, learning_rate, grads):
  params = jax.tree_map(
      lambda p, g: p - learning_rate * g, params, grads)
  return params

for i in range(101):
  # Perform one gradient update.
  loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
  params = update_params(params, learning_rate, grads)
  if i % 10 == 0:
    print(f'Loss step {i}: ', loss_val)

Loss for "true" W,b:  0.02363979
Loss step 0:  0.011577631
Loss step 10:  0.011571459
Loss step 20:  0.011569389
Loss step 30:  0.011568717
Loss step 40:  0.011568493
Loss step 50:  0.011568409
Loss step 60:  0.011568387
Loss step 70:  0.011568378
Loss step 80:  0.011568379
Loss step 90:  0.011568367
Loss step 100:  0.01156838


### Optimizing with Optax

Flax used to use its own `flax.optim` package for optimization, but with
[FLIP #1009](https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md)
this was deprecated in favor of
[Optax](https://github.com/deepmind/optax).

Basic usage of Optax is straightforward:

1.   Choose an optimization method (e.g. `optax.sgd`).
2.   Create optimizer state from parameters.
3.   Compute the gradients of your loss with `jax.value_and_grad()`.
4.   At every iteration, call the Optax `update` function to update the internal
     optimizer state and create an update to the parameters. Then add the update
     to the parameters with Optax's `apply_updates` method.

Note that Optax can do a lot more: it's designed for composing simple gradient
transformations into more complex transformations that allows to implement a
wide range of optimizers. There is also support for changing optimizer
hyperparameters over time ("schedules"), applying different updates to different
parts of the parameter tree ("masking") and much more. For details please refer
to the
[official documentation](https://optax.readthedocs.io/en/latest/).


In [1]:
import optax
    
tx = optax.sgd(learning_rate=alpha)

print (alpha)
opt_state = tx.init(params)
loss_grad_fn = jax.value_and_grad(mse)



NameError: name 'alpha' is not defined

In [31]:
for i in range(101):
  loss_val, grads = loss_grad_fn(params)
  updates, opt_state = tx.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  if i % 10 == 0:
    print('Loss step {}: '.format(i), loss_val)

TypeError: mse() missing 2 required positional arguments: 'x_batched' and 'y_batched'