In [None]:
from init_notebook import *
import re

In [None]:
full_image = VF.to_tensor(PIL.Image.open(
    "/home/bergi/Pictures/Unternehmenskleidung.jpg"
).convert("RGB"))
full_image = VF.crop(full_image, 0, 0, 256, 256)
image = VF.crop(full_image, 110, 140, 32, 32)
print(image.shape)
display(VF.to_pil_image(image))
display(VF.to_pil_image(full_image))

In [None]:
conv = nn.Conv2d(3, 3, 3)
y = conv(image)
print(y.shape)
display(VF.to_pil_image(y.clamp(0, 1)))

In [None]:
conv = nn.Conv2d(3, 3, 8, padding=3, stride=16)
y = conv(image)
print(y.shape)
display(VF.to_pil_image(y.clamp(0, 1)))


In [None]:
class ResConv(nn.Module):

    def __init__(
            self,
            num_in: int,
            num_out: int,
            kernel_size: int = 3,
            activation: Union[None, str, Callable] = None,
            groups: int = 1,
            residual: str = "map",  # "add", "map"
            transpose: bool = False,
    ):
        super().__init__()

        conv_class = nn.Conv2d
        if transpose:
            conv_class = nn.ConvTranspose2d

        self.residual = None
        self.residual_mode = residual
        self.transpose = transpose
        self.num_in = num_in
        self.num_out = num_out
        if num_in != num_out:
            if self.residual_mode == "map":
                self.residual = conv_class(num_in, num_out, 1, bias=False, groups=groups)

        padding = int(math.floor(kernel_size / 2))
        self.conv = conv_class(num_in, num_out, kernel_size, padding=padding, groups=groups)
        self.act = activation_to_module(activation)

    def forward(self, x):
        B, C, H, W = x.shape
        r = x
        if self.residual is not None:
            r = self.residual(r)                    
                    
        y = self.conv(x)
        if self.act is not None:
            y = self.act(y)

        if self.residual_mode == "add":
            if self.num_in < self.num_out:
                r = r.repeat(1, self.num_out // self.num_in + 1, 1, 1)[:, :self.num_out]
                #idx = 0
                #while idx < self.num_out:
                #    idx2 = idx + self.num_in
                #    rsize = self.num_out - idx
                #    #print(r.shape, y[..., idx: idx2, :, :].shape, r[..., :rsize, :, :].shape) 
                #    y[..., idx: idx2, :, :] = y[..., idx: idx2, :, :] + r[..., :rsize, :, :]
                #    idx += self.num_in
            elif self.num_in > self.num_out:
                idx = 0
                while idx < self.num_in:
                    idx2 = idx + self.num_out
                    ysize = self.num_in - idx
                    y[..., :ysize, :, :] = y[..., :ysize, :, :] + r[..., idx: idx2, :, :] 
                    idx += self.num_out
                return y
                
        return y + r

m = ResConv(10, 11, activation="gelu")
print(f"params: {num_module_parameters(m):,}")
print(m(torch.ones(1, m.num_in, 16, 16)).shape)
display(m)
m = ResConv(33, 10)
print(f"params: {num_module_parameters(m):,}")
print(m(torch.ones(1, m.num_in, 16, 16)).shape)
display(m)

In [None]:
class ResidualScriptedAE(nn.Module):

    COMMANDS = {
        "ch": re.compile(r"ch\s*([*/=])\s*(\d+)"),
        "down": re.compile(r"down"),
    }

    def __init__(
            self,
            channels: int,
            script: str,
            kernel_size: int = 1,
            activation: Union[None, str, Callable] = None,
            residual: str = "map",
            groups: int = 1,
    ):
        super().__init__()
        self.encoder = nn.Sequential()
        self.decoder = nn.Sequential()
        self.script = "|".join(filter(bool, (l.strip() for l in script.splitlines())))

        ch = channels
        for line in self.script.split("|"):
            if "#" in line:
                line = line[:line.index("#")]
            line = line.strip()
            if not line or line.startswith("#"):
                continue

            match, args = None, None
            for cmd, regex in self.COMMANDS.items():
                match = regex.match(line)
                if match:
                    args = list(match.groups())
                    break

            if match is None:
                raise SyntaxError(f"Could not parse line `{line}`")

            if cmd == "ch":
                if args[0] == "=":
                    new_ch = int(eval(args[1]))
                else:
                    new_ch = int(eval(f"{ch} {args[0]} {args[1]}"))
                groups_param = 1
                if ch / groups == ch // groups and new_ch / groups == new_ch // groups:
                    groups_param=groups
                self.encoder.append(ResConv(ch, new_ch, kernel_size=kernel_size, activation=activation, groups=groups_param, residual=residual))
                self.decoder.insert(0, ResConv(new_ch, ch, kernel_size=kernel_size, activation=activation, groups=groups_param, residual=residual, transpose=True))
                ch = new_ch

            elif cmd == "down":
                self.encoder.append(nn.PixelUnshuffle(2))
                self.decoder.insert(0, nn.PixelShuffle(2))
                ch = ch * 4

        self.decoder.append(nn.Sigmoid())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decoder(self.encoder(x))

    def extra_repr(self) -> str:
        return f"script={repr(self.script)}"

    def debug_forward(self, x: torch.Tensor, file=None) -> torch.Tensor:
        print(f"encoding {x.shape}", file=file)
        for idx, layer in enumerate(self.encoder):
            x = layer(x)
            print(f"{type(layer).__name__:20} -> {math.prod(x.shape):10} = {tuple(x.shape)}", file=file)
        print(f"decoding {x.shape}", file=file)
        for idx, layer in enumerate(self.decoder):
            x = layer(x)
            print(f"{type(layer).__name__:20} -> {math.prod(x.shape):10} = {tuple(x.shape)}", file=file)
        return x
        
m = ResidualScriptedAE(
    channels=3,
    kernel_size=3,
    #groups=8,
    #script="ch=32|ch*2|ch*2|down|ch/4|down|ch/4|down|ch/4|ch/2|ch/2|ch/2"
    script="""
        ch=32
        ch*1
        ch*1
        ch*1
        ch*1
        ch*1
        ch*1
        down
        ch/4
        down
        ch/4
        down
        ch=4
    """
)
print(f"params: {num_module_parameters(m):,}")
print(f"script: \"{m.script}\"")
enc = m.encoder(image.unsqueeze(0)).squeeze(0)
print("RATIO:", math.prod(image.shape) / math.prod(enc.shape), image.shape, "->", enc.shape)
m.debug_forward(image.unsqueeze(0))
display(m)
assert m.decoder(enc.unsqueeze(0)).squeeze(0).shape == image.shape


In [None]:
data = image.unsqueeze(0).repeat(16, 1, 1, 1)
for i in tqdm(range(100)):
    m(data)

In [None]:
model = ResidualScriptedAE(
    channels=3,
    kernel_size=5,
    activation="gelu",
    script="ch=32|ch*1|ch*1|ch*1|down|ch*1|ch*1|ch*1|down|ch/2|ch/2|ch/2|down|ch/2|ch=4",
)
snapshot = torch.load("../checkpoints/ae/resscriptae-ks-5_act-gelu_script-ch32ch1ch1ch1downch1ch1ch1/best.pt")
model.load_state_dict(snapshot["state_dict"])

In [None]:
enc = model.encoder(full_image.unsqueeze(0))
print(enc.shape)
display(VF.to_pil_image(enc[0, :3]))
recon = model.decoder(enc)[0]
display(VF.to_pil_image(recon))

In [None]:
VF.to_pil_image(make_grid(
    [full_image, recon, resize(enc[0, :3], 8)]
))