In [1]:
!nvidia-smi

Sat Oct 28 10:59:15 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import jax
import jax.numpy as jnp

In [3]:
jax.__version__

'0.4.16'

In [4]:
jax.devices()

[gpu(id=0)]

## Using jax.jit as transformation or annotation

In [5]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [6]:
x = jax.random.normal(jax.random.PRNGKey(42), (1_000_000,))

In [7]:
selu_jit = jax.jit(selu)

In [8]:
%timeit -n100 selu(x).block_until_ready()

The slowest run took 7.09 times longer than the fastest. This could mean that an intermediate result is being cached.
2.04 ms ± 2.07 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%timeit -n100 selu_jit(x).block_until_ready()

The slowest run took 7.18 times longer than the fastest. This could mean that an intermediate result is being cached.
296 µs ± 308 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
@jax.jit
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [11]:
%timeit -n100 selu(x).block_until_ready()

The slowest run took 4.16 times longer than the fastest. This could mean that an intermediate result is being cached.
224 µs ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Compiling and running on specific hardware

In [12]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [13]:
selu_jit_cpu = jax.jit(selu, backend='cpu')
selu_jit_gpu = jax.jit(selu, backend='gpu')

In [14]:
%timeit -n100 selu(x).block_until_ready()

1.11 ms ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit -n100 selu_jit_cpu(x).block_until_ready()

1.94 ms ± 221 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
%timeit -n100 selu_jit_gpu(x).block_until_ready()

207 µs ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
x_gpu = jax.device_put(x, jax.devices('gpu')[0])

In [18]:
%timeit -n100 selu(x_cpu).block_until_ready()

3.09 ms ± 371 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
%timeit -n100 selu(x_gpu).block_until_ready()

1.12 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
%timeit -n100 selu_jit_cpu(x_cpu).block_until_ready()

1.26 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%timeit -n100 selu_jit_gpu(x_gpu).block_until_ready()

257 µs ± 23 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
%timeit -n100 selu_jit_cpu(x_gpu).block_until_ready()

2.14 ms ± 302 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
%timeit -n100 selu_jit_gpu(x_cpu).block_until_ready()

1.44 ms ± 46.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
x_cpu.device()

CpuDevice(id=0)

In [26]:
x_gpu.device()

gpu(id=0)

In [27]:
selu_jit_gpu

<PjitFunction of <function selu at 0x7cfd108bbf40>>

In [28]:
selu_jit_cpu

<PjitFunction of <function selu at 0x7cfd108bbf40>>

## Working with function arguments

In [29]:
def dist(order, x, y):
  print("Compiling")
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)

In [30]:
dist_jit = jax.jit(dist, static_argnums=0)

In [31]:
dist_jit(1, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Compiling


Array(4., dtype=float32)

In [32]:
dist_jit(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Compiling


Array(2.828427, dtype=float32)

In [33]:
dist_jit(1, jnp.array([10.0, 10.0]), jnp.array([2.0, 2.0]))

Array(16., dtype=float32)

In [34]:
from functools import partial

@partial(jax.jit, static_argnums=0)
def dist(order, x, y):
  return jnp.power(jnp.sum(jnp.abs(x-y)**order), 1.0/order)


In [35]:
dist

<PjitFunction of <function dist at 0x7cfcf002c670>>

In [36]:
dist(2, jnp.array([0.0, 0.0]), jnp.array([2.0, 2.0]))

Array(2.828427, dtype=float32)

In [45]:
def dense_layer(x, w, b, activation_func):
    return activation_func(x*w+b)

In [46]:
x = jnp.array([1.0, 2.0, 3.0])
w = jnp.ones((3,3))
b = jnp.ones(3)

In [47]:
dense_layer_jit = jax.jit(dense_layer)

dense_layer_jit(x, w, b, selu)

TypeError: ignored

In [48]:
dense_layer_jit = jax.jit(dense_layer, static_argnums=3)

In [49]:
dense_layer_jit(x, w, b, selu)

Array([[2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804],
       [2.101402, 3.152103, 4.202804]], dtype=float32)

## Pure functions again

In [37]:
global_state = 1

def impure_function(x):
  print(f'Side-effect: printing x={x}')
  y = x*global_state
  return y

In [38]:
impure_function_jit = jax.jit(impure_function)

In [39]:
impure_function_jit(10)

Side-effect: printing x=Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


Array(10, dtype=int32, weak_type=True)

In [40]:
impure_function_jit(10)

Array(10, dtype=int32, weak_type=True)

In [41]:
impure_function_jit(11)

Array(11, dtype=int32, weak_type=True)

In [42]:
global_state = 2

In [43]:
impure_function_jit(10)

Array(10, dtype=int32, weak_type=True)

In [44]:
impure_function(10)

Side-effect: printing x=10


20

## Jaxpr

Getting jaxpr

In [50]:
def f1(x, y, z):
  return jnp.sum(x + y * z)

In [51]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 1.0, 0.0]).T

In [52]:
jax.make_jaxpr(f1)(x,y,z)

{ lambda ; a:f32[3] b:f32[3,3] c:f32[3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] c
    e:f32[3,3] = mul b d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [53]:
f1_jaxpr = jax.make_jaxpr(f1)(x,y,z)

In [55]:
type(f1_jaxpr)

jax._src.core.ClosedJaxpr

In [56]:
f1_jaxpr.jaxpr

{ lambda ; a:f32[3] b:f32[3,3] c:f32[3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] c
    e:f32[3,3] = mul b d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [57]:
f1_jaxpr.consts

[]

Jaxpr for a function with side-effect

In [58]:
def f2(x, y):
  print(f'x={x}, y={y}, z={z}')
  return jnp.sum(x + y * z)

In [59]:
f2_jaxpr = jax.make_jaxpr(f2)(x,y)

x=Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>, y=Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=1/0)>, z=[2. 1. 0.]


In [None]:
type(f2_jaxpr.jaxpr)

jax.core.Jaxpr

In [60]:
f2_jaxpr.jaxpr

{ lambda a:f32[3]; b:f32[3] c:f32[3,3]. let
    d:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] a
    e:f32[3,3] = mul c d
    f:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b
    g:f32[3,3] = add f e
    h:f32[] = reduce_sum[axes=(0, 1)] g
  in (h,) }

In [61]:
f2_jaxpr.jaxpr.constvars

[a]

In [62]:
f2_jaxpr.jaxpr.invars

[b, c]

In [63]:
f2_jaxpr.jaxpr.outvars

[h]

In [64]:
f2_jaxpr.jaxpr.eqns

[a:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b,
 a:f32[3,3] = mul b c,
 a:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b,
 a:f32[3,3] = add b c,
 a:f32[] = reduce_sum[axes=(0, 1)] b]

In [65]:
f2_jaxpr.jaxpr.effects

set()

In [66]:
type(f2_jaxpr.consts)

list

In [67]:
f2_jaxpr.consts

[Array([2., 1., 0.], dtype=float32)]

In [69]:
type(jax.make_jaxpr(f2))

function

In [70]:
jax.grad(f2)(x,y)

x=Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1., 1., 1.], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = LambdaBinding(), y=[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]], z=[2. 1. 0.]


Array([3., 3., 3.], dtype=float32)

Tracing with control structures

In [71]:
def f3(x):
  y = x
  for i in range(5):
    y += i
  return y

In [72]:
jax.make_jaxpr(f3)(0)

{ lambda ; a:i32[]. let
    b:i32[] = add a 0
    c:i32[] = add b 1
    d:i32[] = add c 2
    e:i32[] = add d 3
    f:i32[] = add e 4
  in (f,) }

In [73]:
jax.jit(f3)(0)

Array(10, dtype=int32, weak_type=True)

In [74]:
def f4(x):
  y = 0
  for i in range(x.shape[0]):
    y += x[i]
  return y

In [75]:
jax.make_jaxpr(f4)(jnp.array([1.0, 2.0, 3.0]))

{ lambda ; a:f32[3]. let
    b:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c:f32[] = squeeze[dimensions=(0,)] b
    d:f32[] = add 0.0 c
    e:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f:f32[] = squeeze[dimensions=(0,)] e
    g:f32[] = add d f
    h:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i:f32[] = squeeze[dimensions=(0,)] h
    j:f32[] = add g i
  in (j,) }

In [76]:
jax.jit(f4)(jnp.array([1.0, 2.0, 3.0]))

Array(6., dtype=float32)

Dependance on a parameter value

In [77]:
def f5(x):
  y = 0
  for i in range(x):
    y += i
  return y

In [78]:
f5(5)

10

In [79]:
jax.make_jaxpr(f5)(5)

TracerIntegerConversionError: ignored

In [80]:
jax.jit(f5)(5)

TracerIntegerConversionError: ignored

In [83]:
def relu(x):
  #print(x)
  if x > 0:
    return x
  return 0.0

In [84]:
relu(10.0)

10.0

In [85]:
jax.make_jaxpr(relu)(10.0)

TracerBoolConversionError: ignored

Using static parameters to overcome this

In [86]:
jax.make_jaxpr(f5, static_argnums=0)(5)

{ lambda ; . let  in (10,) }

In [87]:
jax.jit(f5, static_argnums=0)(5)

Array(10, dtype=int32, weak_type=True)

In [88]:
jax.make_jaxpr(relu, static_argnums=0)(12.3)

{ lambda ; . let  in (12.3,) }

In [89]:
jax.jit(relu, static_argnums=0)(12.3)

Array(12.3, dtype=float32, weak_type=True)

Rewriting with structured control flow primitives

In [90]:
def f5(x):
  return jax.lax.fori_loop(0, x, lambda i,v: v+i, 0)

In [91]:
f5(5)

Array(10, dtype=int32, weak_type=True)

In [92]:
jax.make_jaxpr(f5)(5)

{ lambda ; a:i32[]. let
    _:i32[] _:i32[] b:i32[] = while[
      body_jaxpr={ lambda ; c:i32[] d:i32[] e:i32[]. let
          f:i32[] = add c 1
          g:i32[] = add e c
        in (f, d, g) }
      body_nconsts=0
      cond_jaxpr={ lambda ; h:i32[] i:i32[] j:i32[]. let
          k:bool[] = lt h i
        in (k,) }
      cond_nconsts=0
    ] 0 a 0
  in (b,) }

In [93]:
jax.jit(f5)(5)

Array(10, dtype=int32, weak_type=True)

In [94]:
def relu(x):
  return jax.lax.cond(x>0, lambda x: x, lambda x: 0.0, x)

In [95]:
relu(12.3)

Array(12.3, dtype=float32, weak_type=True)

In [96]:
relu(-12.3)

Array(0., dtype=float32, weak_type=True)

In [97]:
jax.make_jaxpr(relu)(12.3)

{ lambda ; a:f32[]. let
    b:bool[] = gt a 0.0
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
    d:f32[] = cond[
      branches=(
        { lambda ; e:f32[]. let  in (0.0,) }
        { lambda ; f:f32[]. let  in (f,) }
      )
      linear=(False,)
    ] c a
  in (d,) }

In [98]:
jax.jit(relu)(12.3)

Array(12.3, dtype=float32, weak_type=True)

## XLA

In [99]:
def f(x, y, z):
  return jnp.sum(x + y * z)

In [100]:
x = jnp.array([1.0, 1.0, 1.0])
y = jnp.ones((3,3))*2.0
z = jnp.array([2.0, 1.0, 0.0]).T

In [101]:
x

Array([1., 1., 1.], dtype=float32)

In [102]:
y

Array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]], dtype=float32)

In [103]:
z

Array([2., 1., 0.], dtype=float32)

In [104]:
f(x, y, z)

Array(27., dtype=float32)

In [105]:
f_jitted = jax.jit(f)

In [106]:
f_lowered = f_jitted.lower(x,y,z)
print(f_lowered.as_text())

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<3x3xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<3xf32> {mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %2 = stablehlo.multiply %arg1, %1 : tensor<3x3xf32>
    %3 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xf32>) -> tensor<1x3xf32>
    %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<1x3xf32>) -> tensor<3x3xf32>
    %5 = stablehlo.add %4, %2 : tensor<3x3xf32>
    %6 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %7 = stablehlo.reduce(%5 init: %6) applies stablehlo.add across dimensions = [0, 1] : (tensor<3x3xf32>, tensor<f32>) -> tensor<f32>
    retu

In [107]:
f_compiled = f_lowered.compile()

In [108]:
f_compiled

<jax._src.stages.Compiled at 0x7cfd1090c9a0>

In [109]:
print(f_compiled.as_text())

HloModule jit_f, is_scheduled=true, entry_computation_layout={(f32[3]{0}, f32[3,3]{1,0}, f32[3]{0})->f32[]}, allow_spmd_sharding_propagation_to_output={true}

%region_0.15 (Arg_0.16: f32[], Arg_1.17: f32[]) -> f32[] {
  %Arg_1.17 = f32[] parameter(1)
  %Arg_0.16 = f32[] parameter(0)
  ROOT %add = f32[] add(f32[] %Arg_0.16, f32[] %Arg_1.17), metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(0, 1)]" source_file="<ipython-input-99-95d48614b981>" source_line=2}
}

%fused_computation (param_0.4: f32[3,3], param_1.4: f32[3], param_2.2: f32[3]) -> f32[] {
  %param_2.2 = f32[3]{0} parameter(2)
  %broadcast.5 = f32[3,3]{1,0} broadcast(f32[3]{0} %param_2.2), dimensions={1}, metadata={op_name="jit(f)/jit(main)/add" source_file="<ipython-input-99-95d48614b981>" source_line=2}
  %param_0.4 = f32[3,3]{1,0} parameter(0)
  %param_1.4 = f32[3]{0} parameter(1)
  %broadcast.4 = f32[3,3]{1,0} broadcast(f32[3]{0} %param_1.4), dimensions={1}, metadata={op_name="jit(f)/jit(main)/mul" source_file="<ipython

MLIR

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/mlir_hlo



## JIT Limitations

Long representation

In [110]:
def cumulative_sum(x):
  acc = 0.0
  y = []
  for i in range(x.shape[0]):
    acc += x[i]
    y.append(acc)
  return y

In [111]:
cumulative_sum(jnp.array([1.0, 1.0, 5.0, 2.0]))

[Array(1., dtype=float32),
 Array(2., dtype=float32),
 Array(7., dtype=float32),
 Array(9., dtype=float32)]

In [112]:
j = jax.make_jaxpr(cumulative_sum)(jnp.ones(10000))

In [113]:
len(j.jaxpr.eqns)

30000

In [114]:
j

{ lambda ; a:f32[10000]. let
    b:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] a
    c:f32[] = squeeze[dimensions=(0,)] b
    d:f32[] = add 0.0 c
    e:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=None] a
    f:f32[] = squeeze[dimensions=(0,)] e
    g:f32[] = add d f
    h:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=None] a
    i:f32[] = squeeze[dimensions=(0,)] h
    j:f32[] = add g i
    k:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=None] a
    l:f32[] = squeeze[dimensions=(0,)] k
    m:f32[] = add j l
    n:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=None] a
    o:f32[] = squeeze[dimensions=(0,)] n
    p:f32[] = add m o
    q:f32[1] = slice[limit_indices=(6,) start_indices=(5,) strides=None] a
    r:f32[] = squeeze[dimensions=(0,)] q
    s:f32[] = add p r
    t:f32[1] = slice[limit_indices=(7,) start_indices=(6,) strides=None] a
    u:f32[] = squeeze[dimensions=(0,)] t
    v:f32[] = add s u
   

In [115]:
%time cs = jax.jit(cumulative_sum)(jnp.ones(10000))

CPU times: user 9min 18s, sys: 3.95 s, total: 9min 22s
Wall time: 9min 20s


In [116]:
def cumulative_sum_fast(x):
  result, array = jax.lax.scan(lambda carry, elem: (carry+elem, carry+elem), 0.0, x)
  return array

In [117]:
cumulative_sum_fast(jnp.array([1.0, 1.0, 5.0, 2.0]))

Array([1., 2., 7., 9.], dtype=float32)

In [118]:
j = jax.make_jaxpr(cumulative_sum_fast)(jnp.ones(10000))

In [119]:
len(j.jaxpr.eqns)

1

In [120]:
j

{ lambda ; a:f32[10000]. let
    _:f32[] b:f32[10000] = scan[
      jaxpr={ lambda ; c:f32[] d:f32[]. let
          e:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
          f:f32[] = add e d
          g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] c
          h:f32[] = add g d
        in (f, h) }
      length=10000
      linear=(False, False)
      num_carry=1
      num_consts=0
      reverse=False
      unroll=1
    ] 0.0 a
  in (b,) }

In [121]:
%time cs = jax.jit(cumulative_sum_fast)(jnp.ones(10000))

CPU times: user 174 ms, sys: 1.01 ms, total: 175 ms
Wall time: 223 ms


Class methods

In [122]:
class ScaleClass:
  def __init__(self, scale: jnp.array):
    self.scale = scale

  @jax.jit
  def apply(self, x: jnp.array):
    return self.scale * x

In [123]:
scale_double = ScaleClass(2)

In [124]:
scale_double.apply(10)

TypeError: ignored

In [125]:
from functools import partial

class ScaleClass:
  def __init__(self, scale: jnp.array):
    self.scale = scale

  def apply(self, x: jnp.array):
    return _apply_helper(self.scale, x)

@partial(jax.jit, static_argnums=0)
def _apply_helper(scale, x):
  return scale*x

In [126]:
scale_double = ScaleClass(2)


In [127]:
scale_double.apply(10)

Array(20, dtype=int32, weak_type=True)

## AOT compilation

In [142]:
def selu(x,
         alpha=1.6732632423543772848170429916717,
         scale=1.0507009873554804934193349852946):
  '''Scaled exponential linear unit activation function.'''
  print('Function run')
  return scale * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [143]:
selu_jit = jax.jit(selu)

In [144]:
selu_aot = jax.jit(selu).lower(1.0).compile()

Function run


In [145]:
selu_jit(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [146]:
selu_aot(17.8)

Array(18.702477, dtype=float32, weak_type=True)

In [147]:
selu_jit(17)

Function run


Array(17.861917, dtype=float32, weak_type=True)

In [148]:
selu_aot(17)

TypeError: ignored

In [149]:
selu_jit_batched = jax.vmap(selu_jit)

In [150]:
selu_aot_batched = jax.vmap(selu_aot)

In [151]:
selu_jit_batched(jnp.array([42.0, 78.0, -12.3]))

Function run


Array([44.129444 , 81.95468  , -1.7580913], dtype=float32)

In [152]:
selu_aot_batched(jnp.array([42.0, 78.0, -12.3]))

TypeError: ignored