In [None]:
from init_notebook import *

In [None]:
class ConvMixer(nn.Module):

    def __init__(
            self,
            channels: int,
            factor: int,
            num_layers: int = 2,
            kernel_size: int = 3,
            activation: Union[None, str, Callable] = None,
            batch_norm: bool = False,
    ):
        super().__init__()

        padding = int(math.floor(kernel_size / 2))
        chan_mult = factor ** 2
        cur_channels = channels
        
        self.encoder = nn.Sequential()
        for i in range(num_layers):
            self.encoder.add_module(f"unshuffle_{i + 1}", nn.PixelUnshuffle(factor))
            cur_channels = cur_channels * chan_mult
            self.encoder.add_module(f"conv_{i + 1}", nn.Conv2d(cur_channels, cur_channels, kernel_size, padding=padding, groups=cur_channels))
            if activation is not None:
                self.encoder.add_module(f"act_{i + 1}", activation_to_module(activation))
            if batch_norm and i < num_layers - 1:
                self.encoder.add_module(f"norm_{i + 1}", nn.BatchNorm2d(cur_channels))

        self.decoder = nn.Sequential()
        for i in range(num_layers):
            self.decoder.add_module(f"shuffle_{i + 1}", nn.PixelShuffle(factor))
            cur_channels = cur_channels // chan_mult
            self.decoder.add_module(f"conv_{i + 1}", nn.Conv2d(cur_channels, cur_channels, kernel_size, padding=padding, groups=cur_channels))
            if activation is not None and i < num_layers - 1:
                self.decoder.add_module(f"act_{i + 1}", activation_to_module(activation))
            if batch_norm and i < num_layers - 1:
                self.decoder.add_module(f"norm_{i + 1}", nn.BatchNorm2d(cur_channels))
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.encoder(x)) + x


m = ConvMixer(3, 8, 1, activation="gelu", batch_norm=True, kernel_size=5)
print(f"params: {num_module_parameters(m):,}")
for size in (32, 64, 128):
    inp = torch.zeros(1, 3, size, size)
    inp[..., :, size//2, size//2] = 100
    outp = m(inp)
    print(inp.shape, "->", outp.shape)
    display(VF.to_pil_image(outp[0].clamp(0, 1)))

display(m)


In [None]:
class ConvStrideMixer(nn.Module):

    def __init__(
            self,
            channels: int,
            stride: int,
            hidden_channels: Optional[int] = None,
            activation: Union[None, str, Callable] = None,
            residual: bool = True,
    ):
        super().__init__()
        self._channels = channels
        self._hidden_channels = channels * stride if hidden_channels is None else hidden_channels
        self._stride = stride
        self._residual = residual
        
        self.encoder = nn.Sequential()
        self.encoder.add_module("patch", nn.Conv2d(channels, self._hidden_channels, stride * 2, stride=stride))
        if activation is not None:
            self.encoder.add_module("act", activation_to_module(activation))
        self.encoder.add_module("conv", nn.Conv2d(self._hidden_channels, self._hidden_channels, 1))
        self.encoder.add_module("unpatch", nn.ConvTranspose2d(self._hidden_channels, channels, stride * 2, stride=stride))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.encoder(x)
        if self._residual:
            y = y + x
        return y

    def extra_repr(self):
        return f"channels={self._channels}, stride={self._stride}, residual={self._residual}"


m = ConvStrideMixer(3, 16)
print(f"params: {num_module_parameters(m):,}")

for size in (32, 64, 128):
    inp = torch.zeros(1, 3, size, size)
    inp[..., :, size//2, size//2] = 100
    outp = m(inp)
    print(inp.shape, "->", outp.shape)
    outp = VF.pad(outp, 1, fill=.5)
    display(VF.to_pil_image(resize(outp[0, :3].clamp(0, 1), 3)))

display(m)

In [None]:
class ConvDilationMixer(nn.Module):

    def __init__(
            self,
            channels: int,
            dilations: Tuple[int, ...] = (6, 5, 3),
            hidden_channels: Optional[int] = None,
            activation: Union[None, str, Callable] = None,
            residual: bool = True,
    ):
        super().__init__()
        self._channels = channels
        self._dilations = tuple(dilations)
        self._hidden_channels = channels if hidden_channels is None else hidden_channels
        self._residual = residual
        
        self.encoder = nn.Sequential()
        ch = channels
        next_ch = self._hidden_channels
        for i, dil in enumerate(dilations):
            self.encoder.add_module(f"conv_{i+1}", nn.Conv2d(ch, next_ch, 3, padding=dil, dilation=dil))
            if activation is not None:
                self.encoder.add_module(f"act_{i+1}", activation_to_module(activation))
            ch = next_ch
        self.encoder.add_module(f"conv_{i+2}", nn.Conv2d(ch, channels, 3, padding=1))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.encoder(x)
        if self._residual:
            y = y + x
        return y

    def extra_repr(self):
        return f"channels={self._channels}, dilations={self._dilations}, residual={self._residual}"

#m = nn.Sequential(
#    nn.Conv2d(3, 3, 3, padding=6, dilation=6),
#    nn.Conv2d(3, 3, 3, padding=5, dilation=5),
#    nn.Conv2d(3, 3, 3, padding=3, dilation=3),
#    nn.Conv2d(3, 3, 3, padding=1, dilation=1),
#)
m = ConvDilationMixer(3, hidden_channels=10)
print(f"params: {num_module_parameters(m):,}")

for size in (32, 64, 128):
    inp = torch.zeros(1, 3, size, size)
    inp[..., :, size//2, size//2] = 100
    inp[..., :, :size//2, size//2] = 100
    outp = m(inp)
    print(inp.shape, "->", outp.shape)
    outp = VF.pad(outp, 1, fill=.5)
    display(VF.to_pil_image(resize(outp[0, :3].clamp(0, 1), 3)))

display(m)