In [None]:
# default_exp utils

# Utils
> Utilities & Convenience functions...

In [None]:
# exporti
from functools import wraps

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax

## Convolutions

In [None]:
# export
def conv(lhs, rhs, window_strides=(1,1), padding="SAME", **kwargs):
    if not lhs.dtype == rhs.dtype:
        raise ValueError(f"Cannot do convolution. Different dtypes for 'lhs' and 'rhs'. Got: {lhs.dtype}, {rhs.dtype}")
    elif lhs.dtype in (jnp.float16, jnp.float32, jnp.float64):
        return lax.conv(lhs, rhs, window_strides, padding, **kwargs)
    else:
        raise ValueError(f"Cannot do convolution. Unsupported dtype: {lhs.dtype}.")

In [None]:
# export
@wraps(conv)
def conv2d(lhs, rhs, window_strides=(1,1), padding="SAME", **kwargs):
    return conv(lhs[None, None, :, :], rhs[None, None, :, :], window_strides, padding, **kwargs)[0, 0, :, :]

In [None]:
# export
@wraps(conv)
def batch_conv2d(lhs, rhs, window_strides=(1,1), padding="SAME", **kwargs):
    return conv(lhs[:, None, :, :], rhs[:, None, :, :], window_strides, padding, **kwargs)[:, 0, :, :]

In [None]:
# export
def dilute(touches, brush):
    result = conv2d(
        lhs=touches,
        rhs=brush,
        window_strides=(1, 1),
        padding="SAME",
    )
    return jnp.where(result > 1e-10, 1.0, 0.0)

## Random

I just can't be bothered doing this the JAX way...

In [None]:
# export
def randn(shape, r=None, dtype=float):
    if r is not None:
        if isinstance(r, int):
            r = np.random.RandomState(seed=r)
    else:
        r = np.random
    return jnp.asarray(r.randn(*shape), dtype=dtype)

In [None]:
# export
def rand(shape, r=None, dtype=float):
    if r is not None:
        if isinstance(r, int):
            r = np.random.RandomState(seed=r)
    else:
        r = np.random
    return jnp.asarray(r.rand(*shape), dtype=dtype)

## Argmax / Argmin

In [None]:
# export
@jax.jit
def argmax2d(arr2d):
    m, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmax(arr1d)
    return k//m, k%m

In [None]:
# export
@jax.jit
def argmin2d(arr2d):
    m, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmin(arr1d)
    return k//m, k%m

## Float Mask
CUDA and booleans don't work well together. Therefore, we define this convenience function which maps a boolean array to a float array with two values: 1.0 and 0.0

In [None]:
# export
@jax.jit
def float_mask(boolean_mask):
    assert boolean_mask.dtype == bool
    return jnp.asarray(jnp.where(boolean_mask, 1.0, 0.0), dtype=jnp.float32)

In [None]:
float_mask(jnp.array([True, False]))

## Boolean Operations on Float Masks

In [None]:
lhs = jnp.array([1.0, 1.0, 0.0, 0.0], dtype=jnp.float32)
rhs = jnp.array([1.0, 0.0, 1.0, 0.0], dtype=jnp.float32)

In [None]:
# export
@jax.jit
def not_(lhs):
    return 1.0 - lhs

In [None]:
not_(lhs)

In [None]:
# export
@jax.jit
def or_(lhs, rhs):
    result = lhs + rhs
    return jnp.asarray(jnp.where(result > 1, 1.0, result), dtype=float)

In [None]:
or_(lhs, rhs)

In [None]:
or_(lhs, not_(rhs))

In [None]:
or_(not_(lhs), rhs)

In [None]:
# export
@jax.jit
def and_(lhs, rhs):
    return lhs * rhs

In [None]:
and_(lhs, rhs)

In [None]:
and_(lhs, not_(rhs))

In [None]:
and_(not_(lhs), rhs)

In [None]:
# export
@jax.jit
def xor_(lhs, rhs):
    return (lhs + rhs)%2.0

In [None]:
xor_(lhs, rhs)

In [None]:
xor_(lhs, not_(rhs))

In [None]:
xor_(not_(lhs), rhs)

In [None]:
# export
@jax.jit
def where_(float_mask, x, y):
    return jnp.where(float_mask > 0.5, x, y)