In [None]:
from init_notebook import *

In [None]:
conv = nn.Conv2d(3, 16, 5, padding=2)
print("W:", conv.weight.shape)
inp = torch.ones(1, 3, 32, 32)
outp = conv(inp)
print(inp.shape, "->", outp.shape)

In [None]:
conv = nn.Conv2d(4, 16, 5, padding=2, groups=4)
print("W:", conv.weight.shape)
inp = torch.ones(1, 4, 32, 32)
outp = conv(inp)
print(inp.shape, "->", outp.shape)

In [None]:
conv = nn.Conv2d(4, 16, 5, padding=2)
print("W:", conv.weight.shape)
inp = torch.ones(1, 4, 32, 32)
outp = conv(inp)
print(inp.shape, "->", outp.shape)

- more channels at shallow layers
- small kernel size in shallow layers and larger kernel size in deeper layers
- batch-norm before conv, or after conv, not after activation

In [None]:
class RescaleConv(nn.Module):

    def __init__(
            self,
            factor: int,
            channels: int,
            kernel_size: int = 3,
            padding: int = 1,
            activation: Union[None, str, Callable] = None,
            batch_norm: bool = False,
            transpose: bool = False,
    ):
        super().__init__()

        self._transpose = transpose
        chan_mult = factor ** 2
        if not transpose:
            self.unshuffle = nn.PixelUnshuffle(factor)
            self.conv = nn.Conv2d(channels * chan_mult, channels, kernel_size, padding=padding)
        else:
            self.conv = nn.ConvTranspose2d(channels, channels * chan_mult, kernel_size, padding=padding)
            self.shuffle = nn.PixelShuffle(factor)
            
        self.act = activation_to_module(activation)
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self._transpose:
            x = self.unshuffle(x)
            x = self.conv(x)
        else:
            x = self.conv(x)
            x = self.shuffle(x)
        if self.act is not None:
            x = self.act(x)
        return x

    def extra_repr(self):
        return f"transpose={self._transpose}"
        
m = RescaleConv(2, 16)
mt = RescaleConv(2, 16, transpose=True)
print(f"params: {num_module_parameters(m):,} {num_module_parameters(mt):,}")
for size in (32, 64, 128):
    inp = torch.ones(1, 16, size, size)
    outp = m(inp)
    print(inp.shape, "->", outp.shape, "ratio:", math.prod(inp.shape) / math.prod(outp.shape))
    recon = mt(outp)
    assert inp.shape == recon.shape, f"{inp.shape} != {recon.shape}"
display(m)
display(mt)

In [None]:
class ResConvStack(nn.Module):

    def __init__(
            self,
            channels: Tuple[int, ...],
            kernel_size: Union[int, Tuple[int, ...]] = 3,
            padding: Union[int, Tuple[int, ...]] = 1,
            activation: Union[Union[None, str, Callable], Tuple[Union[None, str, Callable], ...]] = None,
            batch_norm: Union[bool, Tuple[bool, ...]] = False,
            depth: Union[int, Tuple[int, ...]] = 0,
            scale: Union[int, Tuple[int, ...]] = 1,
            transpose: bool = False,
    ):
        super().__init__()

        self._channels = tuple(channels)
        assert len(self._channels) >= 2, f"Got {len(self._channels)}"
        num_layers = len(self._channels) - 1
        self._kernel_size = param_make_tuple(kernel_size, num_layers)
        self._padding = param_make_tuple(padding, num_layers)
        self._activation = param_make_tuple(activation, num_layers)
        self._batch_norm = param_make_tuple(batch_norm, num_layers)
        self._depth = param_make_tuple(depth, num_layers)
        self._scale = param_make_tuple(scale, num_layers)
        self._transpose = transpose
        
        self.layers = nn.Sequential()
        chan_mult = 1
        for idx in range(num_layers):
            ch1 = self._channels[idx]
            ch2 = self._channels[idx + 1]

            if self._batch_norm[idx]:
                self.layers.add_module(f"layer_{idx+1}_norm", nn.BatchNorm2d(ch1))
            if self._scale[idx] > 1:
                self.layers.add_module(
                    f"layer_{idx+1}_scale",
                    RescaleConv(self._scale[idx], ch1, transpose=transpose)
                )
            self.layers.add_module(
                f"layer_{idx+1}_conv", 
                (nn.ConvTranspose2d if transpose else nn.Conv2d)(
                    ch1, ch2, self._kernel_size[idx], padding=self._padding[idx]
                )
            )
            if self._activation[idx] is not None:
                self.layers.add_module(f"layer_{idx+1}_act", activation_to_module(self._activation[idx]))
            for i in range(self._depth[idx]):
                self.layers.add_module(
                    f"layer_{idx+1}_res_{i+1}",
                    ResidualAdd(
                        nn.Conv2d(ch2, ch2, self._kernel_size[idx], padding=int(math.floor(self._kernel_size[idx] / 2)))
                    )
                )
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

    def extra_repr(self):
        return (
            f"channels={self._channels}, depth={self._depth}, scale={self._scale}, kernel_size={self._kernel_size}"
            f",\nbatch_norm={self._batch_norm}, activation={self._activation}, "
        )

m = ResConvStack(
    channels=(3, 64, 48, 32, 16, 3), 
    kernel_size=3, #(3, 5, 9, 9, 9),
    padding=1,
    scale=(1, 2, 2, 2, 1),
    batch_norm=True, activation="gelu", depth=2
)
mt = ResConvStack(
    channels=(3, 64, 48, 32, 16, 3), 
    scale=(1, 2, 2, 2, 1),
    transpose=True,
)
#mt.register_forward_pre_hook(lambda model, x: print("X", x[0].shape))
print(f"params: {num_module_parameters(m):,} {num_module_parameters(mt):,}")
for size in (32, 64, 128):
    inp = torch.ones(1, 3, size, size)
    outp = m(inp)
    print(inp.shape, "->", outp.shape, "ratio:", math.prod(inp.shape) / math.prod(outp.shape))
    recon = mt(outp)
    assert inp.shape == recon.shape, f"{inp.shape} != {recon.shape}"
display(m)

In [None]:
m = ResConvStack(
    channels=(3, 64, 48, 32, 16, 3), 
    kernel_size=3, #(3, 5, 9, 9, 9),
    padding=1,
    scale=(1, 2, 2, 2, 1),
    batch_norm=True, activation="gelu", depth=2
)
mt = ResConvStack(
    channels=(3, 64, 48, 32, 16, 3), 
    scale=(1, 2, 2, 2, 1),
    transpose=True,
)
#mt.register_forward_pre_hook(lambda model, x: print("X", x[0].shape))
print(f"params: {num_module_parameters(m):,} {num_module_parameters(mt):,}")
for size in (32, 64, 128):
    inp = torch.ones(1, 3, size, size)
    outp = m(inp)
    print(inp.shape, "->", outp.shape, "ratio:", math.prod(inp.shape) / math.prod(outp.shape))
    recon = mt(outp)
    assert inp.shape == recon.shape, f"{inp.shape} != {recon.shape}"
display(m)