In [None]:
# default_exp utils

# Utils
> Utilities & Convenience functions...

In [None]:
# hide
import sys; sys.path.insert(0, '..') # make sure rust binary can be found...

In [None]:
# exporti
from functools import partial, wraps

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
from jax.lib import xla_bridge

## Convolutions

In [None]:
# export
@partial(jax.jit, static_argnames=("window_strides", "padding"))
def conv(lhs, rhs, window_strides=(1, 1), padding="SAME", **kwargs):
    if xla_bridge.get_backend().platform == "cpu":
        return lax.conv(lhs, rhs, window_strides, padding, **kwargs)
    else:  # gpu can only do float convolutions...
        if not lhs.dtype == rhs.dtype:
            raise ValueError(
                f"Cannot do convolution. Different dtypes for "
                f"'lhs' and 'rhs'. Got: {lhs.dtype}, {rhs.dtype}"
            )
        dtype = lhs.dtype
        if dtype not in (jnp.float16, jnp.float32, jnp.float64):
            lhs = jnp.asarray(lhs, dtype=float)
            rhs = jnp.asarray(rhs, dtype=float)
        result = lax.conv(lhs, rhs, window_strides, padding, **kwargs)
        if dtype == bool:
            result = result > 1e-5
        elif dtype not in (jnp.float16, jnp.float32, jnp.float64):
            result = jnp.asarray(result, dtype=dtype)
        return result

In [None]:
# export
@wraps(conv)
def conv2d(lhs, rhs, window_strides=(1, 1), padding="SAME", **kwargs):
    lhs = lhs[None, None, :, :]
    rhs = rhs[None, None, :, :]
    result = conv(lhs, rhs, window_strides, padding, **kwargs)[0, 0, :, :]
    return result

In [None]:
# export
@wraps(conv)
def batch_conv2d(lhs, rhs, window_strides=(1, 1), padding="SAME", **kwargs):
    lhs = lhs[:, None, :, :]
    rhs = rhs[:, None, :, :]
    return conv(lhs, rhs, window_strides, padding, **kwargs)[:, 0, :, :]

In [None]:
# export
def dilute(touches, brush):
    return conv2d(lhs=touches, rhs=brush, window_strides=(1, 1), padding="SAME")

## Random

I just can't be bothered doing this the JAX way...

In [None]:
# export
def randn(shape, r=None, dtype=float):
    if r is not None:
        if isinstance(r, int):
            r = np.random.RandomState(seed=r)
    else:
        r = np.random
    return jnp.asarray(r.randn(*shape), dtype=dtype)

In [None]:
# export
def rand(shape, r=None, dtype=float):
    if r is not None:
        if isinstance(r, int):
            r = np.random.RandomState(seed=r)
    else:
        r = np.random
    return jnp.asarray(r.rand(*shape), dtype=dtype)

## Argmax / Argmin

In [None]:
# export
@jax.jit
def argmax2d(arr2d):
    _, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmax(arr1d)
    return k // n, k % n

In [None]:
# export
@jax.jit
def argmin2d(arr2d):
    _, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmin(arr1d)
    return k // n, k % n