# I go through the jax tutorial and attempt to understand 10% of what goes on

In [28]:
import jax.numpy as jnp
from jax import lax
from jax import grad
import jax

- jax.Array Creation: Similar to np.  
- There be some complicated stuff about devices that ill get to later
- Tracers: Basically u run through fxn with tracer and they help jax "compile"/figure out the sequence of operations the fxn carries out (aka the jaxpr)
- Pytree: Nested data structures
- Jax has random keys that you pass in to fxns instead of np's seed

## JIT


In [18]:
global_list = []

def log2(x, k):
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(k)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0, 5.)) 
#Jaxpr: Low-level, compiled thing. You don't run it. It just exists (kind of like backend stuff thats still cool)
#jaxpr made by running it on the args u give while also tracing stuff.

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[][39m b[35m:f32[][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[][39m = log a
    d[35m:f32[][39m = log b
    e[35m:f32[][39m = div c d
  [34m[22m[1min [39m[22m[22m(e,) }


# IMPORTANT: Does not capture anything about global_list.append(x)
- Feature, not a bug. Basically wants fxns that depend only on their args
- Impure fxns (fxns that read/write to a global state are bad due to the compiler doing weird things).
- Ie it can cache the global state as 4.0, then the global gets updated to 5, but the cached value stays 4 and then your computations go no no.
- Print is included as impure.
- Basically just pass anything that the fxn depends on as an arg.

In [23]:
jitted_log2 = jax.jit(log2)
jitted_log2(3, 5)
print(f"length {len(global_list)} and list {global_list} before")
jitted_log2(1, 5)
#Should make global_list bigger but it DOESNT due to thing above.
print(f"length {len(global_list)} and list {global_list} after")

#Side note: The first time does make it run the append but that might be just due the first call including a tracing pass (which includes the global append) but subsequent ones don't
#IN OTHER WORDS DONT READ/WRITE GLOBALS. 

length 2 and list [Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>] before
length 2 and list [Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace>] after


# Conditionals

In [25]:
#so jax gets very not happy with conditionals.# While loop conditioned on x and n with a jitted body.
#Avoid that!

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i) #Here, computationally expensive loop body is jitted but the while loop
      #is not jitted so jax doesnt have to compile a fxn conditional on something that is known only at runtime (n)
  return x + i

g_inner_jitted(10, 20)

Array(30, dtype=int32, weak_type=True)

In [26]:
#Or option 2: use staticargnames. This will work, but is not great as it has to recompile for each new value of the nums in static argnums
#Static means python values not jax arrays

jit_cond = jax.jit(g_inner_jitted, static_argnames='n')

In [31]:
#Best:

#jnp.where, jnp.piecewise -> function like np.piecwise fxns

#Jax.lax.cond is where true_fun is ran if its true, false if its false and pred is the condtional the if(conditional)
#compiles both branches and evaulates the if at runtime
def cond(pred, true_fun, false_fun, operand):
  if pred:
    return true_fun(operand)
  else:
    return false_fun(operand)

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand)

#Also: fori and while loops: Jax does cool stuff using XLA

Array([1.], dtype=float32)

# Autovectorization:
jax.vmap() adds a batch ax to beginning of each input. Can also use in_axes or out_axes to specify the location of batch dimensions

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)



# Autodiff

In [35]:
grad_tanh = grad(jnp.tanh) #grad_fxn = grad(fxn_name). Returns a fxn 
gradded_val = grad(jnp.tanh)(2.)
gradded_val

# loss_value, Wb_grad = jax.value_and_grad(fxn, (0, 1))(W, b)
#Also: argnums is either int or list of ints (positional) to differentiate wrt. 
#Returns same datastructure/type as the argnums

Array(0.07065082, dtype=float32, weak_type=True)

# Pytree

In [None]:
#Basically nested data structures can be seen as a tree. Can make some of these but idt we need to worry about that rn

# Sharp bits
jax can only differentiate immutables (that are immutable in the same way a str is. Can += but just reassings)  
Uses x.at[idx].set(y) , .add(num), .multiply(num), etc. These are addings that are made rlly fast under hood

# I try einsum and other np functions

In [39]:
import numpy as np

So it goes through the test1 in column major order (columns or the right side changing fastest) and then shoves it into the arr of the new shape in column major (w/ columns, or the right side changing fastest)

In [44]:
test1 = np.arange(20).reshape((2,2, 5), order='c')
print(test1)

[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]]

 [[10 11 12 13 14]
  [15 16 17 18 19]]]


In [47]:
#Now lets say we got something of bhwc:
batched_img = np.arange(36).reshape(2,2,3,3) #Batch of 2 images w/ 2 rows and 3 columns and 3 channels. So each of those 3 channels = 1 pixel normally
batched_img

array([[[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]]],


       [[[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]],

        [[27, 28, 29],
         [30, 31, 32],
         [33, 34, 35]]]])

In [50]:
#if I wanted to flatten it but retain channels:
test2 = batched_img.reshape(2,6,3) #functions because each 3 values in the flattened get read into the same pixel.
test2

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26],
        [27, 28, 29],
        [30, 31, 32],
        [33, 34, 35]]])

All right now onto the hard part: Einstein summation. Even the name is intimidating.  
np.einsum('in1.shape,in2.shape->output_shape', in1, in2). Also can omit stuff from output_shape to tell it to sum over those axes  
Examples are the only way I can explain this so here

In [53]:
#Matrix mult.
A = np.random.rand(2, 3)
B = np.random.rand(3, 4)
C = np.einsum('ik,kj->ij', A, B)

In [68]:
#Dot product
A = np.arange(6)
B = np.arange(6)
dot_prod = np.einsum('i,i->', A, B) #holy shit theres something wacko going on. think its cause they sum it after computing the outer product in the second scanrio
outer_prod_sum = np.einsum('i,b->', A, B)
outer_prod = np.einsum('i,b->ib', A, B)
print(A, B)
print(dot_prod, outer_prod_sum, outer_prod)

[0 1 2 3 4 5] [0 1 2 3 4 5]
55 225 [[ 0  0  0  0  0  0]
 [ 0  1  2  3  4  5]
 [ 0  2  4  6  8 10]
 [ 0  3  6  9 12 15]
 [ 0  4  8 12 16 20]
 [ 0  5 10 15 20 25]]


In [73]:
#Tensordot over specific axes:
A1 = np.arange(6).reshape(2, 3)
B1 = np.arange(6).reshape(2,3) + 1
#I want to run tensor over ax 1 while leaving ax 0 unchanged. So result be of size (2, 1)
res = np.einsum('ab,ab->a', A1, B1)
print(A1)
print(B1)
print(res)

[[0 1 2]
 [3 4 5]]
[[1 2 3]
 [4 5 6]]
[ 8 62]


In [105]:
#Now we get relevant to CNNs again. 
A2 = np.arange(12).reshape(1,2,2,3) #Slice of 5-batch 3 channel image
k = np.arange(20).reshape(2,2,5) #2x2kernel w/ 5 outputfilters

convolved = np.einsum('abxc,bxd->ad',A2, k) #oh so this works but 
crap_convolved = np.einsum('abbc,bbd->ad',A2, k) #doesnt work. According to google it performed a summation over the diagonal or smth

#I aint manaully checking this so lets write a nested for loop
ret_slow = np.zeros((A2.shape[0], k.shape[2]))
for b in range(A2.shape[0]): #we manually loop over every batch ughh this is hurting me already
    #We manually loop over the filters
    for f in range(k.shape[2]):
        ret_slow[b,f] = np.sum( A2[b] * k[:,:,f, np.newaxis])
        
print(A2)
print("-")
print(k)
print("slow")
print(ret_slow)
print("einsummed")
print(convolved)


[[[[ 0  1  2]
   [ 3  4  5]]

  [[ 6  7  8]
   [ 9 10 11]]]]
-
[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]]

 [[10 11 12 13 14]
  [15 16 17 18 19]]]
slow
[[720. 786. 852. 918. 984.]]
einsummed
[[720 786 852 918 984]]


In [None]:
#Why does einsum tweak out with duplicate axes? Time to find out!
