<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/jax_more_advanced.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax import random

import numpy as np
import functools
import matplotlib.pyplot as plt

from jax import grad, jit, vmap, pmap
from copy import deepcopy
from typing import Tuple, NamedTuple

In [5]:
## Setup PRNG
seed = 0
state = jax.random.PRNGKey(seed)
state1, state2 = jax.random.split(state)

In [9]:
class Counter:
  def __init__(self):
    self.n = 0

  def count(self) -> int:
    self.n += 1
    return self.n

  def reset(self):
    self.n = 0

counter = Counter()
for i in range(3):
  print(counter.count())  # counter works as expected

fast_count = jit(counter.count())  # doesn't work

1
2
3


In [26]:
## still dont understand this implementation. Includes a value and state. 
CounterState = int

class Counter2:
  def count(self, n: CounterState) -> Tuple[int, CounterState]:
    return n + 1, n + 1

  def reset(self) -> CounterState:
    return 0

counter = Counter2()
## loop through non-jitted version, but with state
state = counter.reset()

for _ in range(3):
  value, state = counter.count(state)
  print(value)

## loop through jitted version, but with state
state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)

1
2
3
1
2
3


In [30]:
list_of_lists = [
      {'a': 3},
      [1, 2, 3],
      [1, 2],
      [1, 2, 3, 4]
]

# use tree_map for single arg functions
print(jax.tree_map(lambda x: x*2, list_of_lists))
print(jax.tree_multimap(lambda x, y: x + y, list_of_lists, list_of_lists))

# adding an element wil cause a mismatch between the lengths of both lists

[{'a': 6}, [2, 4, 6], [2, 4], [2, 4, 6, 8]]
[{'a': 6}, [2, 4, 6], [2, 4], [2, 4, 6, 8]]


In [32]:
## MLP init example
def init_MLP_params(layer_widths):
  params = []

  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2 / n_in),
             biases=np.ones(shape=(n_out,)))
    )

  return params

params = init_MLP_params([1, 128, 128, 1])
jax.tree_map(lambda x: x.shape, params)

[{'biases': (128,), 'weights': (1, 128)},
 {'biases': (128,), 'weights': (128, 128)},
 {'biases': (1,), 'weights': (128, 1)}]

In [33]:
def forward(params, x):
  *hidden, last = params

  for layer in hidden:
    x = jax.nn.relu(jnp.dot(x, layer['weights']) + layer['biases'])

  return jnp.dot(x, last['weights']) + last['biases']

def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

lr = 0.001

@jit
def update(params, x, y):
  grads = jax.grad(loss_fn)(params, x, y)
  return jax.tree_multimap(
      lambda p, g: p - lr * g, params, grads)

In [34]:
x = np.random.normal(size=(128, 1))
y = x ** 2

for _ in range(5000):
  params = update(params, x, y)