# I go through the jax tutorial and attempt to understand 10% of what goes on

In [28]:
import jax.numpy as jnp
from jax import lax
from jax import grad
import jax

# Key Concepts: 
- jax.Array Creation: Similar to np.  
- There be some complicated stuff about devices that ill get to later
- Tracers: Basically u run through fxn with tracer and they help jax "compile"/figure out the sequence of operations the fxn carries out (aka the jaxpr)
- Pytree: Nested data structures
- Jax has random keys that you pass in to fxns instead of np's seed

# JIT


In [18]:
global_list = []

def log2(x, k):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(k)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0, 5.)) 
#Jaxpr: Low-level, compiled thing. You don't run it. It just exists (kind of like backend stuff thats still cool)
#jaxpr made by running it on the args u give while also tracing stuff.

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = log a
    d[35m:f32[][39m = log b
    e[35m:f32[][39m = div c d
  [34m[22m[1min [39m[22m[22m(e,) }


# IMPORTANT: Does not capture anything about global_list.append(x)
- Feature, not a bug. Basically wants fxns that depend only on their args
- Impure fxns (fxns that read/write to a global state are bad due to the compiler doing weird things).
- Ie it can cache the global state as 4.0, then the global gets updated to 5, but the cached value stays 4 and then your computations go no no.
- Print is included as impure.
- Basically just pass anything that the fxn depends on as an arg.

In [23]:
jitted_log2 = jax.jit(log2)
jitted_log2(3, 5)
print(f"length {len(global_list)} and list {global_list} before")
jitted_log2(1, 5)
#Should make global_list bigger but it DOESNT due to thing above.
print(f"length {len(global_list)} and list {global_list} after")

#Side note: The first time does make it run the append but that might be just due the first call including a tracing pass (which includes the global append) but subsequent ones don't
#IN OTHER WORDS DONT READ/WRITE GLOBALS. 

length 2 and list [Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>] before
length 2 and list [Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>] after


# Conditionals

In [25]:
#so jax gets very not happy with conditionals.# While loop conditioned on x and n with a jitted body.
#Avoid that!

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i) #Here, computationally expensive loop body is jitted but the while loop
      #is not jitted so jax doesnt have to compile a fxn conditional on something that is known only at runtime (n)
  return x + i

g_inner_jitted(10, 20)

Array(30, dtype=int32, weak_type=True)

In [26]:
#Or option 2: use staticargnames. This will work, but is not great as it has to recompile for each new value of the nums in static argnums
#Static means python values not jax arrays

jit_cond = jax.jit(g_inner_jitted, static_argnames='n')

In [31]:
#Best:

#jnp.where, jnp.piecewise -> function like np.piecwise fxns

#Jax.lax.cond is where true_fun is ran if its true, false if its false and pred is the condtional the if(conditional)
#compiles both branches and evaulates the if at runtime
def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)

#Also: fori and while loops: Jax does cool stuff using XLA

Array([1.], dtype=float32)

# Autovectorization:
jax.vmap() adds a batch ax to beginning of each input. Can also use in_axes or out_axes to specify the location of batch dimensions

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)



# Autodiff

In [35]:
grad_tanh = grad(jnp.tanh) #grad_fxn = grad(fxn_name). Returns a fxn 
gradded_val = grad(jnp.tanh)(2.)
gradded_val

# loss_value, Wb_grad = jax.value_and_grad(fxn, (0, 1))(W, b)
#Also: argnums is either int or list of ints (positional) to differentiate wrt. 
#Returns same datastructure/type as the argnums

Array(0.07065082, dtype=float32, weak_type=True)

# Pytree

In [None]:
#Basically nested data structures can be seen as a tree. Can make some of these but idt we need to worry about that rn

# Sharp bits

In [None]:
#jax can only differentiate immutables (that are immutable in the same way a str is. Can += but just reassings)
#Uses x.at[idx].set(y) , .add(num), .multiply(num), etc. These are addings that are made rlly fast under hood