In [None]:
from init_notebook import *
from src.util.module import dump_module_stacktrace

In [None]:
60_000 * 16

In [None]:
class Patchify(nn.Module):
    def __init__(self, patch_size: int):
        super().__init__()
        self.patch_size = patch_size

    def extra_repr(self):
        return f"patch_size={self.patch_size}"
        
    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        B, C, H, W = batch.shape
        assert W % self.patch_size == 0, f"width must be divisible by patch_size, got {W} / {self.patch_size}"
        assert H % self.patch_size == 0, f"height must be divisible by patch_size, got {H} / {self.patch_size}"

        return (
            batch.permute(0, 2, 3, 1)                     # B, H, W, C
            .unfold(1, self.patch_size, self.patch_size)  # B, H/s, W, C, s
            .unfold(2, self.patch_size, self.patch_size)  # B, H/s, W/s, C, s, s
        )

class Unpatchify(nn.Module):
    def __init__(self, patch_size: int):
        super().__init__()
        self.patch_size = patch_size

    def extra_repr(self):
        return f"patch_size={self.patch_size}"

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        B, Y, X, C, H, W = batch.shape

        return (
            batch.permute(0, 1, 3, 4, 2, 5)               # B, Y, C, H, X, W 
            .reshape(B, Y, C, H, W * X)
            .permute(0, 2, 1, 3, 4)                       # B, C, Y, H, W*X
            .reshape(B, C, H * Y, W * X)
        )

inp = torch.ones(2, 3, 64, 64)
patches = Patchify(8)(inp)
shape = patches.shape
patches = patches.reshape(math.prod(shape[:3]), -1)
patches.shape

In [None]:
class MLPLayer(nn.Module):

    def __init__(
            self,
            channels_in: int,
            channels_out: int,
            activation: Union[None, str, Callable] = None,
            bias: bool = True,
            residual: bool = True,
    ):
        super().__init__()
        self._residual = residual and channels_in == channels_out

        self.module = nn.Linear(channels_in, channels_out, bias=bias)
            
        self.act = activation_to_module(activation)

    def extra_repr(self):
        return f"residual={self._residual}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.module(x)
        if self.act is not None:
            y = self.act(y)
        if self._residual:
            y = y + x
        return y


class MLPMixerLayer(nn.Module):

    def __init__(
            self,
            channels: int,
            activation: Union[None, str, Callable] = None,
            bias: bool = True,
            residual: bool = True,
    ):
        super().__init__()
        self._residual = residual
        self.channels = channels

        self.module = nn.Conv1d(channels, channels, kernel_size=1, bias=bias)
            
        self.act = activation_to_module(activation)

    def extra_repr(self):
        return f"residual={self._residual}"

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shape = x.shape
        x = x.reshape(x.shape[0] // self.channels, self.channels, -1) 
        y = self.module(x)
        if self.act is not None:
            y = self.act(y)
        if self._residual:
            y = y + x
        return y.reshape(shape)

        
class MixerMLP(nn.Module):

    def __init__(
            self,
            image_shape: Tuple[int, int, int],
            patch_size: int,
            hidden_channels: Tuple[int, ...],
            mixer_at: Tuple[int, ...],
    ):
        assert image_shape[-1] % patch_size == 0, f"width must be divisible by patch_size, got {image_shape[-1]} / {patch_size}"
        assert image_shape[-2] % patch_size == 0, f"height must be divisible by patch_size, got {image_shape[-2]} / {patch_size}"
        
        super().__init__()
        self.image_shape = image_shape
        self.hidden_channels = hidden_channels
        self.patch_size = patch_size
        self.patch_dim = image_shape[0] * (self.patch_size ** 2)
        patches_shape = (image_shape[1] // self.patch_size, image_shape[2] // self.patch_size)
        self._last_patch_shape = None
        # self.patcher = nn.Conv2d(in_channels, hidden_channels, kernel_size=patch_size, stride=patch_size)
        self.patchify = Patchify(patch_size)
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.unpatchify = Unpatchify(patch_size)

        channels = (self.patch_dim, *self.hidden_channels)
        for i, (ch, next_ch) in enumerate(zip(channels, channels[1:])):
            self.encoder.append(MLPLayer(ch, next_ch))
            self.decoder.insert(0, MLPLayer(next_ch, ch))
            if i + 1 in mixer_at:
                self.encoder.append(MLPMixerLayer(math.prod(patches_shape)))
                self.decoder.insert(0, MLPMixerLayer(math.prod(patches_shape)))

        self.encoder.append(MLPLayer(next_ch * math.prod(patches_shape), next_ch))
        self.decoder.insert(0, MLPLayer(next_ch, next_ch * math.prod(patches_shape)))
        
    def encode(self, batch: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
        B, C, H, W = batch.shape
        assert (C, H, W) == self.image_shape, f"Expected image shape {self.image_shape}, got {(C, H, W)}"

        patch_batch = self.patchify(batch)
        self._last_patch_shape = patch_shape = patch_batch.shape
        y = patch_batch.reshape(math.prod(patch_shape[:3]), -1)  # B*X*Y, C*S*S
        # print(y.shape, self._patch_batch_shape, self.encoder)
        for i, module in enumerate(self.encoder):
            if i == len(self.encoder) - 1:
                y = y.reshape(B, -1)
            y = module(y)
        return y

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        assert self._last_patch_shape is not None, "Must call encode() before decode()"
        patch_shape = self._last_patch_shape

        y = x
        for i, module in enumerate(self.decoder):
            y = module(y)
            if i == 0:
                y = y.reshape(math.prod(patch_shape[:3]), -1)

        y = y.reshape(patch_shape)
        y = self.unpatchify(y)
        return y

    def forward(self, batch: torch.Tensor) -> torch.Tensor:
        y = self.encode(batch)
        y = self.decode(y)
        return y
        
module = MixerMLP((3, 8, 12), 4, [128, 16], mixer_at=[1])
# print("patcher:", module.patcher.weight.shape)
inp = torch.ones(2, 3, 8, 12)
outp = dump_module_stacktrace(module, inp)
print("outp", outp.shape)
display(module)
display(outp)

In [None]:
inp = torch.arange(2*6*48).reshape(2, 6, 48).float()
mixer = nn.Conv1d(6, 6, kernel_size=1)
outp = mixer(inp)
print(outp.shape)
outp

In [None]:
[torch.split(o, 16, -2) for o in torch.split(inp, 16, -1)]

In [None]:
batch = torch.arange(0, 2*3*8*8).reshape(2, 3, 8, 8)
display(batch)
u = batch.unfold(-2, 4, 4).unfold(-2, 4, 4)
display(u.shape)
display(u)
#inp.unfold(-2, 16, 16).shape #.unfold(-3, 16, 16).shape

In [None]:
batch = torch.arange(0, 2*3*8*8).reshape(2, 3, 8, 8)
display(batch)
p = nn.Conv2d(3, 3, kernel_size=1, stride=4)(batch.float())
display(p.shape)
display(p)

In [None]:
ds = BaseDataset(TensorDataset(torch.load(
    "../datasets/colorful-uint-64x64-340k.pt"
)))

images = [i[0] for i in ds.limit(64)]
VF.to_pil_image(make_grid(images))