In [2]:
import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax

In [3]:
k = 25

@jax.jit
def grad_superspike(x):
    return 1 / (1 + k*jnp.abs(x))**2

In [4]:
V = jnp.arange(16, dtype=jnp.float32)
V

Array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15.], dtype=float32)

In [5]:
grad_superspike(V)

Array([1.0000000e+00, 1.4792900e-03, 3.8446751e-04, 1.7313019e-04,
       9.8029610e-05, 6.2988162e-05, 4.3857726e-05, 3.2283056e-05,
       2.4751862e-05, 1.9578667e-05, 1.5872763e-05, 1.3127495e-05,
       1.1037406e-05, 9.4094621e-06, 8.1168173e-06, 7.0733363e-06],      dtype=float32)

In [8]:
jax.make_jaxpr(grad_superspike)(V)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[16][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[16][39m = pjit[
      name=grad_superspike
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; c[35m:f32[16][39m. [34m[22m[1mlet
          [39m[22m[22md[35m:f32[16][39m = abs c
          e[35m:f32[16][39m = mul 25.0 d
          f[35m:f32[16][39m = add 1.0 e
          g[35m:f32[16][39m = integer_pow[y=2] f
          h[35m:f32[16][39m = div 1.0 g
        [34m[22m[1min [39m[22m[22m(h,) }
    ] a
  [34m[22m[1min [39m[22m[22m(b,) }

In [11]:
lowered = jax.jit(grad_superspike).lower(V)

In [12]:
print(lowered.as_text())

module @jit_grad_superspike attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @grad_superspike(%arg0) : (tensor<16xf32>) -> tensor<16xf32>
    return %0 : tensor<16xf32>
  }
  func.func private @grad_superspike(%arg0: tensor<16xf32>) -> tensor<16xf32> {
    %0 = stablehlo.abs %arg0 : tensor<16xf32>
    %1 = stablehlo.constant dense<2.500000e+01> : tensor<f32>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %3 = stablehlo.multiply %2, %0 : tensor<16xf32>
    %4 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %6 = stablehlo.add %5, %3 : tensor<16xf32>
    %7 = stablehlo.multiply %6, %6 : tensor<16xf32>
    %8 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
   

In [13]:
compiled = lowered.compile()

In [14]:
print(compiled.as_text())

HloModule jit_grad_superspike, is_scheduled=true, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="97c98a0ca68aa6e801146d4c1b162054"}

%fused_divide (param_0.3: f32[16]) -> f32[16] {
  %constant_0_1 = f32[] constant(1)
  %broadcast.4.1 = f32[16]{0} broadcast(f32[] %constant_0_1), dimensions={}
  %param_0.3 = f32[16]{0} parameter(0)
  %abs.2.1 = f32[16]{0} abs(f32[16]{0} %param_0.3), metadata={op_name="jit(grad_superspike)/jit(main)/jit(grad_superspike)/abs" source_file="/tmp/ipykernel_73553/1167682913.py" source_line=5}
  %constant_1_1 = f32[] constant(25)
  %broadcast.6.1 = f32[16]{0} broadcast(f32[] %constant_1_1), dimensions={}
  %multiply.4.1 = f32[16]{0} multiply(f32[16]{0} %abs.2.1, f32[16]{0} %broadcast.6.1), metadata={op_name="jit(grad_superspike)/jit(main)/jit(grad_superspike)/mul" source_file="/tmp/ipykernel_73553/1167682913.

In [15]:
compiled.cost_analysis()

[{'flops': 80.0,
  'bytes accessed': 128.0,
  'utilization0{}': 1.0,
  'utilization1{}': 4.0,
  'bytes accessed1{}': 256.0,
  'bytes accessedout{}': 64.0,
  'bytes accessed0{}': 64.0}]

## Pallas

In [16]:
from jax.experimental import pallas as pl

import jax
import jax.numpy as jnp
import spyx
import haiku as hk
import optax

In [17]:
def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

In [18]:
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(
      add_vectors_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
  )(x, y)
add_vectors(jnp.arange(4), jnp.arange(4))

Array([0, 2, 4, 6], dtype=int32)

In [19]:
V = jnp.arange(512, dtype=jnp.float32)
V

Array([  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
        11.,  12.,  13.,  14.,  15.,  16.,  17.,  18.,  19.,  20.,  21.,
        22.,  23.,  24.,  25.,  26.,  27.,  28.,  29.,  30.,  31.,  32.,
        33.,  34.,  35.,  36.,  37.,  38.,  39.,  40.,  41.,  42.,  43.,
        44.,  45.,  46.,  47.,  48.,  49.,  50.,  51.,  52.,  53.,  54.,
        55.,  56.,  57.,  58.,  59.,  60.,  61.,  62.,  63.,  64.,  65.,
        66.,  67.,  68.,  69.,  70.,  71.,  72.,  73.,  74.,  75.,  76.,
        77.,  78.,  79.,  80.,  81.,  82.,  83.,  84.,  85.,  86.,  87.,
        88.,  89.,  90.,  91.,  92.,  93.,  94.,  95.,  96.,  97.,  98.,
        99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,
       110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.,
       121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,
       132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,
       143., 144., 145., 146., 147., 148., 149., 15

In [78]:
V.shape

(16,)

In [21]:
@jax.jit
def superspike_kernel(x_ref, o_ref):
    x = x_ref[...]
    o_ref[...] = 1 / (1 + 25*jnp.abs(x))**2

In [22]:
@jax.jit
def pallas_superspike(x: jax.Array) -> jax.Array:
    bspec = pl.BlockSpec(block_shape=(32,), index_map=lambda i: i)
    return pl.pallas_call(
        superspike_kernel,
        out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(16,),
        in_specs=[pl.BlockSpec(lambda i: i, (32,))],
        out_specs=pl.BlockSpec(lambda i: i, (32,))
    )(x,)

In [110]:
pallas_superspike(V)

Array([1.00000000e+00, 1.47928996e-03, 3.84467508e-04, 1.73130189e-04,
       9.80296099e-05, 6.29881615e-05, 4.38577263e-05, 3.22830565e-05,
       2.47518619e-05, 1.95786670e-05, 1.58727635e-05, 1.31274946e-05,
       1.10374058e-05, 9.40946211e-06, 8.11681730e-06, 7.07333629e-06,
       6.21886693e-06, 5.51037056e-06, 4.91639685e-06, 4.41352995e-06,
       3.98404791e-06, 3.61433581e-06, 3.29379668e-06, 3.01408181e-06,
       2.76854166e-06, 2.55182772e-06, 2.35959806e-06, 2.18829882e-06,
       2.03499781e-06, 1.89725961e-06, 1.77304651e-06, 1.66064399e-06,
       1.55860107e-06, 1.46568254e-06, 1.38083215e-06, 1.30314208e-06,
       1.23182895e-06, 1.16621345e-06, 1.10570420e-06, 1.04978506e-06,
       9.98002974e-07, 9.49959883e-07, 9.05304262e-07, 8.63724949e-07,
       8.24945744e-07, 7.88720627e-07, 7.54830353e-07, 7.23078358e-07,
       6.93288484e-07, 6.65302366e-07, 6.38977212e-07, 6.14184216e-07,
       5.90806678e-07, 5.68738926e-07, 5.47884838e-07, 5.28157102e-07,
      

In [111]:
grad_superspike(V)

Array([1.00000000e+00, 1.47928996e-03, 3.84467508e-04, 1.73130189e-04,
       9.80296099e-05, 6.29881615e-05, 4.38577263e-05, 3.22830565e-05,
       2.47518619e-05, 1.95786670e-05, 1.58727635e-05, 1.31274946e-05,
       1.10374058e-05, 9.40946211e-06, 8.11681730e-06, 7.07333629e-06,
       6.21886693e-06, 5.51037056e-06, 4.91639685e-06, 4.41352995e-06,
       3.98404791e-06, 3.61433581e-06, 3.29379668e-06, 3.01408181e-06,
       2.76854166e-06, 2.55182772e-06, 2.35959806e-06, 2.18829882e-06,
       2.03499781e-06, 1.89725961e-06, 1.77304651e-06, 1.66064399e-06,
       1.55860107e-06, 1.46568254e-06, 1.38083215e-06, 1.30314208e-06,
       1.23182895e-06, 1.16621345e-06, 1.10570420e-06, 1.04978506e-06,
       9.98002974e-07, 9.49959883e-07, 9.05304262e-07, 8.63724949e-07,
       8.24945744e-07, 7.88720627e-07, 7.54830353e-07, 7.23078358e-07,
       6.93288484e-07, 6.65302366e-07, 6.38977212e-07, 6.14184216e-07,
       5.90806678e-07, 5.68738926e-07, 5.47884838e-07, 5.28157102e-07,
      

In [23]:
p_lowered = pallas_superspike.lower(V)

In [24]:
print(p_lowered.as_text())

module @jit_pallas_superspike attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @wrapped(%arg0) : (tensor<512xf32>) -> tensor<512xf32>
    return %0 : tensor<512xf32>
  }
  func.func private @wrapped(%arg0: tensor<512xf32>) -> tensor<512xf32> {
    %0 = stablehlo.custom_call @__gpu$xla.gpu.triton(%arg0) {mhlo.backend_config = {debug = false, grid_x = 16 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIR19.0.0git\00\017\07\01\05\09!\01\03\0F\03\13\13\17\1B\1F#'+/3\05\0D7;?CGK\03\BD\99\11\01\95\07\0F\0F\0F\0F\0B\0F\0F\0B\0F\0F\0F\13\0B\0F\0B\0F\0B\1F\13\0B\0B\0B\0B\0F\0F\13\0F\0B\13\0B\0F\0F\0B\13\0B\0F\0F\0B\17\0F\0F\0B\17\0F\0F\0B\17\0F\0F\0B\13\0B\0F\0F\0B\17\0F\0F\17\0F\17\0B\0F\0B\0F\0B\0F\13\1F\0B\0B\0B\0B\05\05YY\01\0F\0F\13\13\07\13\0B\17\03\03A\02\C2\04\1F\1D\93\0D\1D/1\11\01\

In [25]:
p_compiled = p_lowered.compile()

In [26]:
print(p_compiled.as_text())

HloModule jit_pallas_superspike, is_scheduled=true, entry_computation_layout={(f32[512]{0})->f32[512]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="6632de43e561b95501cdc7d88cb9d595"}

ENTRY %main.6 (Arg_0.1.0: f32[512]) -> f32[512] {
  %Arg_0.1.0 = f32[512]{0} parameter(0)
  ROOT %custom-call.0.0 = f32[512]{0} custom-call(f32[512]{0} %Arg_0.1.0), custom_call_target="__gpu$xla.gpu.triton", operand_layout_constraints={f32[512]{0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(pallas_superspike)/jit(main)/jit(wrapped)/pallas_call[name=superspike_kernel which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(512,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(512,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(16,), block_mappings=(BlockMapping(block_shape=(32,), index_map_jaxpr={ \033[34m\033[22m\033[1mlambda \033[39m\033[22m\033[

In [30]:
def grad_superspike2(x):
    return 1 / (1 + k*jnp.abs(x))**2

In [31]:
%timeit -r20 grad_superspike2(V).block_until_ready()

273 µs ± 24.8 µs per loop (mean ± std. dev. of 20 runs, 1000 loops each)


In [32]:
%timeit -r20 grad_superspike(V).block_until_ready()

33.5 µs ± 4.39 µs per loop (mean ± std. dev. of 20 runs, 10000 loops each)


In [33]:
%timeit -r20 pallas_superspike(V).block_until_ready()

27.9 µs ± 2.68 µs per loop (mean ± std. dev. of 20 runs, 10000 loops each)


## LIF neuron

In [None]:
class LIF(hk.RNNCore):
    """
    Leaky Integrate and Fire neuron model inspired by the implementation in
    snnTorch:

    https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html
    
    """

    def __init__(self, 
                 hidden_shape: tuple, 
                 beta=None,
                 threshold = 1.,
                 activation = superspike(),
                 name="LIF"):

        """
        
        :hidden_size: Size of preceding layer's outputs
        :beta: decay rate. Set to float in range (0,1] for uniform decay across layer, otherwise it will be a normal
                distribution centered on 0.5 with stddev of 0.25
        :threshold: threshold for reset. Defaults to 1.
        :activation: spyx.axn.Axon object, default is Heaviside with Straight-Through-Estimation.
        """
        super().__init__(name=name)
        self.hidden_shape = hidden_shape
        self.beta = beta
        self.threshold = threshold
        self.spike = activation
    
    def __call__(self, x, V):
        """
        :x: input vector coming from previous layer.
        :V: neuron state tensor.

        """
        beta = self.beta # this line can probably be deleted, and the check changed to self.beta
        if not beta:
            beta = hk.get_parameter("beta", self.hidden_shape,
                                init=hk.initializers.TruncatedNormal(0.25, 0.5))
            beta = jnp.clip(beta, 0, 1)
        else:
            beta = hk.get_parameter("beta", [],
                                init=hk.initializers.Constant(beta))
            beta = jnp.clip(beta, 0, 1)
            
        # calculate whether spike is generated, and update membrane potential
        spikes = self.spike(V-self.threshold)
        V = beta*V + x - spikes * self.threshold
        
        return spikes, V

    def initial_state(self, batch_size): 
        return jnp.zeros((batch_size,) + self.hidden_shape)


In [34]:
from spyx.axn import superspike

In [56]:
activation_func = superspike()

In [58]:
def lif_neuron(V, x): # carry, x
    spikes = activation_func(V-1.1)
    V = 0.9*V + x - spikes * 1.1

    return V, spikes # carry, y



In [59]:
x_in = jnp.ones(16) / 2
x_in

Array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
       0.5, 0.5, 0.5], dtype=float32)

In [60]:
v0 = jnp.zeros(1)
v0

Array([0.], dtype=float32)

In [80]:
@jax.jit
def run_lif(x_in):
    return jax.lax.scan(
        lif_neuron,
        v0,
        x_in
    )

@jax.jit
def run_lif_unrolled(x_in):
    return jax.lax.scan(
        lif_neuron,
        v0,
        x_in,
        unroll=True
    )    

In [72]:
print(jax.make_jaxpr(run_lif)(x_in))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[16][39m. [34m[22m[1mlet
    [39m[22m[22mb[35m:f32[1][39m c[35m:f32[16,1][39m = pjit[
      name=run_lif
      jaxpr={ [34m[22m[1mlambda [39m[22m[22md[35m:f32[1][39m; e[35m:f32[16][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f32[1][39m g[35m:f32[16,1][39m = scan[
            _split_transpose=False
            jaxpr={ [34m[22m[1mlambda [39m[22m[22m; h[35m:f32[1][39m i[35m:f32[][39m. [34m[22m[1mlet
                [39m[22m[22mj[35m:f32[1][39m = sub h 1.100000023841858
                k[35m:f32[1][39m = pjit[
                  name=wrapped_fun
                  jaxpr={ [34m[22m[1mlambda [39m[22m[22m; l[35m:f32[1][39m. [34m[22m[1mlet
                      [39m[22m[22mm[35m:f32[1][39m = custom_vjp_call_jaxpr[
                        bwd=<function CustomVJPCallPrimitive.bind.<locals>.<lambda> at 0x7c885a4d8310>
                        fun_jaxpr={ [34m[22m[1mlambda [39m

In [78]:
print(run_lif.lower(x_in).as_text())

module @jit_run_lif attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = "default"}) -> (tensor<1xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<16x1xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0.000000e+00> : tensor<1xf32>
    %1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<f32>) -> tensor<16x1xf32>
    %3 = stablehlo.constant dense<0> : tensor<i32>
    %4:4 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %3, %iterArg_1 = %0, %iterArg_2 = %2) : tensor<16xf32>, tensor<i32>, tensor<1xf32>, tensor<16x1xf32>
     cond {
      %5 = stablehlo.constant dense<16> : tensor<i32>
      %6 = stablehlo.compare  LT, %iterArg_0, %5,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %6 : tensor<i1>
    } do {
      %5 = stablehlo.constant dense<0

In [79]:
print(run_lif.lower(x_in).compile().as_text())

HloModule jit_run_lif, is_scheduled=true, entry_computation_layout={(f32[16]{0})->(f32[1]{0}, f32[16,1]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs="ccb804c474066cf7086fa9ad8bb7d353"}

%fused_dynamic_update_slice (param_0: f32[16,1], param_1.18: s32[], param_2.21: f32[1]) -> f32[16,1] {
  %param_0 = f32[16,1]{1,0} parameter(0)
  %param_2.21 = f32[1]{0} parameter(2)
  %constant_6_2 = f32[1]{0} constant({-1.1})
  %add.8.3 = f32[1]{0} add(f32[1]{0} %param_2.21, f32[1]{0} %constant_6_2), metadata={op_name="jit(run_lif)/jit(main)/while/body/sub" source_file="/tmp/ipykernel_73553/3145203169.py" source_line=2}
  %constant_0_2 = f32[1]{0} constant({0})
  %compare.5.3 = pred[1]{0} compare(f32[1]{0} %add.8.3, f32[1]{0} %constant_0_2), direction=GT, metadata={op_name="jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/gt" source_file="/home/legion/.local/lib/python3.10/site-package

In [81]:
print(run_lif_unrolled.lower(x_in).compile().as_text())

HloModule jit_run_lif_unrolled, is_scheduled=true, entry_computation_layout={(f32[16]{0})->(f32[1]{0}, f32[16,1]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs="cb90a9e0f226453786315bab202e6172"}

%fused_concatenate (param_0.166: f32[16]) -> f32[16,1] {
  %constant_191_1 = f32[1,1]{1,0} constant({ {0} })
  %param_0.166 = f32[16]{0} parameter(0)
  %bitcast.443.17 = f32[1,16]{1,0} bitcast(f32[16]{0} %param_0.166)
  %slice.33.17 = f32[1,1]{1,0} slice(f32[1,16]{1,0} %bitcast.443.17), slice={[0:1], [0:1]}, metadata={op_name="jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]" source_file="/tmp/ipykernel_73553/283343340.py" source_line=11}
  %constant_193_2 = f32[1,1]{1,0} constant({ {-1.1} })
  %add.81.9 = f32[1,1]{1,0} add(f32[1,1]{1,0} %slice.33.17, f32[1,1]{1,0} %constant_193_2), metadata={op_name="jit(run_lif_unrolled)/jit(ma

In [75]:
%timeit -r20 run_lif(x_in)[0].block_until_ready()

151 µs ± 9.39 µs per loop (mean ± std. dev. of 20 runs, 1000 loops each)


In [82]:
%timeit -r20 run_lif_unrolled(x_in)[0].block_until_ready()

32.9 µs ± 2.03 µs per loop (mean ± std. dev. of 20 runs, 10000 loops each)


In [94]:
@jax.jit
def lif_kernel(V_ref, X_ref, out_X): # carry, x
    
    v, x = V_ref[...], X_ref[...]
    v, x = lif_neuron(v, x)

    out_X[...] = x



In [95]:
@jax.jit
def pallas_lif(v, x: jax.Array) -> jax.Array:
    bspec = pl.BlockSpec(block_shape=(1,), index_map=lambda i: i)
    return pl.pallas_call(
        lif_kernel,
        out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(1,),
        in_specs=[pl.BlockSpec(lambda i: i, (1,)), pl.BlockSpec(lambda i: i, (1,))],
        out_specs=pl.BlockSpec(lambda i: i, (1,))
    )(v0, x_in)

In [96]:
pallas_lif(v0, x_in)

LoweringError: Exception while lowering eqn:
  a[35m:f32[1][39m = pjit[
  name=wrapped_fun
  jaxpr={ [34m[22m[1mlambda [39m[22m[22m; b[35m:f32[1][39m. [34m[22m[1mlet
      [39m[22m[22mc[35m:f32[1][39m = custom_vjp_call_jaxpr[
        bwd=<function CustomVJPCallPrimitive.bind.<locals>.<lambda> at 0x7c885a4d8310>
        fun_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f32[1][39m. [34m[22m[1mlet
            [39m[22m[22me[35m:bool[1][39m = gt d 0.0
            f[35m:i32[1][39m = pjit[
              name=_where
              jaxpr={ [34m[22m[1mlambda [39m[22m[22m; g[35m:bool[1][39m h[35m:i32[][39m i[35m:i32[][39m. [34m[22m[1mlet
                  [39m[22m[22mj[35m:i32[1][39m = broadcast_in_dim[
                    broadcast_dimensions=()
                    shape=(1,)
                  ] h
                  k[35m:i32[1][39m = broadcast_in_dim[
                    broadcast_dimensions=()
                    shape=(1,)
                  ] i
                  l[35m:i32[1][39m = select_n g k j
                [34m[22m[1min [39m[22m[22m(l,) }
            ] e 1 0
            m[35m:f32[1][39m = convert_element_type[new_dtype=float32 weak_type=False] f
          [34m[22m[1min [39m[22m[22m(m,) }
        fwd_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7c885a4d8c10>
        num_consts=0
        out_trees=<function transformation_with_aux.<locals>.<lambda> at 0x7c885a4d8160>
        symbolic_zeros=False
      ] b
    [34m[22m[1min [39m[22m[22m(c,) }
] n
With context:
  LoweringRuleContext(context=ModuleContext(name='lif_kernel', grid_mapping=GridMapping(grid=(1,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m. [34m[22m[1mlet[39m[22m[22m  [34m[22m[1min [39m[22m[22m(a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x7c885a969000>), BlockMapping(block_shape=(1,), index_map_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m. [34m[22m[1mlet[39m[22m[22m  [34m[22m[1min [39m[22m[22m(a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x7c885a969000>), BlockMapping(block_shape=(1,), index_map_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[][39m. [34m[22m[1mlet[39m[22m[22m  [34m[22m[1min [39m[22m[22m(a,) }, indexing_mode=<jax._src.pallas.core.Blocked object at 0x7c885a969000>)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7c885a4f0e70>]), avals_in=[ShapedArray(float32[1])], avals_out=[ShapedArray(float32[1])], block_infos=[None])
With inval types=[RankedTensorType(tensor<1xf32>)]
In jaxpr:
{ [34m[22m[1mlambda [39m[22m[22m; a[35m:MemRef<None>{float32[1]}[39m b[35m:MemRef<None>{float32[1]}[39m c[35m:MemRef<None>{float32[1]}[39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[1][39m <- [32ma[:]
    [39me[35m:f32[1][39m <- [32mb[:]
    [39mf[35m:f32[1][39m = sub d 1.100000023841858
    g[35m:f32[1][39m = pjit[
      name=wrapped_fun
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; h[35m:f32[1][39m. [34m[22m[1mlet
          [39m[22m[22mi[35m:f32[1][39m = custom_vjp_call_jaxpr[
            bwd=<function CustomVJPCallPrimitive.bind.<locals>.<lambda> at 0x7c885a4d8310>
            fun_jaxpr={ [34m[22m[1mlambda [39m[22m[22m; j[35m:f32[1][39m. [34m[22m[1mlet
                [39m[22m[22mk[35m:bool[1][39m = gt j 0.0
                l[35m:i32[1][39m = pjit[
                  name=_where
                  jaxpr={ [34m[22m[1mlambda [39m[22m[22m; m[35m:bool[1][39m n[35m:i32[][39m o[35m:i32[][39m. [34m[22m[1mlet
                      [39m[22m[22mp[35m:i32[1][39m = broadcast_in_dim[
                        broadcast_dimensions=()
                        shape=(1,)
                      ] n
                      q[35m:i32[1][39m = broadcast_in_dim[
                        broadcast_dimensions=()
                        shape=(1,)
                      ] o
                      r[35m:i32[1][39m = select_n m q p
                    [34m[22m[1min [39m[22m[22m(r,) }
                ] k 1 0
                s[35m:f32[1][39m = convert_element_type[
                  new_dtype=float32
                  weak_type=False
                ] l
              [34m[22m[1min [39m[22m[22m(s,) }
            fwd_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7c885a4d8c10>
            num_consts=0
            out_trees=<function transformation_with_aux.<locals>.<lambda> at 0x7c885a4d8160>
            symbolic_zeros=False
          ] h
        [34m[22m[1min [39m[22m[22m(i,) }
    ] f
    [32mc[:][39m <- g
  [34m[22m[1min [39m[22m[22m() }