In [28]:
import jax
from jax import lax, random, numpy as jnp
from flax import linen as nn
from typing import Any, Callable, Sequence
from flax.core import freeze, unfreeze



mist = nn.Dense(features=10)
#create the layers and hidden layers

n = 10
x_lay = 10
yval = 10

#Generate the data
tok1, tok2 = random.split(random.PRNGKey(0))
x = random.normal(tok1, (10,)) # Dummy input data
params = mist.init(tok2, x) # Initialization call
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes


# Generate a key 
key = random.PRNGKey(0)
k1, k2 = random.split(key)
W = random.normal(k1, (x_lay, yval))

b = random.normal(k2, (yval,))
#Have some setting for the BLM
true_params = freeze({'params': {'bias': b, 'kernel': W}})

# Creating a value for all the dimension using random.normal 
KeyVal, key_noise = random.split(k1)
xval = random.normal(KeyVal, (n, x_lay))
yval = jnp.dot(xval, W) + b + 0.1 * random.normal(key_noise,(n, yval))

#Created a method to get the loss value for each iteration
def SquaredError(params, x_batched, y_batched):
  def lossfunc(x, y):
    pop = mist.apply(params, x)
    return jnp.inner(y-pop, y-pop) / 2.0

  #Put it in a vector and add an get the mean of the final value
  return jnp.mean(jax.vmap(lossfunc)(x_batched,y_batched), axis=0)

learning_rate = 0.2
lgF = jax.value_and_grad(SquaredError)

def update(params, learning_rate, grads):
  params = jax.tree_util.tree_map(
      
      lambda p, g: p - learning_rate * g , params , grads )

  return params

print("Using flax:")
for i in range(101):
  # Have an update for each iteration to get a value
  loss_val, grads = lgF(params, xval, yval)
  params = update(params, 0.2, grads)
  if i % 10 == 0:
    print('The loss result at ' + str(i) + " = ", loss_val)

Using flax:
The loss result at 0 =  59.25503
The loss result at 10 =  1.7329859
The loss result at 20 =  0.42431983
The loss result at 30 =  0.19928001
The loss result at 40 =  0.13186544
The loss result at 50 =  0.10130804
The loss result at 60 =  0.083989404
The loss result at 70 =  0.07311896
The loss result at 80 =  0.065888725
The loss result at 90 =  0.06083995
The loss result at 100 =  0.057132434
