In [None]:
from init_notebook import *

In [None]:
image = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/Unternehmenskleidung.jpg"
).convert("RGB"))
image = VF.crop(image, 110, 140, 64, 64)
print(image.shape)
VF.to_pil_image(image)

In [None]:
conv = nn.Conv2d(3, 3, 3)
y = conv(image)
print(y.shape)
display(VF.to_pil_image(y.clamp(0, 1)))

In [None]:
conv = nn.Conv2d(3, 3, 8, padding=3, stride=16)
y = conv(image)
print(y.shape)
display(VF.to_pil_image(y.clamp(0, 1)))


In [None]:
class PatchAE(nn.Module):

    def __init__(
            self,
            shape: Tuple[int, int, int],
            patch_size: int,
            code_size: int,
            kernel_size: Optional[int] = None,
            channels: Optional[int] = None,
            num_residuals: int = 0,
            batch_norm: bool = False,
            activation: Union[None, str, Callable] = None,
    ):
        super().__init__()
        if channels is None:
            channels = shape[0] * patch_size ** 2
        if kernel_size is None:
            kernel_size = patch_size
        padding = int(math.floor(kernel_size / 2))

        self.encoder = nn.Sequential()
        self.encoder.append(
            nn.Conv2d(shape[0], channels, kernel_size, padding=padding, stride=patch_size)
        )

        with torch.no_grad():
            img = torch.zeros(1, *shape)
            encoded_shape = self.encoder(img).shape[1:]
        
        if batch_norm: 
            self.encoder.append(nn.BatchNorm2d(encoded_shape[0]))
        
        if activation is not None:
            self.encoder.append(activation_to_module(activation))

        for i in range(num_residuals):
            self.encoder.append(ResidualAdd(nn.Conv2d(encoded_shape[0], encoded_shape[0], 3, padding=1)))
            if activation is not None:
                self.encoder.append(activation_to_module(activation))        

        self.encoder.append(nn.Flatten(-3))
        self.encoder.append(nn.Linear(math.prod(encoded_shape), code_size))

        self.decoder = nn.Sequential()
        self.decoder.append(nn.Linear(code_size, math.prod(encoded_shape), code_size))
        if activation is not None:
            self.decoder.append(activation_to_module(activation))        
        self.decoder.append(Reshape(encoded_shape))

        for i in range(num_residuals):
            self.decoder.append(ResidualAdd(nn.Conv2d(encoded_shape[0], encoded_shape[0], 3, padding=1)))
            if activation is not None:
                self.decoder.append(activation_to_module(activation))        
        
        self.decoder.append(
            nn.ConvTranspose2d(channels, shape[0], kernel_size, padding=padding, stride=patch_size)
        )
        self.decoder.append(nn.Sigmoid())

    def forward(self, x):
        return self.decoder(self.encoder(x))

ae = PatchAE(
    shape=image.shape,
    patch_size=16,
    code_size=128,
    kernel_size=16,
    activation="gelu",
    num_residuals=2,
)
#display(ae)
print(f"params: {num_module_parameters(ae):,}")

c = ae.encoder(image.unsqueeze(0))
y = ae.decoder(c).squeeze(0)
print(f"{image.shape} -> {c.shape} -> {y.shape}, RATIO: {math.prod(image.shape) / math.prod(c.shape)}")

display(VF.to_pil_image(y[:3].clamp(0, 1)))

display(ae)

In [None]:
class ResConv(nn.Module):

    def __init__(
            self,
            num_in: int,
            num_out: int,
            kernel_size: int = 3,
            activation: Union[None, str, Callable] = None,
            conv_class: Type[nn.Module] = nn.Conv2d,
    ):
        super().__init__()
        padding = int(math.floor(kernel_size / 2))
        self.conv = conv_class(num_in, num_out, kernel_size, padding=padding)
        self.residual = None
        if num_in != num_out:
            self.residual = conv_class(num_in, num_out, 1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        r = x
        if self.residual is not None:
            r = self.residual(r)
        y = self.conv(x)
        return y + r

m = ResConv(10, 11)
print(f"params: {num_module_parameters(m):,}")
print(m(torch.ones(1, 10, 16, 16)).shape)
display(m)

In [None]:
nn.Conv2d(3, 8, 1).weight.shape