## What is Jax?
Jax is a Python library designed for high-performance ML research. Jax is nothing more than a numerical computing library, just like Numpy, but with some key improvements. It was developed by Google and used internally both by Google and Deepmind teams.

### Jax basics

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

In [2]:
x=np.zeros(10)
x

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [3]:
y=jnp.zeros(10)
y



DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [4]:
# Basics of jax
x = np.random.rand(1000,1000)
y = jnp.array(x)

%timeit -n 1 -r 1 np.dot(x,x)
# 1 loop, best of 1: 52.6 ms per loop

%timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready()
# 1 loop, best of 1: 1.47 ms per loop

12.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
72.7 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


The calculations are faster in the GPUs. Also we need the block_until_ready() function. Because JAX is asynchronous, we need to wait until the execution is complete in order to properly measure the time

### Auto differentiation with grad() function

#### singular value

In [5]:
from jax import grad
def f(x):
    return x**2+2*x+1
def f_diff(x):
    return 2*x+2
print(grad(f)(1.0)) # grad return a derivative function 
print(f_diff(1.0))

4.0
4.0


#### vectors

In [6]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [7]:
def f(x):
    return np.sum(jnp.power(x,2) + 2*x +1)  # f needs to have a scaler output

In [8]:
grad(f)(x_small)

DeviceArray([2., 4., 6.], dtype=float32)

In [10]:
x = np.random.rand(1000,1000)  # even for matrix, just needs to explicitly address the position where we wanna differentiate
y = jnp.array(x) 

In [12]:
grad(f)(y)

DeviceArray([[3.5051565, 3.11126  , 3.457899 , ..., 3.4057453, 2.9257603,
              2.1074712],
             [2.4986985, 2.776178 , 2.7964032, ..., 2.7231436, 2.7258153,
              2.4071379],
             [3.6057768, 3.9050226, 2.598824 , ..., 3.2422433, 2.110405 ,
              2.3537593],
             ...,
             [3.58633  , 3.0957935, 3.8398504, ..., 2.5640457, 2.6462164,
              3.9038177],
             [3.57158  , 2.1966538, 3.2050138, ..., 2.3495917, 2.3040285,
              2.8068812],
             [3.541185 , 3.4056678, 3.442199 , ..., 2.9499464, 2.4896505,
              2.3615398]], dtype=float32)

#### Accelerated Linear Algebra (XLA compiler)
One of the factors that make JAX so fast is also Accelerated Linear Algebra or XLA.XLA is a domain-specific compiler for linear algebra that has been used extensively by Tensorflow.In order to perform matrix operations as fast as possible, the code is compiled into a set of computation kernels that can be extensively optimized based on the nature of the code.Just in time (jit) compilation comes hand in hand with XLA. In order to take advantage of the power of XLA, the code must be compiled into the XLA kernels. This is where jit comes into play.

In [6]:
# from jax import jit
from jax import jit

In [7]:
x = np.random.rand(1000,1000)
y = jnp.array(x)

def f(x):

  for _ in range(10):
      x = 0.5*x + 0.1* jnp.sin(x)

  return x

g = jit(f)

%timeit -n 5 -r 5 f(y).block_until_ready()
# 5 loops, best of 5: 10.8 ms per loop

%timeit -n 5 -r 5 g(y).block_until_ready()
# 5 loops, best of 5: 341 µs per loop

175 ms ± 31.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
47.9 ms ± 15.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)


jit can also be combined with grad transformation (or any other transformation for that matter), making backpropagation super fast.Pmap is another transformation that enables us to replicate the computation into multiple cores or devices and execute them in parallel(p in pmap stands for parallel)

In [8]:
from jax import pmap

In [9]:
f(np.arange(4))

DeviceArray([0.        , 0.00580175, 0.01047094, 0.0139569 ], dtype=float32)

In [10]:
pmap(f)(np.arange(4))# no avialble devieces

ValueError: compiling computation that requires 4 logical devices, but only 1 XLA devices are available (num_replicas=4, num_partitions=1)

In [11]:
from jax import vmap
f(jnp.arange(10))

DeviceArray([0.        , 0.00580175, 0.01047094, 0.0139569 , 0.01694178,
             0.01998397, 0.02311684, 0.02610062, 0.02872661, 0.0309466 ],            dtype=float32)

Automatic vectorization with vmap
Vmap is, as the name suggests, a function transformation that enables us to vectorize functions (v stands for vector!).
We can take a function that operates on a single data point and vectorize it so it can accept a batch of these data points (or a vector) of arbitrary size.

In [86]:
def f(x):
    return vmap(f)(jnp.arange(x))

#### There have been already couple of popular libraries built on Jax
* Haiku: Haiku is the go-to framework for Deep Learning and it’s used by many Google and Deepmind internal teams. It provides some simple, composable abstractions for machine learning research as well as ready-to-use modules and layers.

* Optax: Optax is a gradient processing and optimization library that contains out-of-the-box optimizers and related mathematical operations.

* RLax: RLax is a reinforcement learning framework with many RL subcomponents and operations.

* Chex: Chex is a library of utilities for testing and debugging JAX code.

* Jraph: Jraph is a Graph Neural Networks library in JAX.

* Flax: Flax is another neural network library with a variety of ready-to-use modules, optimizers, and utilities. It’s most likely the closest we have in an all-in JAX framework.

* Objax: Objax is a third ml library that focuses on object-oriented programming and code readability. Once again it contains the most popular modules, activation functions, losses, optimizers as well a handful of pre-trained models.

* Trax: Trax is an end-to-end library for deep learning that focuses on Transformers

* JAXline: JAXline is a supervised-learning library that is used for distributed JAX training and evaluation.

* ACME: ACME is another research framework for reinforcement learning.

* JAX-MD: JAX-MD is a niche framework that deals with molecular dynamics.

* Jaxchem: JAXChem is another niche library that emphasizes on chemical modeling.

#### Test Jax for a loss function autograd

#### Stax is a small but flexible neural net specification library from scratch. :https://jax.readthedocs.io/en/latest/jax.experimental.stax.html

In [52]:
import torch
import pandas as pd
from sklearn import model_selection
from sklearn.preprocessing import MinMaxScaler,StandardScaler,Normalizer
from sklearn.metrics import mean_absolute_error,mean_absolute_error
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from numpy import exp,log,sqrt
from scipy import stats
from scipy.stats import norm

In [53]:
# Black Scholes Put option and Greeks
def bsput(s0,k,t,r,sigma):
    """
    s0: spot price
    k: strike
    t: T-t maturity
    r: risk free rate
    sigma: volatility
    y: dividend yield ==0
    """
    d1=(log(s0/k)+(r+sigma**2/2)*t)/(sigma*sqrt(t))
    d2=d1-sigma*sqrt(t)    
    return -s0*norm.cdf(-d1)+k*exp(-r*t)*norm.cdf(-d2),s0*norm.pdf(d1)*t**0.5 

def closedform_Euro(row):
    return bsput(row['S'],row['K'],row['T'],row['rf'],row['Vol'])

cdf = torch.distributions.Normal(0,1).cdf
pdf = lambda x: torch.distributions.Normal(0,1).log_prob(x).exp()


def bs_greeks(s0,k,t,r,sigma):
    d1=(log(s0/k)+(r+sigma**2/2)*t)/(sigma*sqrt(t))
    d2=d1-sigma*sqrt(t)
    delta=-norm.cdf(-d1)
    gamma=norm.pdf(d1)/s0/sigma/np.sqrt(t)
    theta=-s0*sigma/2/np.sqrt(t) *norm.pdf(d1)+r*k*exp(-r*t)*norm.cdf(-d2)
    return delta,gamma,theta

def closedform_greeks(row):
    return bs_greeks(row['S'],row['K'],row['T'],row['rf'],row['Vol'])

def numerical_greeks(df_test):
    """
    Input Dataframe:
    S: spot price
    Put_nn: estimated values from model
    Delta: closed form
    Gamma: closed form
    """
    diff=df_test[['S','Put_nn']].sort_values('S').diff(1).dropna().reset_index(drop=True)
    diff['Delta_c']=diff['Put_nn']/diff['S']
    diff['Gamma']=diff['Delta_c'].diff(1)/diff['S'].shift(1)
    
    # Numerical & closed form Delta plot
    plt.rcParams['agg.path.chunksize'] = 100000
    plt.figure(figsize=(10,6))
    plt.scatter(df_test.sort_values('S')['S'][:-1],diff['Delta_c'].values, color='orange')
    plt.show()
    
    plt.figure(figsize=(10,6))
    plt.scatter(df_test['S'],df_test['Delta'].values)
    plt.show()
    
    # Numerical & closed form Gamma plot
    plt.rcParams['agg.path.chunksize'] = 100000
    plt.figure(figsize=(10,6))
    plt.scatter(df_test.sort_values('S')['S'][:-1],diff['Gamma'].values, color='orange')
    plt.show()
    
    plt.figure(figsize=(10,6))
    plt.scatter(df_test['S'],df_test['Gamma'].values)
    plt.show()

In [54]:
def generate_random_train_data(n=250000, corner=False):
    df=pd.DataFrame({"S":np.ones(n),
                     "K":np.random.uniform(low=0.5,high=1.5,size=n),
                     "Vol":np.random.uniform(low=0.1,high=0.4,size=n),
              "T":np.random.uniform(low=0.25,high=2,size=n),
                     "rf":np.random.uniform(low=0.0025,high=0.025,size=n)})
    
    if corner:
        df_corner=pd.DataFrame({"S":np.ones(n),
                         "Vol":np.ones(n)*0.2,
                         "T":np.ones(n),
                         "rf":np.ones(n)*0.01})

        df_corner['K']=df_corner['S']*np.random.uniform(low=0.8,high=0.95,size=n) # special region
        df = df.append(df_corner)

    df[["Put","Vega"]]=df.apply(closedform_Euro, axis=1, result_type="expand")

    # weighted by T
    df["Put"] = df["Put"]/df["K"]/np.sqrt(df['T'])
    df['K']=df['K']/np.sqrt(df['T'])
    
    return df


In [55]:
# 50 * 30 * 7 * 10 = 105K
def generate_mesh_train_data():
    dimension = 5
    K = 1
    S = np.concatenate((np.linspace(1/0.5, 1/0.9, 20), np.linspace(1/0.8, 1/1.5, 20)), axis=None)
    Vol = np.linspace(0.1, 0.4, 11)
    T = np.concatenate((np.linspace(0.25, 1, 10), np.linspace(1, 2, 4)), axis=None)
    rf = np.linspace(0.0025, 0.025, 11)
    grid_data = np.stack(np.meshgrid(S, K, Vol, T, rf), dimension).reshape(-1, dimension)
    df = pd.DataFrame(grid_data, columns=['S', 'K', 'Vol', 'T', 'rf'])
    
    df[["Put", "Vega"]]=df.apply(closedform_Euro, axis=1, result_type="expand")
    df["Put"] = df["Put"]/df["K"]/np.sqrt(df['T'])
    df['K']=df['K']/np.sqrt(df['T'])
    
    return df

In [56]:
def generate_test_data(n=5000, dimension='K', low_bound=0.5, up_bound=1.5):
    cdf = torch.distributions.Normal(0,1).cdf
    pdf = lambda x: torch.distributions.Normal(0,1).log_prob(x).exp()

    df=pd.DataFrame({"S":np.ones(n),
                     "K":np.ones(n), 
                     "Vol":np.ones(n) * 0.2, 
                     "T":np.ones(n),
                     "rf":np.ones(n) * 0.01})

    df[dimension] = np.random.uniform(low=low_bound, high=up_bound, size=n)
    df[["Put","Vega"]]=df.apply(closedform_Euro, axis=1, result_type="expand")
    
    # weighted by T
    df["Put"] = df["Put"]/df["K"]/np.sqrt(df['T'])
    df['K']=df['K']/np.sqrt(df['T'])
    
    return df

In [57]:
df = generate_mesh_train_data()
df_train = df.copy()
df_train['Put'] = np.where(df_train['Put'] < 0.01, 0.01, df_train['Put'])

y = df_train['Put']
x = df_train.drop(['Put','Vega'], 1)

x_train,x_test,y_train,y_test=model_selection.train_test_split(x,y,test_size=0.1)

In [9]:
pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.6.0-cp38-cp38-macosx_10_11_x86_64.whl (199.0 MB)
[K     |████████████████████████████████| 199.0 MB 448 kB/s eta 0:00:012
[?25hCollecting tensorboard~=2.6
  Downloading tensorboard-2.6.0-py3-none-any.whl (5.6 MB)
[K     |████████████████████████████████| 5.6 MB 2.0 MB/s eta 0:00:01
[?25hCollecting keras~=2.6
  Downloading keras-2.6.0-py2.py3-none-any.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 10.5 MB/s eta 0:00:01
Collecting flatbuffers~=1.12.0
  Downloading flatbuffers-1.12-py2.py3-none-any.whl (15 kB)
Collecting wrapt~=1.12.1
  Downloading wrapt-1.12.1.tar.gz (27 kB)
Collecting termcolor~=1.1.0
  Downloading termcolor-1.1.0.tar.gz (3.9 kB)
Collecting keras-preprocessing~=1.1.2
  Downloading Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
[K     |████████████████████████████████| 42 kB 3.5 MB/s  eta 0:00:01
Collecting clang~=5.0
  Downloading clang-5.0.tar.gz (30 kB)
Collecting gast==0.4.0
  Downloading g

In [61]:
import time
import itertools

import numpy.random as npr

import jax.numpy as np
from jax.config import config
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax
from jax.experimental.stax import Dense, Relu, LogSoftmax

In [62]:
def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return np.mean(np.sum((targets-preds)**2))

In [63]:
init_random_params, predict = stax.serial(
    Dense(128), Relu,
    Dense(1024), Relu,
    Dense(128), Relu,
    Dense(16),Relu,
    Dense(1))

In [64]:
step_size = 0.0001
num_epochs = 80
batch_size = 128

In [65]:
num_train = x.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

In [66]:
rng = random.PRNGKey(0)
def data_stream():
    rng = npr.RandomState(0)    
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield x.iloc[batch_idx,:].values, y.loc[batch_idx].values
batches = data_stream()

In [67]:
opt_init, opt_update, get_params = optimizers.adam(step_size)

In [68]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

In [69]:
_, init_params = init_random_params(rng, (-1, 5))
opt_state = opt_init(init_params)
itercount = itertools.count()

In [70]:
def accuracy(params, inputs,targets):
    preds=predict(params, inputs)
    return np.mean(np.sum((preds-targets)**2))

In [29]:
print("\nStarting training...")
for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
        opt_state = update(next(itercount), opt_state, next(batches))
    epoch_time = time.time() - start_time    
    params = get_params(opt_state)
    #train_acc = accuracy(params, x.values, y.values)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    #print("Training set accuracy {}".format(train_acc))


Starting training...
Epoch 0 in 2.69 sec
Epoch 1 in 1.01 sec
Epoch 2 in 1.10 sec
Epoch 3 in 1.14 sec
Epoch 4 in 1.27 sec
Epoch 5 in 1.35 sec
Epoch 6 in 1.39 sec
Epoch 7 in 1.44 sec
Epoch 8 in 1.52 sec
Epoch 9 in 1.59 sec
Epoch 10 in 1.70 sec
Epoch 11 in 1.82 sec
Epoch 12 in 2.01 sec
Epoch 13 in 2.09 sec
Epoch 14 in 2.23 sec
Epoch 15 in 2.45 sec
Epoch 16 in 2.66 sec
Epoch 17 in 2.88 sec
Epoch 18 in 2.78 sec
Epoch 19 in 2.67 sec
Epoch 20 in 2.37 sec
Epoch 21 in 1.99 sec
Epoch 22 in 1.76 sec
Epoch 23 in 1.61 sec
Epoch 24 in 1.48 sec
Epoch 25 in 1.53 sec
Epoch 26 in 1.50 sec
Epoch 27 in 1.46 sec
Epoch 28 in 1.48 sec
Epoch 29 in 1.47 sec
Epoch 30 in 1.50 sec
Epoch 31 in 1.49 sec
Epoch 32 in 1.58 sec
Epoch 33 in 1.49 sec
Epoch 34 in 1.50 sec
Epoch 35 in 1.53 sec
Epoch 36 in 1.49 sec
Epoch 37 in 1.47 sec
Epoch 38 in 1.47 sec
Epoch 39 in 1.45 sec
Epoch 40 in 1.48 sec
Epoch 41 in 1.45 sec
Epoch 42 in 1.47 sec
Epoch 43 in 1.46 sec
Epoch 44 in 1.48 sec
Epoch 45 in 1.48 sec
Epoch 46 in 1.49 sec
E

In [30]:
get_params(opt_state)

[(DeviceArray([[ 1.17039606e-01, -7.68911988e-02,  1.18306898e-01,
                 2.35780347e-02, -5.23120090e-02,  1.06977083e-01,
                 1.41145483e-01, -2.10290313e-01, -1.57700613e-01,
                 5.40650003e-02, -1.46053582e-01,  2.11354256e-01,
                 1.42893761e-01,  1.63040191e-01, -1.45566911e-01,
                 9.56042483e-02, -6.05233619e-03,  3.44735198e-02,
                -2.16554046e-01, -4.66052145e-02, -2.42987037e-01,
                -1.14166282e-01, -6.64530396e-02, -1.80376954e-02,
                -4.27191146e-02, -2.21430436e-01, -6.07723184e-02,
                -1.72491953e-01, -1.02409251e-01,  4.93595675e-02,
                -1.70862094e-01,  1.62178516e-01, -2.47168824e-01,
                -1.55457556e-01,  7.87244141e-02,  7.37744523e-03,
                 6.43459186e-02,  6.75803656e-03,  9.62769985e-02,
                -3.75370979e-02, -1.88406020e-01, -1.08496837e-01,
                 1.98165447e-01,  2.40025092e-02,  1.48074910e

In [31]:
predict(get_params(opt_state), x.values)

DeviceArray([[0.09021164],
             [0.09021601],
             [0.09022038],
             ...,
             [0.09159355],
             [0.09159149],
             [0.09158941]], dtype=float32)

In [36]:
pip install git+https://github.com/deepmind/dm-haiku

Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /private/var/folders/wm/m4whmdd14n57n775j15v1tjw0000gn/T/pip-req-build-msvcx1vm
Collecting jmp>=0.0.2
  Downloading jmp-0.0.2-py3-none-any.whl (16 kB)
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... [?25ldone
[?25h  Created wheel for dm-haiku: filename=dm_haiku-0.0.5.dev0-py3-none-any.whl size=285278 sha256=6bcb0bcc6bc2b3558fea688ed76fc8c735df535ad07eef3329f59fb130196f53
  Stored in directory: /private/var/folders/wm/m4whmdd14n57n775j15v1tjw0000gn/T/pip-ephem-wheel-cache-8b1_z6vp/wheels/c7/4d/89/b159f184ad7c9e95672c342eafcc176ad92ee0c77f27f3bd23
Successfully built dm-haiku
Installing collected packages: jmp, dm-haiku
[31mERROR: After October 2020 you may experience errors when installing or updating packages. This is because pip will change the way that it resolves dependency conflicts.

We recommend you use --use-feature=2020-resolver t

In [58]:
import haiku as hk

In [59]:
def loss_fn(x_train,y_train):
    mlp = hk.Sequential([hk.Linear(128), jax.nn.relu,
                         hk.Linear(1024), jax.nn.relu,
                         hk.Linear(128),  jax.nn.relu,
                         hk.Linear(16), jax.nn.celu,
                         hk.Linear(1)])
    preds = mlp(x_train)
    return jnp.mean(jnp.sum(preds-y_train)**2)

In [71]:
loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)
# Initial parameter values are typically random. In JAX you need a key in order
# to generate random numbers and so Haiku requires you to pass one in.
rng = jax.random.PRNGKey(42)

# `init` runs your function, as such we need an example input. Typically you can
# pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
# is not usually data dependent.
x_init, y_init = next(batches)

# The result of `init` is a nested data structure of all the parameters in your
# network. You can pass this into `apply`.
params = loss_fn_t.init(rng, x_init, y_init)

  lax._check_user_dtype_supported(dtype, "zeros")


In [72]:
params

FlatMap({
  'linear': FlatMap({
              'w': DeviceArray([[-0.4363558 ,  0.0706057 ,  0.20446752, -0.24028288,
                                  0.1068935 , -0.24332774,  0.11834247, -0.43644264,
                                 -0.02317936, -0.38573587,  0.141738  ,  0.11776715,
                                  0.4071066 , -0.89413506, -0.5232067 , -0.29121146,
                                  0.1349041 , -0.14995281, -0.6072486 ,  0.31439045,
                                 -0.24311817, -0.6128429 ,  0.61145765,  0.5701605 ,
                                  0.2899103 , -0.34465683,  0.15892078,  0.17956081,
                                 -0.07094992, -0.16064133,  0.00093165, -0.3747404 ,
                                 -0.59049094,  0.05386654,  0.09549524, -0.04192615,
                                 -0.41203362,  0.39526656, -0.19115703,  0.02204039,
                                  0.18036191,  0.4340976 ,  0.38166875, -0.43801317,
                                 

In [73]:
def sgd(param, update):
    return param - step_size * update

print("\nStarting training...")
for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
        x_iter, y_iter = next(batches)
        grads = jax.grad(loss_fn_t.apply)(params, x_iter, y_iter)
        params = jax.tree_multimap(sgd, params, grads)
    epoch_time = time.time() - start_time    
    #train_acc = accuracy(params, x.values, y.values)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    #print("Training set accuracy {}".format(train_acc))


Starting training...
Epoch 0 in 16.46 sec
Epoch 1 in 14.94 sec
Epoch 2 in 15.24 sec
Epoch 3 in 14.88 sec
Epoch 4 in 15.07 sec
Epoch 5 in 14.97 sec
Epoch 6 in 15.25 sec
Epoch 7 in 15.08 sec
Epoch 8 in 15.19 sec
Epoch 9 in 15.77 sec
Epoch 10 in 15.03 sec
Epoch 11 in 15.02 sec
Epoch 12 in 15.28 sec
Epoch 13 in 15.24 sec
Epoch 14 in 16.27 sec
Epoch 15 in 15.38 sec
Epoch 16 in 15.15 sec
Epoch 17 in 16.95 sec
Epoch 18 in 16.84 sec
Epoch 19 in 15.16 sec
Epoch 20 in 15.16 sec
Epoch 21 in 15.42 sec
Epoch 22 in 15.28 sec
Epoch 23 in 15.76 sec
Epoch 24 in 15.63 sec
Epoch 25 in 16.22 sec
Epoch 26 in 15.68 sec
Epoch 27 in 15.24 sec
Epoch 28 in 15.27 sec
Epoch 29 in 15.26 sec
Epoch 30 in 15.21 sec
Epoch 31 in 15.23 sec
Epoch 32 in 15.26 sec
Epoch 33 in 15.25 sec
Epoch 34 in 15.04 sec
Epoch 35 in 15.14 sec
Epoch 36 in 15.06 sec
Epoch 37 in 15.22 sec
Epoch 38 in 15.52 sec
Epoch 39 in 15.25 sec
Epoch 40 in 15.14 sec
Epoch 41 in 15.27 sec
Epoch 42 in 15.14 sec
Epoch 43 in 15.17 sec
Epoch 44 in 15.22 se

In [74]:
loss_fn_t.apply(params,x.values,y.values)

DeviceArray(nan, dtype=float32)

In [75]:
params

FlatMap({
  'linear': FlatMap({
              'b': DeviceArray([         nan,          nan, -397.1789   ,  -97.34782  ,
                                -297.47705  ,          nan, -465.82858  ,    0.       ,
                                   0.       , -131.5717   , -336.90704  ,          nan,
                                -216.40575  ,    0.       ,          nan,          nan,
                                -447.69424  ,          nan,    0.       ,          nan,
                                  -6.2996573,          nan,          nan, -155.28235  ,
                                         nan,    0.       , -141.9651   ,          nan,
                                   0.       ,          nan,  -15.818071 ,    0.       ,
                                         nan,    0.       ,    0.       ,  -53.556313 ,
                                         nan, -231.59782  ,          nan,    0.       ,
                                         nan,  -27.464762 , -154.47498  ,    0.       ,
