In [None]:
import jax
from jax import jit
from jax import numpy as jnp

## Simple AXPY

In [None]:
def AXPY(a: jnp.float32,
         x: jax.Array,
         y: jax.Array) -> jax.Array:
    ans = a * x + y

    # Uncomment below lines to see traced objects while making jaxpr
    # print("a = "    , a)
    # print("x = "    , x)
    # print("y = "    , y)
    # print("ans = "  , ans)

    return ans

In [None]:
AXPYjitted = jit(AXPY)
type(AXPYjitted)

In [None]:
import sys
import inspect

# to get all the functions defined in the current python environment
inspect.getmembers(sys.modules[__name__], inspect.isfunction)
# AXPYjitted is not a function, its an instance of class PjitFunction, so its not part of the list

In [None]:
a = 0.5
x = jnp.full((5,5),2.0)
y = jnp.full((5,5),3.0)

In [None]:
print(jax.make_jaxpr(AXPY)(a,x,y))

## AXPY with side-effects

In [None]:
b = 10 # global variable

def AXPY2(a: jnp.float32,
         x: jax.Array,
         y: jax.Array) -> jax.Array:
    # the main AXPY function
    ans = a * x + y
    
    # priting: side-effect
    print("Performing AXPY")

    # accessing global variable
    ans = ans + b
    return ans

In [None]:
# this will output "Performing AXPY" as it executes the print statement
# but its not part of the jaxpr
jax.make_jaxpr(AXPY2)(a,x,y)