In [5]:
import jax

# Functions

In [6]:
@jax.jit
def bat(x, f):
    return f(x)

# Function should be marked as static_argnums or names. Error, otherwise.
# bat(1, lambda x: x + 1)

In [7]:
from functools import partial

@partial(jax.jit, static_argnames=["f"])
def bat(x, f):
    return f(x)

bat(1, lambda x: x + 1)

Array(2, dtype=int32, weak_type=True)

# Use tuple than the list

list isn't hashable

In [8]:
@partial(jax.jit, static_argnames=('param',))
def foo(a, param):
    return a + param[0]

foo(1, [1, 2, 3])

ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'foo' while trying to hash an object of type <class 'list'>, [1, 2, 3]. The error was:
TypeError: unhashable type: 'list'


In [9]:
# Trivial way to test hashability of a list
hash([1, 2, 3])

TypeError: unhashable type: 'list'

Use tuple instead.

In [10]:
foo(1, tuple([1, 2, 3]))

Array(2, dtype=int32, weak_type=True)

# Partial

In [11]:
from functools import partial

@partial(jax.jit, static_argnames=["f"])
def foo(x, f):
    return f(x, x)

def f(a, b):
    return a + b

foo(1, f)

Array(2, dtype=int32, weak_type=True)

In [12]:
from functools import partial

@partial(jax.jit, static_argnames=["f"])
def foo(x, f):
    return f(x)

def f(a, b):
    return a + b

foo(1, partial(f, b=2))

Array(3, dtype=int32, weak_type=True)