In [None]:
import equinox as eq
import jax
from jax import lax
from jax._src.api import linear_transpose, ShapeDtypeStruct
import jax.numpy as jnp
import numpy as np
from skimage import data
import matplotlib.pyplot as plt
from einops import repeat, rearrange

In [None]:
img = data.astronaut()

In [None]:
img.min(), img.max(), img.dtype

In [None]:
fig, ax = plt.subplots()

ax.imshow(img)

plt.show()

In [None]:
img.shape

### Reshape

In [None]:
def squeeze(x, factor=2):

    assert x.ndim == 4
    b, h, w, c = x.shape

    assert h % factor == 0
    assert w % factor == 0

    x = rearrange(
        x,
        "b (h fh) (w fw) c -> b h w (fh fw c)",
        b=b,
        c=c,
        h=h // factor,
        w=w // factor,
        fh=factor,
        fw=factor,
    )

    return x


def unsqueeze(x, factor=2):
    assert x.ndim == 4
    b, h, w, c = x.shape

    x = rearrange(
        x,
        "b h w (fh fw c) -> b (h fh) (w fw) c",
        b=b,
        c=(c // factor**2),
        h=h,
        w=w,
        fh=factor,
        fw=factor,
    )

    return x


class Squeeze2D(eq.Module):
    factor: eq.static_field()

    def __init__(self, factor: int):
        self.factor = factor

    def __call__(self, x, *, key):
        return self.transform(x)

    def transform(self, x):
        return squeeze(x, self.factor)

    def inverse(self, x):
        return unsqueeze(x, self.factor)

In [None]:
# def squeeze(x, factor=2):

#     assert x.ndim == 4
#     b, h, w, c = x.shape

#     assert h % factor == 0
#     assert w % factor == 0

#     x = rearrange(
#         x,
#         "b (h fh) (w fw) c -> b h w (fh fw c)",
#         b=b, c=c, h=h//factor, w=w//factor, fh=factor, fw=factor
#     )

#     return x

# def unsqueeze(x, factor=2):
#     assert x.ndim == 4
#     b, h, w, c = x.shape

#     x = rearrange(
#         x,
#         "b h w (fh fw c) -> b (h fh) (w fw) c",
#         b=b, c=(c//factor**2), h=h, w=w, fh=factor, fw=factor
#     )

#     return x

# def squeeze(x, factor=2):
#     assert x.ndim == 4
#     b, h, w, c = x.shape

#     assert h % factor == 0
#     assert w % factor == 0

#     x = jnp.reshape(x, (b, h//factor, factor, w//factor, factor, c))
#     x = jnp.transpose(x, axes=(0, 1, 3, 5, 2, 4))
#     x = jnp.reshape(x, (b, h//factor, w//factor, c * factor **2))

#     return x

# def unsqueeze(x, factor=2):
#     assert x.ndim == 4
#     b, h, w, c = x.shape

#     x = jnp.reshape(x, (b, h, w, c // factor ** 2, factor, factor))
#     x = jnp.transpose(x, axes=(0, 1, 4, 2, 5, 3))
#     x = jnp.reshape(x, (b, h*factor, w*factor, c // factor ** 2))
#     return x

In [None]:
reshape_layer = Squeeze2D(factor=2)

z = reshape_layer.transform(x)

x_ = reshape_layer.inverse(z)

np.testing.assert_array_equal(x, x_)

print(x.shape, z.shape, x_.shape)

In [None]:
img_squeeze = z.astype(np.uint8)

In [None]:
img_squeeze.min(), img_squeeze.max(), img_squeeze.dtype, img_squeeze.shape, type(
    img_squeeze
)

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

axs[0, 0].imshow(img_squeeze[0, ..., :3])
axs[0, 1].imshow(img_squeeze[0, ..., 3:6])
axs[1, 0].imshow(img_squeeze[0, ..., 6:9])
axs[1, 1].imshow(img_squeeze[0, ..., 9:12])
plt.show()

## DownSample

In [None]:
class IRev(eq.Module):
    kernel: jnp.ndarray
    input_shape: eq.static_field()

    def __init__(self, input_shape, key, n_channels: int = 3):

        b, c, h, w = input_shape
        kernel = np.ones((4, 1, 2, 2))
        kernel[1, 0, 0, 1] = -1
        kernel[1, 0, 1, 1] = -1

        kernel[2, 0, 1, 0] = -1
        kernel[2, 0, 1, 1] = -1

        kernel[3, 0, 1, 0] = -1
        kernel[3, 0, 0, 1] = -1
        kernel *= 0.5
        # kernel[0, 0, 0, 0] = 1
        # kernel[1, 0, 0, 1] = 1
        # kernel[2, 0, 1, 0] = 1
        # kernel[3, 0, 1, 1] = 1

        kernel = np.concatenate([kernel] * c, 0)
        self.kernel = jnp.asarray(kernel, dtype=jnp.float32)
        self.input_shape = input_shape

    def __call__(self, x, *, key=None):
        return x

    def forward(self, x):
        print("hi!:", self.input_shape, x.shape, self.kernel.shape)

        def fwd(x):
            dn = lax.conv_dimension_numbers(
                self.input_shape,  # only ndim matters, not shape
                self.kernel.shape,  # only ndim matters, not shape
                ("NCHW", "OIHW", "NCHW"),
            )  # the important bit

            return lax.conv_general_dilated(
                lhs=x,  # lhs = image tensor
                rhs=self.kernel,  # rhs = conv kernel tensor
                window_strides=(2, 2),  # window strides
                padding="VALID",  # padding mode
                lhs_dilation=(1, 1),  # lhs/image dilation
                rhs_dilation=(1, 1),  # rhs/kernel dilation
                dimension_numbers=dn,
                feature_group_count=3,
            )

        return fwd(x)

    def inverse(self, x):

        print("hi!:", self.input_shape, x.shape, self.kernel.shape)

        def fwd(x):
            dn = lax.conv_dimension_numbers(
                self.input_shape,  # only ndim matters, not shape
                self.kernel.shape,  # only ndim matters, not shape
                ("NCHW", "OIHW", "NCHW"),
            )  # the important bit

            return lax.conv_general_dilated(
                lhs=x,  # lhs = image tensor
                rhs=self.kernel,  # rhs = conv kernel tensor
                window_strides=(2, 2),  # window strides
                padding="VALID",  # padding mode
                lhs_dilation=(1, 1),  # lhs/image dilation
                rhs_dilation=(1, 1),  # rhs/kernel dilation
                dimension_numbers=dn,
                feature_group_count=3,
            )

        dummy_primal = ShapeDtypeStruct(self.input_shape, x.dtype)
        transpose = linear_transpose(fwd, dummy_primal)
        (z,) = transpose(x)
        return z

In [None]:
x_batch = rearrange(x, "b h w c -> b c h w")

layer = IRev(x_batch.shape, None)

z = layer.forward(x_batch.astype(jnp.float32))
# z = jax.vmap(layer.forward)(x.astype(jnp.float32))
x_batch_ = layer.inverse(z)

img_squeeze = rearrange(z, "b c h w ->b h w c").astype(np.uint8)

np.testing.assert_array_almost_equal(x_batch, x_batch_)

print(x_batch.shape, z.shape, x_batch_.shape, img_squeeze.shape)

In [None]:
img_squeeze.min(), img_squeeze.max()

In [None]:
img_squeeze

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

axs[0, 0].imshow(img_squeeze[0, ..., :3])
axs[0, 1].imshow(img_squeeze[0, ..., 3:6])
axs[1, 0].imshow(img_squeeze[0, ..., 6:9])
axs[1, 1].imshow(img_squeeze[0, ..., 9:12])
plt.show()

In [None]:
b, h, w, c = x.shape

kernel = np.ones((4, 1, 2, 2))
kernel[1, 0, 0, 1] = 1
kernel[1, 0, 0, 1] = 1

kernel[2, 0, 1, 0] = 1
kernel[2, 0, 1, 1] = 1

kernel = np.concatenate([kernel] * 3, 0).astype(np.float32)

# kernel[3, 0, 1, 0] = -1
# kernel[3, 0, 0, 1] = -1

In [None]:
b, h, w, c = x.shape

kernel = np.ones((4, 1, 2, 2))
kernel[1, 0, 0, 1] = -1
kernel[1, 0, 1, 1] = -1

kernel[2, 0, 1, 0] = -1
kernel[2, 0, 1, 1] = -1

kernel[3, 0, 1, 0] = -1
kernel[3, 0, 0, 1] = -1


kernel = np.concatenate([kernel] * 3, 0).astype(np.float32)
# kernel = jnp.asarray(kernel).transpose([1, 2, 3, 0])
# kernel = jnp.asarray(kernel).transpose([1,0,2,3])

kernel.shape

In [None]:
type(jnp.transpose(x[0], axes=(2, 0, 1))), jnp.transpose(x[0], axes=(2, 0, 1)).dtype

In [None]:
layer_forward = eq.nn.Conv2d(
    in_channels=3,
    out_channels=12,
    kernel_size=(2, 2),
    use_bias=False,
    stride=(2, 2),
    padding=0,
    groups=3,
    key=jax.random.PRNGKey(123),
)

# layer_forward.weight = kernel
img_squeeze = layer_forward(jnp.transpose(x[0], axes=(2, 0, 1)).astype(jnp.float32))

img_squeeze = rearrange(img_squeeze, "c h w -> 1 h w c").astype(np.uint8)
img_squeeze.shape, layer_forward.weight.shape

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

axs[0, 0].imshow(img_squeeze[0, ..., :3])
axs[0, 1].imshow(img_squeeze[0, ..., 3:6])
axs[1, 0].imshow(img_squeeze[0, ..., 6:9])
axs[1, 1].imshow(img_squeeze[0, ..., 9:12])
plt.show()

In [None]:
layer_forward.weight.shape

In [None]:
print(img.shape)
x_torch = x.transpose([0, 3, 1, 2]).astype(jnp.float32)
x_torch.shape

In [None]:
# NHWC layout
img = jnp.zeros((1, 200, 200, 3), dtype=jnp.float32)
for k in range(3):
    x = 30 + 60 * k
    y = 20 + 60 * k
    img = img.at[0, x : x + 10, y : y + 10, k].set(1.0)

print("out shape: ", img.shape, " <-- original shape")
print("Original Image:")
plt.imshow(img[0]);

In [None]:
img[0].min(), img[1].max()

In [None]:
# 2D kernel - HWIO layout
kernel = jnp.zeros((3, 3, 3, 3), dtype=jnp.float32)
kernel += jnp.array([[1, 1, 0], [1, 0, -1], [0, -1, -1]])[
    :, :, jnp.newaxis, jnp.newaxis
]

print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);

In [None]:
dn = lax.conv_dimension_numbers(
    img.shape,  # only ndim matters, not shape
    kernel.shape,  # only ndim matters, not shape
    ("NHWC", "HWIO", "NHWC"),
)  # the important bit
print(dn)

In [None]:
out = lax.conv_general_dilated(
    img,  # lhs = image tensor
    kernel,  # rhs = conv kernel tensor
    (2, 2),  # window strides
    "SAME",  # padding mode
    (1, 1),  # lhs/image dilation
    (1, 1),  # rhs/kernel dilation
    dn,
)  # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(5, 5))
print("First output channel:")
plt.imshow(np.array(out)[0, :, :, 0]);

In [None]:
out.min(), out.max()

### Transpose

In [None]:
kernel.shape

In [None]:
# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))

# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = kernel  # jnp.rot90(jnp.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))

img_ori = lax.conv_general_dilated(
    out,  # lhs = image tensor
    kernel,  # rhs = conv kernel tensor
    (1, 1),  # window strides
    padding,  # padding mode
    (2, 2),  # lhs/image dilation
    (1, 1),  # rhs/kernel dilation
    dn,
)  # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", img_ori.shape, "<-- larger than original!")
plt.figure(figsize=(5, 5))
print("First output channel:")
print("Original Image:")
plt.imshow(img_ori[0].astype(np.uint8));

In [None]:
img_ori.min(), img_ori.max()

In [None]:
img_ori[0].min(), img_ori[0].max(), img[0].min(), img[0].max()

In [None]:
img_ori.shape
from einops import rearrange

In [None]:
b, h, w, c = x.shape

kernel = np.ones((4, 1, 2, 2))
kernel[1, 0, 0, 1] = -1
kernel[1, 0, 1, 1] = -1

kernel[2, 0, 1, 0] = -1
kernel[2, 0, 1, 1] = -1

kernel[3, 0, 1, 0] = -1
kernel[3, 0, 0, 1] = -1
kernel *= 0.5

kernel = np.concatenate([kernel] * 3, 0)
# kernel = jnp.asarray(kernel).transpose([1, 2, 3, 0])
# kernel = jnp.asarray(kernel).transpose([1,0,2,3])
# kernel = repeat(kernel, "A 1 B C -> A 3 B C")
kernel = kernel.astype(np.float32)
kernel.shape

In [None]:
print(img.shape)
x_torch = x.transpose([0, 3, 1, 2]).astype(jnp.float32)
x_torch.shape

In [None]:
from jax import lax

dn = lax.conv_dimension_numbers(
    x.shape,  # only ndim matters, not shape
    kernel.shape,  # only ndim matters, not shape
    ("NCHW", "OIHW", "NCHW"),
)  # the important bit
print(dn)

In [None]:
x_torch.shape, kernel.shape

In [None]:
x_squeezed = lax.conv(
    lhs=x_torch,  # lhs = image tensor
    rhs=kernel,  # rhs = conv kernel tensor
    window_strides=(2, 2),  # window strides
    padding="SAME",  # padding mode
)  # dimension_numbers = lhs, rhs, out dimension permutation

In [None]:
x_squeezed.shape

In [None]:
x_squeezed = lax.conv_general_dilated(
    lhs=x_torch,  # lhs = image tensor
    rhs=kernel,  # rhs = conv kernel tensor
    window_strides=(2, 2),  # window strides
    padding="SAME",  # padding mode
    lhs_dilation=(1, 1),  # lhs/image dilation
    rhs_dilation=(1, 1),  # rhs/kernel dilation
    dimension_numbers=dn,
    feature_group_count=3,
)  # dimension_numbers = lhs, rhs, out dimension permutation

In [None]:
x_squeezed.min(), x_squeezed.max()

In [None]:
from sklearn.preprocessing import MinMaxScaler

In [None]:
x_squeezed = (
    MinMaxScaler((0.0, 255.0)).fit_transform(x_squeezed.ravel()[:, None]).ravel()
)
x_squeezed = x_squeezed.reshape((1, 12, 256, 256))

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))

axs[0, 0].imshow(x_squeezed[0, 0::4, ...].transpose([1, 2, 0]).astype(np.uint8))
axs[0, 1].imshow(x_squeezed[0, 1::4, ...].transpose([1, 2, 0]).astype(np.uint8))
axs[1, 0].imshow(x_squeezed[0, 2::4, ...].transpose([1, 2, 0]).astype(np.uint8))
axs[1, 1].imshow(x_squeezed[0, 3::4, ...].transpose([1, 2, 0]).astype(np.uint8))
plt.show()

In [None]:
x_torch.shape, x_squeezed.shape,

In [None]:
kernel.shape, kernel.transpose([1, 0, 2, 3]).shape

In [None]:
# np.testing.assert_array_equal(kernel_rot, kernel.transpose([1, 0, 2, 3]))

In [None]:
dn = lax.conv_dimension_numbers(
    x.shape,  # only ndim matters, not shape
    kernel.shape,  # only ndim matters, not shape
    ("NCHW", "OIHW", "NCHW"),
)  # the important bit
print(dn)

# kernel_rot = jnp.rot90(jnp.rot90(kernel, axes=(2,3)), axes=(2,3))
padding = ((1, 1), (1, 1))
# padding = "SAME"
x_torch_ori = lax.conv_general_dilated(
    lhs=x_squeezed,  # lhs = image tensor
    rhs=kernel,  # rhs = conv kernel tensor
    window_strides=(1, 1),  # window strides
    padding=padding,  # padding mode
    lhs_dilation=(2, 2),  # lhs/image dilation
    rhs_dilation=(1, 1),  # rhs/kernel dilation
    dimension_numbers=dn,
    feature_group_count=3,
)  # dimension_numbers = lhs, rhs, out dimension permutation

# x_torch_ori = lax.conv_general_dilated(
#     lhs=x_squeezed,    # lhs = image tensor
#     rhs=kernel, # rhs = conv kernel tensor
#     window_strides=(1,1),  # window strides
#     padding=((0, 0), (0, 0)), # padding mode
#     lhs_dilation=(2,2),  # lhs/image dilation
#     rhs_dilation=(1,1),  # rhs/kernel dilation
#     dimension_numbers=dn,
#     feature_group_count=3
# )     # dimension_numbers = lhs, rhs, out dimension permutation

In [None]:
x_torch_ori.shape,

In [None]:
x_torch.min(), x_torch.max(), x_torch_ori.min(), x_torch_ori.max()

In [None]:
fig, ax = plt.subplots()

ax.imshow(x_torch_ori.squeeze().transpose([1, 2, 0]))

plt.show()

In [None]:
x_torch_ori.shape

In [None]:
out_inv = lax.conv_transpose(
    out,  # lhs = image tensor
    kernel,  # rhs = conv kernel tensor
    strides=(2, 2),  # window strides
    padding="SAME",  # padding mode
    rhs_dilation=(1, 1),  # lhs/image dilation
    dimension_numbers=dn,
)  # dimension_numbers = lhs, rhs, out dimension permutation
out_inv.shape

In [None]:
out = lax.conv(
    lhs=x_torch,  # lhs = NCHW image tensor
    rhs=kernel,  # rhs = OIHW conv kernel tensor
    window_strides=(1, 1),  # window strides
    padding="SAME",
    # dn
)  # padding mode