In [2]:
!nvidia-smi

Wed Aug  7 21:10:36 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.27                 Driver Version: 560.70         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...    On  |   00000000:01:00.0  On |                  N/A |
|  0%   41C    P8              6W /  285W |    1667MiB /  16376MiB |      6%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [3]:
import jax
import jax.numpy as jnp

In [4]:
jax.__version__

'0.4.30'

In [5]:
jax.devices()

[cuda(id=0)]

## Using jax.jit as transformation or annotation

In [6]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [7]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

In [8]:
selu_jit = jax.jit(selu)

In [9]:
%timeit -n100 selu(x).block_until_ready()

949 μs ± 404 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 selu_jit(x).block_until_ready()

683 μs ± 106 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
@jax.jit
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [16]:
z = selu(x) # warmup

In [19]:
%timeit -n100 selu(x).block_until_ready()

The slowest run took 4.71 times longer than the fastest. This could mean that an intermediate result is being cached.
540 μs ± 284 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Compiling and running on specific hardware

In [20]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [21]:
selu_jit_cpu = jax.jit(selu, backend='cpu')
selu_jit_gpu = jax.jit(selu, backend='gpu')

In [22]:
%timeit -n100 selu(x).block_until_ready()

1.06 ms ± 389 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

4.21 ms ± 396 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

The slowest run took 5.55 times longer than the fastest. This could mean that an intermediate result is being cached.
102 μs ± 84.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

In [26]:
%timeit -n100 selu(x_cpu).block_until_ready()

1.68 ms ± 297 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [27]:
%timeit -n100 selu(x_gpu).block_until_ready()

1.35 ms ± 586 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [28]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

198 μs ± 47.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [33]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

The slowest run took 5.64 times longer than the fastest. This could mean that an intermediate result is being cached.
451 μs ± 269 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [30]:
%timeit -n100 selu_jit_cpu(x_gpu).block_until_ready()

3.99 ms ± 96.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [31]:
%timeit -n100 selu_jit_gpu(x_cpu).block_until_ready()

2.41 ms ± 589 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [34]:
x_cpu.devices()

{CpuDevice(id=0)}

In [35]:
x_gpu.devices()

{cuda(id=0)}

In [36]:
selu_jit_gpu

<PjitFunction of <function selu at 0x7ff694111f80>>

In [37]:
selu_jit_cpu

<PjitFunction of <function selu at 0x7ff694111f80>>

## Working with function arguments

In [38]:
def dist(order, x, y):
  print("Compiling")
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

In [39]:
dist_jit = jax.jit(dist, static_argnums=0)

In [40]:
dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Compiling


Array(4., dtype=float32)

In [41]:
dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Compiling


Array(2.828427, dtype=float32)

In [42]:
dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0]))

Array(16., dtype=float32)

In [43]:
from functools import partial

@partial(jax.jit, static_argnums=0)
def dist(order, x, y):
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)


In [44]:
dist

<PjitFunction of <function dist at 0x7ff69413ef20>>

In [45]:
dist(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Array(2.828427, dtype=float32)

In [46]:
def dense_layer(x, w, b, activation_func):
    return activation_func(x*w+b)

In [47]:
x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3,3))
b = jnp.ones(3)

In [48]:
dense_layer_jit = jax.jit(dense_layer)

dense_layer_jit(x, w, b, selu)

TypeError: Cannot interpret value of type <class 'function'> as an abstract array; it does not have a dtype attribute

In [49]:
dense_layer_jit = jax.jit(dense_layer, static_argnums=3)

In [50]:
dense_layer_jit(x, w, b, selu)

Array([[2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804]], dtype=float32)

## Pure functions again

In [51]:
global_state = 1

def impure_function(x):
  print(f'Side-effect: printing x={x}')
  y = x*global_state
  return y

In [52]:
impure_function_jit = jax.jit(impure_function)

In [53]:
impure_function_jit(10)

Side-effect: printing x=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


Array(10, dtype=int32, weak_type=True)

In [54]:
impure_function_jit(10)

Array(10, dtype=int32, weak_type=True)

In [55]:
impure_function_jit(11)

Array(11, dtype=int32, weak_type=True)

In [56]:
global_state = 2

In [57]:
impure_function_jit(10)

Array(10, dtype=int32, weak_type=True)

In [58]:
impure_function(10)

Side-effect: printing x=10


20

## Jaxpr

Getting jaxpr

In [59]:
def f1(x, y, z):
  return jnp.sum(x + y * z)

In [60]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 1.0, 0.0]).T

In [61]:
jax.make_jaxpr(f1)(x,y,z)

{ lambda ; a:f32[3] b:f32[3,3] c:f32[3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] c
    e:f32[3,3] = mul b d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [76]:
f1_jaxpr = jax.make_jaxpr(f1)(x,y,z)

In [77]:
type(f1_jaxpr)

jax._src.core.ClosedJaxpr

In [78]:
f1_jaxpr.jaxpr

{ lambda ; a:f32[3] b:f32[3,3] c:f32[3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] c
    e:f32[3,3] = mul b d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [79]:
f1_jaxpr.consts

[]

Jaxpr for a function with side-effect

In [80]:
def f2(x, y):
  print(f'x={x}, y={y}, z={z}')
  return jnp.sum(x + y * z)

In [81]:
f2_jaxpr = jax.make_jaxpr(f2)(x,y)

x=Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>, y=Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=1/0)>, z=[2. 1. 0.]


In [82]:
type(f2_jaxpr.jaxpr)

jax._src.core.Jaxpr

In [83]:
f2_jaxpr.jaxpr

{ lambda a:f32[3]; b:f32[3] c:f32[3,3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    e:f32[3,3] = mul c d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [86]:
f2_jaxpr.jaxpr.constvars

[Var(id=140693875631424):float32[3]]

In [87]:
f2_jaxpr.jaxpr.invars

[Var(id=140693875636864):float32[3], Var(id=140693875636800):float32[3,3]]

In [88]:
f2_jaxpr.jaxpr.outvars

[Var(id=140693875622592):float32[]]

In [89]:
f2_jaxpr.jaxpr.eqns

[a:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b,
 a:f32[3,3] = mul b c,
 a:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b,
 a:f32[3,3] = add b c,
 a:f32[] = reduce_sum[axes=(0, 1)] b]

In [90]:
f2_jaxpr.jaxpr.effects

set()

In [91]:
type(f2_jaxpr.consts)

list

In [92]:
f2_jaxpr.consts

[Array([2., 1., 0.], dtype=float32)]

In [93]:
type(jax.make_jaxpr(f2))

function

In [94]:
jax.grad(f2)(x,y)

x=Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1., 1., 1.], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = LambdaBinding(), y=[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]], z=[2. 1. 0.]


Array([3., 3., 3.], dtype=float32)

Tracing with control structures

In [95]:
def f3(x):
  y = x
  for i in range(5):
    y += i
  return y

In [96]:
jax.make_jaxpr(f3)(0)

{ lambda ; a:i32[]. let
    b:i32[] = add a 0
    c:i32[] = add b 1
    d:i32[] = add c 2
    e:i32[] = add d 3
    f:i32[] = add e 4
  in (f,) }

In [97]:
jax.jit(f3)(0)

Array(10, dtype=int32, weak_type=True)

In [98]:
def f4(x):
  y = 0
  for i in range(x.shape[0]):
    y += x[i]
  return y

In [99]:
jax.make_jaxpr(f4)(jnp.array([1.0, 2.0, 3.0]))

{ lambda ; a:f32[3]. let
    b:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c:f32[] = squeeze[dimensions=(0,)] b
    d:f32[] = add 0.0 c
    e:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f:f32[] = squeeze[dimensions=(0,)] e
    g:f32[] = add d f
    h:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i:f32[] = squeeze[dimensions=(0,)] h
    j:f32[] = add g i
  in (j,) }

In [100]:
jax.jit(f4)(jnp.array([1.0, 2.0, 3.0]))

Array(6., dtype=float32)

Dependance on a parameter value

In [101]:
def f5(x):
  y = 0
  for i in range(x):
    y += i
  return y

In [102]:
f5(5)

10

In [103]:
jax.make_jaxpr(f5)(5)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
The error occurred while tracing the function f5 at /tmp/ipykernel_52186/4135095833.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [104]:
jax.jit(f5)(5)

TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
The error occurred while tracing the function f5 at /tmp/ipykernel_52186/4135095833.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError

In [105]:
def relu(x):
  #print(x)
  if x > 0:
    return x
  return 0.0

In [106]:
relu(10.0)

10.0

In [107]:
jax.make_jaxpr(relu)(10.0)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function relu at /tmp/ipykernel_52186/1355355841.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Using static parameters to overcome this

In [108]:
jax.make_jaxpr(f5, static_argnums=0)(5)

{ lambda ; . let  in (10,) }

In [109]:
jax.jit(f5, static_argnums=0)(5)

Array(10, dtype=int32, weak_type=True)

In [110]:
jax.make_jaxpr(relu, static_argnums=0)(12.3)

{ lambda ; . let  in (12.3,) }

In [111]:
jax.jit(relu, static_argnums=0)(12.3)

Array(12.3, dtype=float32, weak_type=True)

Rewriting with structured control flow primitives

In [112]:
def f5(x):
  return jax.lax.fori_loop(0, x, lambda i,v: v+i, 0)

In [113]:
f5(5)

Array(10, dtype=int32, weak_type=True)

In [114]:
jax.make_jaxpr(f5)(5)

{ lambda ; a:i32[]. let
    _:i32[] _:i32[] b:i32[] = while[
      body_jaxpr={ lambda ; c:i32[] d:i32[] e:i32[]. let
          f:i32[] = add c 1
          g:i32[] = add e c
        in (f, d, g) }
      body_nconsts=0
      cond_jaxpr={ lambda ; h:i32[] i:i32[] j:i32[]. let
          k:bool[] = lt h i
        in (k,) }
      cond_nconsts=0
    ] 0 a 0
  in (b,) }

In [115]:
jax.jit(f5)(5)

Array(10, dtype=int32, weak_type=True)

In [116]:
def relu(x):
  return jax.lax.cond(x>0, lambda x: x, lambda x: 0.0, x)

In [117]:
relu(12.3)

Array(12.3, dtype=float32, weak_type=True)

In [118]:
relu(-12.3)

Array(0., dtype=float32, weak_type=True)

In [119]:
jax.make_jaxpr(relu)(12.3)

{ lambda ; a:f32[]. let
    b:bool[] = gt a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let  in (0.0,) }
        { lambda ; f:f32[]. let  in (f,) }
      )
      linear=(False,)
    ] c a
  in (d,) }

In [120]:
jax.jit(relu)(12.3)

Array(12.3, dtype=float32, weak_type=True)

## XLA

In [121]:
def f(x, y, z):
  return jnp.sum(x + y * z)

In [122]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 1.0, 0.0]).T

In [123]:
x

Array([1., 1., 1.], dtype=float32)

In [124]:
y

Array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)

In [125]:
z

Array([2., 1., 0.], dtype=float32)

In [126]:
f(x, y, z)

Array(27., dtype=float32)

In [127]:
f_jitted = jax.jit(f)

In [128]:
f_lowered = f_jitted.lower(x,y,z)
print(f_lowered.as_text())

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<3x3xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<3xf32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %2 = stablehlo.multiply %arg1, %1 : tensor<3x3xf32>
    %3 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %5 = stablehlo.add %4, %2 : tensor<3x3xf32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x3xf32>, tensor<f

In [129]:
f_compiled = f_lowered.compile()

In [130]:
f_compiled

<jax._src.stages.Compiled at 0x7ff5941a8a10>

In [131]:
print(f_compiled.as_text())

HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3,3]{1,0}, f32[3]{0})->f32[]}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="f8ebb7a5511fd87d38b15ebf11a9dfd6"}

%region_0.15 (Arg_0.16: f32[], Arg_1.17: f32[]) -> f32[] {
  %Arg_1.17 = f32[] parameter(1), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0, 1)]"}
  %Arg_0.16 = f32[] parameter(0), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0, 1)]"}
  ROOT %add.2 = f32[] add(f32[] %Arg_0.16, f32[] %Arg_1.17), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0, 1)]" source_file="/tmp/ipykernel_52186/3103132191.py" source_line=2}
}

%fused_reduce (param_0.5: f32[3,3], param_1.8: f32[3], param_2.7: f32[3]) -> f32[] {
  %param_2.7 = f32[3]{0} parameter(2)
  %broadcast.4.3 = f32[3,3]{1,0} broadcast(f32[3]{0} %param_2.7), dimensions={1}, metadata={op_name="jit(f)/jit(main)/add" source_fil

MLIR

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir_hlo



## JIT Limitations

Long representation

In [132]:
def cumulative_sum(x):
  acc = 0.0
  y = []
  for i in range(x.shape[0]):
    acc += x[i]
    y.append(acc)
  return y

In [133]:
cumulative_sum(jnp.array([1.0, 1.0, 5.0, 2.0]))

[Array(1., dtype=float32),
 Array(2., dtype=float32),
 Array(7., dtype=float32),
 Array(9., dtype=float32)]

In [134]:
j = jax.make_jaxpr(cumulative_sum)(jnp.ones(10000))

In [135]:
len(j.jaxpr.eqns)

30000

In [136]:
j

{ lambda ; a:f32[10000]. let
    b:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c:f32[] = squeeze[dimensions=(0,)] b
    d:f32[] = add 0.0 c
    e:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f:f32[] = squeeze[dimensions=(0,)] e
    g:f32[] = add d f
    h:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i:f32[] = squeeze[dimensions=(0,)] h
    j:f32[] = add g i
    k:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] a
    l:f32[] = squeeze[dimensions=(0,)] k
    m:f32[] = add j l
    n:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] a
    o:f32[] = squeeze[dimensions=(0,)] n
    p:f32[] = add m o
    q:f32[1] = slice[limit_indices=(6,) start_indices=(5,) strides=None] a
    r:f32[] = squeeze[dimensions=(0,)] q
    s:f32[] = add p r
    t:f32[1] = slice[limit_indices=(7,) start_indices=(6,) strides=None] a
    u:f32[] = squeeze[dimensions=(0,)] t
    v:f32[] = add s u
   

In [137]:
%time cs = jax.jit(cumulative_sum)(jnp.ones(10000))

CPU times: user 1min 6s, sys: 1.49 s, total: 1min 7s
Wall time: 1min 7s


In [138]:
def cumulative_sum_fast(x):
  result, array = jax.lax.scan(lambda carry, elem: (carry+elem, carry+elem), 0.0, x)
  return array

In [139]:
cumulative_sum_fast(jnp.array([1.0, 1.0, 5.0, 2.0]))

Array([1., 2., 7., 9.], dtype=float32)

In [140]:
j = jax.make_jaxpr(cumulative_sum_fast)(jnp.ones(10000))

In [141]:
len(j.jaxpr.eqns)

1

In [142]:
j

{ lambda ; a:f32[10000]. let
    _:f32[] b:f32[10000] = scan[
      _split_transpose=False
      jaxpr={ lambda ; c:f32[] d:f32[]. let
          e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
          f:f32[] = add e d
          g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
          h:f32[] = add g d
        in (f, h) }
      length=10000
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] 0.0 a
  in (b,) }

In [143]:
%time cs = jax.jit(cumulative_sum_fast)(jnp.ones(10000))

CPU times: user 185 ms, sys: 93.7 ms, total: 279 ms
Wall time: 442 ms


Class methods

In [144]:
class ScaleClass:
  def __init__(self, scale: jnp.array):
    self.scale = scale

  @jax.jit
  def apply(self, x: jnp.array):
    return self.scale * x

In [145]:
scale_double = ScaleClass(2)

In [146]:
scale_double.apply(10)

TypeError: Cannot interpret value of type <class '__main__.ScaleClass'> as an abstract array; it does not have a dtype attribute

In [147]:
from functools import partial

class ScaleClass:
  def __init__(self, scale: jnp.array):
    self.scale = scale

  def apply(self, x: jnp.array):
    return _apply_helper(self.scale, x)

@partial(jax.jit, static_argnums=0)
def _apply_helper(scale, x):
  return scale*x

In [148]:
scale_double = ScaleClass(2)


In [149]:
scale_double.apply(10)

Array(20, dtype=int32, weak_type=True)

## AOT compilation

In [150]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  print('Function run')
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [151]:
selu_jit = jax.jit(selu)

In [152]:
selu_aot = jax.jit(selu).lower(1.0).compile()

Function run


In [153]:
selu_jit(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [154]:
selu_aot(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [155]:
selu_jit(17)

Function run


Array(17.861917, dtype=float32, weak_type=True)

In [156]:
selu_aot(17)

TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with float32[] and called with int32[]

In [157]:
selu_jit_batched = jax.vmap(selu_jit)

In [158]:
selu_aot_batched = jax.vmap(selu_aot)

In [159]:
selu_jit_batched(jnp.array([42.0, 78.0, -12.3]))

Function run


Array([44.129444 , 81.95468  , -1.7580913], dtype=float32)

In [160]:
selu_aot_batched(jnp.array([42.0, 78.0, -12.3]))

TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax._src.interpreters.batching.BatchTracer'>.