In [2]:
import jax, jax.numpy as jp
from functools import partial

@partial(jax.jit, static_argnums=(1, 2))
def unfold(x: jp.ndarray, dim: int, p: int):
    n = x.shape[dim]
    ids = jp.arange(n)[:p * (n // p)].reshape(-1, p)
    trns = tuple(range(dim + 1)) + tuple(range(dim + 2, x.ndim + 1)) + (dim + 1, )
    x_unf = x.take(ids, axis=dim).transpose(trns)
    return x_unf

In [6]:
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (256, 128, 128))
y = unfold(x, 1, 32)
y.shape

(256, 4, 128, 32)

In [7]:
%timeit unfold(x, 1, 32).block_until_ready()

76.2 µs ± 527 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [8]:
import torch as th

z = th.randn(256, 128, 128)

In [9]:
%timeit z.unfold(1, 32, 32)

1.03 µs ± 1.93 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [19]:
import jax, jax.numpy as jp
import torch as th

def patch_mask_th(x, p):
    mask = x.unfold(2, p, p).unfold(1, p, p)
    mask = mask.permute(1, 2, 0, 4, 3)
    mask = mask.contiguous().view(mask.size(0) * mask.size(1), -1)
    return mask


def patch_mask_jx(x, p):
    mask = jax.lax.conv_general_dilated_patches(
        x[None, None, ...],
        filter_shape=(1, p, p),
        window_strides=(1, p, p),
        padding='valid'
    )
    d = int(mask.shape[1] ** 0.5)
    mask = mask.reshape(d, d, *mask.shape[2:]).transpose(3, 4, 2, 0, 1)
    return mask.reshape(mask.shape[0] * mask.shape[1], -1)

In [20]:
import numpy as np

x = np.random.randn(3, 256, 256)
x_th = th.from_numpy(x)
x_jx = jp.array(x)

y1 = patch_mask_th(x_th, 16)
y2 = patch_mask_jx(x_jx, 16)

np.allclose(y1.numpy(), y2)

True

In [21]:
%timeit patch_mask_th(x_th, 16)

22.9 µs ± 281 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [22]:
%timeit patch_mask_th(x_th, 16)

21.9 µs ± 260 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
