key = jax.random.key(0)
image = jax.random.normal(key=key, shape=(3, 32, 32))
matrix = jax.numpy.array([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]
])
offset = jax.numpy.array([5, 5, 0])
img_t = dm_pix.affine_transform(image=image, matrix=matrix, offset=offset)