<a href="https://colab.research.google.com/github/durml91/Transformer_Implementation/blob/main/JAX%26PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Comparing JAX and PyTorch

Resources: https://sjmielke.com/jax-purify.htm and https://kidger.site/thoughts/torch2jax/

Difference in AutoDiff (reverse mode ofc) - PyTorch builds computational graph using Tensor, Modules and Parameter classes. As you go propogate forward, build computational graph by storing parent nodes, args and kwargs and functions. Wrap these forward functions in order to control gradients and correctly store this info. Then backward() on loss requires a store of forward and its corresponding "backward" function (derivative) to call upon during backprop. Topologically sort (depth first) for dp approach to chain rule and calculate derivatives in most memory efficient way possible (depth first gives that). JAX is more functional (think global parameters) and expresses gradient of functions (rather than actually computing the derivative) using the grad() transformation - this means that you get a function that you need to feed an input in order to get the gradient (rather than in PyTorch, where at a specific node, you get have grad_out - wrt to the loss - and working backwards gives you the derivative of the parent node wrt the current node, i.e. grad_out * grad of current function from input x to out)

Important things to look out for code-wise - PyTorch is oop with mutating states (think in-place operations like updating gradients in optimiser - we don't care about keeping original weight in memory, just want the updated one) - JAX relies on "pure" functions

For JAX (and fp in general) accurate summary - functions should always output the same values given same input. Also,there is no knowledge about whether the function has been called at all.


In [19]:
import random
import time

def pure_fn_1(x):
    return 2 * x

def pure_fun_2(xs):
    ys = []
    for x in xs:
      # Mutating stateful variables *inside* the function is fine!
      ys.append(2 * x)
    return ys

def impure_fn_1(xs):
    # Mutating arguments has lasting consequences outside the function! :(
    xs.append(sum(xs))
    return xs

def impure_fn_2(x):
    # Very obviously mutating global state is bad...
    global num_execs
    num_execs += 1
    return 2 * x

def impure_fn_3(x):
    # ...but just accessing it is, too, because now the function depends on the
    # execution context!
    return num_execs * x

def impure_fn_4(x):
    # Things like IO are classic examples of impurity.
    # All three of the following lines are violations of purity:
    print("Hello!")
    user_input = input()
    execution_time = time.time()
    return 2 * x

def impure_fn_5(x):
    # Which constraint does this violate? Both, actually! You access the current
    # state of randomness *and* advance the number generator!
    p = random.random()
    return p * x


In [20]:
num_execs = 0

a = 3
a_out = pure_fn_1(a)
print(a_out) # 2 * 3

b = [1,2,3]
b_out = pure_fun_2(b)
print(b_out) # essentially calling pure_fn_1 inside of function

c = [1,2,3]
c_out = impure_fn_1(c)
print(c_out) # we have used an in place operation hence

d = 3
d_out = impure_fn_2(d)
print(d) # used global keyword to access num_execs
print(num_execs) # this is the global value that we've adjusted (defined a new variable and altered it)

e = 3
e_out = impure_fn_3(e)
print(e) # altered as a result of the previous function being called first, and accessing the variable that has been changed

f = 3
f_out = impure_fn_4(f)
print(f)

g = 3
g_out = impure_fn_5(g)
print(g)# randomness violates consistency of output without fixed key



6
[2, 4, 6]
[1, 2, 3, 6]
3
1
3
Hello!
hi
3
3




1.   JIT functions - use a single JIT decorator (or call) as all the operations are compiled within that function - make sure to you highest hierarchy function - extra/nested `jax.jit`s do nothing
2.   Tracing (v. important) - for example, under a JIT function, can't use boolean/if statements as the whole graph is passed to the compiler at once and executed in full. For example, if you train a transformer and want to greedy decode during "jitted" training loop, can't include `if` statement for end token. Trace really means that the actual value is not known when code is being executed bu the object is known (array with same shape, dtype and paralleism approach) - changing the shape for example will trigger a recompile of the code. `if` can be used to define specific shape of computation graph before compilation time (i.e. if inside a model is fine as this is static, although of course cannot depend on contents of input array as this may change) . However, use jnp.where during computation - this allows you to use conditional statements - so for transformer decoding example, end token or masking/patch tokens will have to be generated to full voacb_size, but use jnp.where to choose output logits that aren't redundant. **Very** useful - inside of jits, use `jax.debug.print` (don't have to remove jit all the time!)
3.  Jit has pytorch analogy in `torch.compile`
4.  Pytrees - don't use generic python classes - use nn.Module from specific library
5.  In place updates (or fa style) - use `array.at[].set()` rather than using indexing and changing values in numpy or pytorch. Problems occur if in backprop, function needs values of values not indexed - in this case the whole copy of array will be used. You will have to manually construct a new array to be certain i.e. set the values to whatever you want manually (so maybe use a for loop and jnp.concatenate)
6.  JAX is about writing code that can be compiled intelligently. Compiler tries to speed up for loops by unfolding for loops. If recursively applying a function, it may create copies of function calls inside for loops which will have to be executed separately. To avoid this, write function calls explicitly, one after one anther, or use `jax.lax.scan` which eliminates need to call functions over again with carry (carry would be input in this case) and iterate over desired `range()`
7.  `jax.lax.cond` is faster than `jnp.where` as it doesn't require evaluation of both inputs/branches - watch out when using this with vmap though (sometimes negates the speedup benefits)




from https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [25]:
import jax
import jax.numpy as jnp
import jax.random as jr

In [26]:
g=0.0
def imp_global(x):
  return x + g
print(jax.jit(imp_global)(4.)) # 4 + 0 = 4
g = 10. #change global value
print(jax.jit(imp_global)(5.))  # expect 5 + 10 but get 5 + 0 - global value remains as 0.0 # using global keyword won't change anything
print(jax.jit(imp_global)(jnp.array([4.]))) # however, changing the shape by using an array, triggers a recompile and thus updates the global value of variable

4.0
5.0
[14.]


In [31]:
g = 0.
def imp_global_2(x):
  global g
  g = x
  return x

print(jax.jit(imp_global_2)(4.0))
print(g)

4.0
Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
