In [36]:
import numpy as np

import jax
import jax.lax as lax
import jax.nn
import jax.numpy as jnp
import jax_scaled_arithmetics as jsa

In [52]:
B = 128
N = 10

In [53]:
act = np.random.randn(B, N).astype(np.float32)

In [63]:
def logsumexp(a, axis=None, keepdims=True):
    dims = (axis,)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    # FIXME: not proper scale propagation, introducing NaNs
    # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax = lax.stop_gradient(amax)
    out = lax.sub(a, amax)
    out = lax.exp(out)
    out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax)
    return out


In [64]:
def fn(act):
    return logsumexp(act, axis=1)

fn(act).shape

(128, 1)

In [58]:
jax.make_jaxpr(fn)(act)

Traced<ShapedArray(float32[128,10])>with<DynamicJaxprTrace(level=1/0)>


{ lambda ; a:f32[128,10]. let
    b:f32[128] = reduce_max[axes=(1,)] a
    c:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] b
    d:f32[128,1] = stop_gradient c
    e:f32[128,10] = sub a d
    f:f32[128,10] = exp e
    g:f32[128] = reduce_sum[axes=(1,)] f
    h:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] g
    i:f32[128,1] = log h
    j:f32[128,1] = add i d
  in (j,) }

In [47]:
out, fn_vjp = jax.vjp(fn, act)

In [48]:
def fn_with_grad(in_act, out_grad):
    out_act, fn_vjp = jax.vjp(fn, in_act)
    return out_act, fn_vjp(out_grad)

In [66]:
jax.make_jaxpr(fn_with_grad)(act, act[:, :1])

{ lambda ; a:f32[128,10] b:f32[128,1]. let
    c:f32[128] = reduce_max[axes=(1,)] a
    d:f32[128,1] = reshape[dimensions=None new_sizes=(128, 1)] c
    e:bool[128,10] = eq a d
    f:f32[128,10] = convert_element_type[new_dtype=float32 weak_type=False] e
    _:f32[128] = reduce_sum[axes=(1,)] f
    g:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] c
    h:f32[128,1] = stop_gradient g
    i:f32[128,10] = sub a h
    j:f32[128,10] = exp i
    k:f32[128] = reduce_sum[axes=(1,)] j
    l:f32[128,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 1)] k
    m:f32[128,1] = log l
    n:f32[128,1] = add m h
    o:f32[128,1] = div b l
    p:f32[128] = reduce_sum[axes=(1,)] o
    q:f32[128,10] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(128, 10)] p
    r:f32[128,10] = mul q j
  in (n, r) }

In [72]:
def fn2(x, y):
    return x * y


def fn2_with_grad(in_act, out_grad):
    out_act, fn_vjp = jax.vjp(fn2, in_act, in_act)
    return out_act, fn_vjp(out_grad)

In [71]:
jax.make_jaxpr(fn2_with_grad)(act, act)

{ lambda ; a:f32[128,10] b:f32[128,10]. let
    c:f32[128,10] = mul a a
    d:f32[128,10] = mul a b
    e:f32[128,10] = mul b a
  in (c, e, d) }

In [79]:
act.shape, act.dtype

((128, 10), dtype('float32'))

In [75]:
def fn3(x):
    return jax.grad(lambda x: jnp.mean(x))(x)

In [81]:
jax.make_jaxpr(fn3)(act)

{ lambda ; a:f32[128,10]. let
    b:f32[] = reduce_sum[axes=(0, 1)] a
    _:f32[] = div b 1280.0
    c:f32[] = div 1.0 1280.0
    d:f32[128,10] = broadcast_in_dim[broadcast_dimensions=() shape=(128, 10)] c
  in (d,) }