In [1]:

from typing import Optional

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
from jax import lax

In [2]:

def multiply_no_nan(x, y):
    dtype = jnp.result_type(x, y)
    return jnp.where(jnp.equal(x, 0.0), jnp.zeros((), dtype=dtype), jnp.multiply(x, y))

def reshape_to_broadcast(array: jnp.array, shape: tuple, axis: int):
    """ reshapes the `array` to be broadcastable to `shape`"""
    new_shape = [1 for _ in shape]
    new_shape[axis] = shape[axis]
    return jnp.reshape(array, new_shape)

def spmax(z):
    sort_z = jnp.flip(jnp.sort(z))
    k = jnp.arange(z.shape[-1]) + 1
    z_cumsum = jnp.cumsum(sort_z)
    k_array = 1 + k*sort_z
    k_z = jnp.where(z_cumsum<k_array)[0]
    # print(f"k_array:{k_array}")
    # print(f"z_cumsum:{z_cumsum}")
    # print(f"kz array:{k_z}")
    # print(f"sort_z:{sort_z}")
    k_z = jnp.max(k_z)
    tau_z = (z_cumsum[k_z]-1)/(k_z+1)
    # print(f"tau_z:{tau_z}")
    res = z - tau_z
    t = jnp.where(res>0,res,0.)
    return t

# @partial(jax.custom_jvp, nondiff_argnums=(1,))
@partial(jax.jit, static_argnums=(1,))
def _sparsemax(x, axis):
    # get indices of elements in the right axis
    # and reshape to allow broadcasting to other dimensions
    idxs = jnp.arange(x.shape[axis]) + 1
    idxs = reshape_to_broadcast(idxs, x.shape, axis)

    # calculate number of elements that belong to the support
    sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis)
    cum = jnp.cumsum(sorted_x, axis=axis)
    k = jnp.sum(jnp.where(1 + sorted_x * idxs > cum, 1, 0), axis=axis, keepdims=True)

    # calculate threshold and project to simplex
    threshold = (jnp.take_along_axis(cum, k - 1, axis=axis) - 1) / k
    return jnp.maximum(x - threshold, 0)

In [61]:

@jax.custom_jvp
def f(x):
    # return jnp.sum(x**2)
    return jnp.sum(x**2)

x = jnp.array([0.1,0.2,0.6])
# x = jnp.array([0.6])
# x = 0.5

@f.defjvp
def f_jvp(p,t):
    x, = p
    dx, = t
    return f(x), dx

jax.grad(f)(x)

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/kara/.conda/envs/tq/lib/python3.8/site-packages/jax/interpreters/ad.py", line 269, in get_primitive_transpose
    if not is_undefined_primal(val):
KeyError: integer_pow

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/kara/.conda/envs/tq/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/kara/.conda/envs/tq/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/kara/.conda/envs/tq/lib/python3.8/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/home/kara/.conda/envs/tq/lib/python3.8/site-packages/traitlets/config/application.py", line 846, in launch_instance
    app.start()
  File "/home/kara/.conda/envs/tq/lib/python3.8/site-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/home/kara/

In [8]:

"""f(x,y) = y*x**2+y+2"""

@jax.custom_jvp
def n1(x):
    return x

@n1.defjvp
def n1_jvp(p,t):
    x = p[0]
    dx = t[0]
    return x+2., dx 

def n2(y):
    return y

def n4(x):
    return x**2

def n5(x,y):
    return x*y

def n6(y):
    return y+2

def n7(x,y):
    return x+y

def f(x,y):
    n1_ = n1(x)
    n2_ = n2(y)
    n4_ = n4(n1_)
    n5_ = n5(n4_,n2_)
    n6_ = n6(y)
    n7_ = n7(n5_,n6_)
    return n7_

jax.grad(f,argnums=0)(3.,4.)

DeviceArray(40., dtype=float32, weak_type=True)

In [None]:

key = jax.random.PRNGKey(42)
w = jax.random.normal(key,shape=(4,5))
b = jnp.ones(shape=(4,))
x = jax.random.normal(key,shape=(5,))

f = lambda x:w@x+b
l = lambda p,x,y:jnp.mean(0.5*(y-x)**2)

In [196]:

y = jnp.array([0.1,0.2,0.5,0.3])
jax.grad(l,argnums=1)(f(x),y)

DeviceArray([ 0.44177228,  1.2178631 ,  0.12794471, -0.22830896], dtype=float32)

In [87]:

def model(theta,x):
    w,b = thea

prob:[0.24483848 0.7551615  0.         0.         0.        ]


DeviceArray([0.24483848, 0.7551615 , 0.        , 0.        , 0.        ],            dtype=float32)

In [None]:

x = jnp.array([2.,3.,1.,3.2,0.8])
x = jax.random.normal(key,shape=(5,))
# x = jnp.array([0.5,0.])
res = spmax(x)
res1 = _sparsemax(x, -1)
res1

In [40]:

x = jnp.array([2.,3.,1.,3.2,0.8])
def forward_fn(x):
    lin = hk.Linear(5)
    return jax.nn.selu(lin(x))

h = hk.without_apply_rng(hk.transform(forward_fn))
rng_key = jax.random.PRNGKey(43)
params = h.init(rng_key,x=x)

In [41]:

h.apply(x=x,params=params)

DeviceArray([-1.0214243,  1.484182 , -1.6136022, -1.712439 ,  0.9919937],            dtype=float32)

In [56]:

g = 1.5
p0 = jnp.array([1.,1.,1.,1.,1.])
m1 = spmax(p0*h.apply(x=x,params=params))
p1 = g-m1
m2 = spmax(p1*h.apply(x=x*m1,params=params))
p2 = (g-m1)*(g-m2)
m3 = spmax(p2*h.apply(x=x*m2,params=params))

prob:[0.         0.7460941  0.         0.         0.25390583]
prob:[1. 0. 0. 0. 0.]
prob:[0.28663146 0.         0.         0.         0.71336854]


In [57]:

def test_sort(x):
    # x = jnp.sort(x)
    # x = jnp.flip(x)
    y = jnp.where(x>1.)[0]
    y = jnp.max(y).astype(jnp.float32)
    return jnp.sum(x) / y

x = jnp.array([2.,3.,1.,4.,3.2,0.8])
jax.grad(test_sort)(x)
# test_sort(x)

DeviceArray([0.25, 0.25, 0.25, 0.25, 0.25, 0.25], dtype=float32)

In [11]:

from functools import partial

a = jnp.array([1.,2.,3.])
tree = ({"a":a,"b":2*a},{"a":3*a,"b":4*a})
# tree = ((a,2*a),(3*a,4*a))

def cumsum(prev,t):
    return prev+t,prev+t

jax.lax.scan(cumsum,init=a,xs=tree)

TypeError: unsupported operand type(s) for +: 'DynamicJaxprTracer' and 'tuple'

In [13]:

current_params = []

def transform(f):

    def apply_f(params,*args,**kwargs):
        current_params.append(params)
        outs = f(*args,**kwargs)
        # current_params.pop()
        return outs
    
    return apply_f

def get_params(id):
    return current_params[-1][id]

class Mymodule:
    def apply(self,x):
        a = get_params("w")*x
        b = get_params("w")

tr = transform(Mymodule().apply)
tr

# # %%
#
# params = {"w":5}
# tr(params,5)

<function __main__.transform.<locals>.apply_f(params, *args, **kwargs)>

In [14]:

jtr = jax.jit(tr)
jax.make_jaxpr(jtr)(params,5.)

{ lambda ; a:i32[] b:f32[]. let
    c:f32[] = xla_call[
      call_jaxpr={ lambda ; d:i32[] e:f32[]. let
          f:f32[] = convert_element_type[new_dtype=float32 weak_type=True] d
          g:f32[] = mul f e
        in (g,) }
      name=apply_f
    ] a b
  in (c,) }

In [7]:

x = jnp.zeros([5,])
def forward_fn(x):
    net = hk.nets.MLP([10,20,10])
    return net(x)

f = hk.transform(forward_fn)
rng_key = jax.random.PRNGKey(42)
params = f.init(rng_key,x=x)

In [67]:

def outer(x):
    @hk.transform
    # @hk.transparent
    def inner(t):
        net = hk.nets.MLP([10,20])
        return net(t)
    init_rng = hk.next_rng_key()
    # init_rng = jax.random.PRNGKey(42)
    params = hk.lift(inner.init)(init_rng,x)
    return jax.tree_map(lambda t:t.shape, params)

f = hk.transform(outer)
rng_key = jax.random.PRNGKey(42)
x = jnp.zeros([5,])
params = f.init(rng_key,x=x)
params

{'lifted/mlp/~/linear_0': {'w': DeviceArray([[ 0.6664603 ,  0.4364819 ,  0.1335594 , -0.6329788 ,
                -0.20747606, -0.42482248, -0.6437425 , -0.19343022,
                 0.44711798,  0.7477464 ],
               [-0.04738775,  0.7271467 ,  0.14986119,  0.7600741 ,
                -0.17041759,  0.3213556 ,  0.8654196 , -0.2867315 ,
                -0.14945029, -0.63857454],
               [-0.18451236,  0.46830958,  0.4697181 ,  0.5848382 ,
                -0.00173268, -0.2649058 ,  0.04747368, -0.78986335,
                -0.22707126,  0.707518  ],
               [ 0.02543681, -0.24890223, -0.3334315 , -0.49748698,
                 0.09348346,  0.2531042 , -0.07690459,  0.11085885,
                 0.70203125, -0.2553478 ],
               [-0.3875222 ,  0.08837761,  0.34643173,  0.02241918,
                -0.02876837, -0.5582098 , -0.47068992,  0.52299106,
                -0.00413316,  0.3278867 ]], dtype=float32),
  'b': DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [63]:

f.apply(x=x,params=params,rng=rng_key)

{'mlp/~/linear_0': {'b': (10,), 'w': (5, 10)},
 'mlp/~/linear_1': {'b': (20,), 'w': (10, 20)}}

In [137]:

@partial(jax.custom_jvp, nondiff_argnums=(1,))
@partial(jax.jit, static_argnums=(1,))
def _sparsemax(x, axis):
    # get indices of elements in the right axis
    # and reshape to allow broadcasting to other dimensions
    idxs = jnp.arange(x.shape[axis]) + 1
    idxs = reshape_to_broadcast(idxs, x.shape, axis)

    # calculate number of elements that belong to the support
    sorted_x = jnp.flip(lax.sort(x, dimension=axis), axis=axis)
    cum = jnp.cumsum(sorted_x, axis=axis)
    k = jnp.sum(jnp.where(1 + sorted_x * idxs > cum, 1, 0), axis=axis, keepdims=True)

    # calculate threshold and project to simplex
    threshold = (jnp.take_along_axis(cum, k - 1, axis=axis) - 1) / k
    return jnp.maximum(x - threshold, 0)


@_sparsemax.defjvp
@partial(jax.jit, static_argnums=(0,))
def _sparsemax_jvp(axis, primals, tangents):
    # unpack arguments
    x = primals[0]
    dx = tangents[0]

    # calculate entmax p and auxiliary s
    p = _sparsemax(x, axis)
    s = jnp.where(p > 0, 1, 0)

    # jvp as simplified product with jacobian
    dy = dx * s
    g = jnp.sum(dy, axis=axis) / jnp.sum(s, axis=axis)
    dy = dy - jnp.expand_dims(g, axis) * s
    return p, dy

In [138]:

key = jax.random.PRNGKey(42)

x = jnp.array([2.,3.,1.,3.2,0.8])
x = jax.random.normal(key,shape=(5,))
x

DeviceArray([ 0.6122652,  1.1225883, -0.8544134, -0.8127325, -0.890405 ],            dtype=float32)

In [135]:

def f(x):
    return jnp.product(_sparsemax(x,axis=-1))

jax.jacfwd(partial(_sparsemax, axis=-1))(x)

DeviceArray([[ 0.5, -0.5,  0. ,  0. ,  0. ],
             [-0.5,  0.5,  0. ,  0. ,  0. ],
             [ 0. ,  0. ,  0. ,  0. ,  0. ],
             [ 0. ,  0. ,  0. ,  0. ,  0. ],
             [ 0. ,  0. ,  0. ,  0. ,  0. ]], dtype=float32)

In [99]:

def sparse_max(z):
    sort_z = jnp.flip(jnp.sort(z))
    k = jnp.arange(z.shape[-1]) + 1
    z_cumsum = jnp.cumsum(sort_z)
    k_array = 1 + k*sort_z
    k_z = jnp.sum(jnp.where(z_cumsum<k_array,1,0))-1
    tau_z = (z_cumsum[k_z]-1)/(k_z+1)
    res = z - tau_z
    t = jnp.where(res>0,res,0.)
    return t

def sparse_max_nd(z,axis):
    if z.ndim <= 1:
        return spmax(z)
    else:
        z = jnp.swapaxes(z, -1, axis)
        pre_shape = z.shape
        out = jax.vmap(spmax)(jnp.vstack(z))
        return jnp.swapaxes(out.reshape(pre_shape),axis,-1)

In [105]:

key = jax.random.PRNGKey(42)
a = jax.random.normal(key, shape=(2,3,4))

def f(x):
    axis = 2
    return spmax_nd(x, axis=axis)

DeviceArray([[[0.        , 0.        , 1.        , 0.        ],
              [0.        , 0.22480872, 0.        , 0.7751913 ],
              [0.        , 0.11440706, 0.3228233 , 0.56276953]],

             [[0.        , 0.        , 0.06037372, 0.93962634],
              [0.        , 0.86641276, 0.        , 0.13358724],
              [0.        , 0.        , 1.        , 0.        ]]],            dtype=float32)

In [84]:

spmax(a[0,:,1])

DeviceArray([0.5862078, 0.       , 0.4137922], dtype=float32)

In [118]:

a = np.arange(24).reshape(2,3,4)
a

array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

In [116]:

jnp.mean(a,[0,1],keepdims=True)

DeviceArray([[[10., 11., 12., 13.]]], dtype=float32)

In [120]:

jnp.split(a,2,axis=-1)
jax.nn.glu(a,axis=-1)

DeviceArray([[[ 0.        ,  0.95257413],
              [ 3.9901097 ,  4.9954453 ],
              [ 7.9996367 ,  8.999849  ]],

             [[11.99999   , 12.999995  ],
              [16.        , 17.        ],
              [20.        , 21.        ]]], dtype=float32)

In [134]:

a = jnp.arange(12.).reshape(2,6)
jnp.split(a,[3,],axis=-1)

[DeviceArray([[0., 1., 2.],
              [6., 7., 8.]], dtype=float32),
 DeviceArray([[ 3.,  4.,  5.],
              [ 9., 10., 11.]], dtype=float32)]

In [133]:

a

DeviceArray([[ 0.,  1.,  2.,  3.,  4.,  5.],
             [ 6.,  7.,  8.,  9., 10., 11.]], dtype=float32)