In [None]:
import sys
sys.path.append("..")

import random
import math
import time
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.util import *

In [None]:
SHAPE = (3, 64, 64)
dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt")[:1000])
dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255.)
print(len(dataset))
VF.to_pil_image(make_grid_labeled(
    [i[0] for i, _ in zip(dataset, range(8*8))]
))

In [None]:
class Encoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.channels = [SHAPE[0], 50, 50, 20]
        self.layers = nn.Sequential()
        for i, (channels, next_channels) in enumerate(zip(self.channels, self.channels[1:])):
            self.layers.append(nn.Conv2d(
                in_channels=channels,
                out_channels=next_channels,
                kernel_size=5,
                stride=1,
            ))
        self.layers.append(nn.MaxPool2d(
            kernel_size=11,
            return_indices=False,
        ))
    
    def forward(self, x):
        return self.layers(x).flatten(1)

encoder = Encoder()
print("params:", num_module_parameters(encoder))
print("output:", encoder(dataset[0][0].unsqueeze(0)).shape)

In [None]:
class Sinus(nn.Module):
    def __init__(self, size: int, freq_scale: float = 3.):
        super().__init__()
        self.freq = nn.Parameter(torch.randn(size) * freq_scale)
        self.phase = nn.Parameter(torch.randn(size) * 3.)
    
    def forward(self, x):
        freq = self.freq
        phase = self.phase
        return torch.sin(x * freq + phase)

class Decoder(FreescaleImageModule):

    def __init__(self, code_size: int = 100):
        super().__init__(num_in=code_size)
        self.layers = nn.Sequential(
            nn.Linear(code_size + 2, code_size),
            Sinus(code_size, 30),
            nn.Linear(code_size, code_size),
            Sinus(code_size, 20),
            nn.Linear(code_size, 3),
            nn.Linear(3, 3),
            
        )

    def forward_state(self, x: torch.Tensor, shape: Tuple[int, int, int]) -> torch.Tensor:
        return self.layers(x)

    
decoder = Decoder(100)
print(f"params: {num_module_parameters(decoder):,}")
#print("output:", decoder(torch.randn(2, 1), SHAPE).shape)
code = torch.randn(1, 100)
display(VF.to_pil_image(VF.resize(decoder(code, (3, 32, 32))[0], (256, 256), VF.InterpolationMode.NEAREST)))
display(VF.to_pil_image(VF.resize(decoder(code, (3, 128, 128))[0], (256, 256), VF.InterpolationMode.NEAREST)))

In [None]:
class ResidualLinearBlock(nn.Module):
    def __init__(
            self,
            num_hidden: int,
            num_layers: int,
            batch_norm: bool = True,
            concat: bool = False,
            activation: Union[str, Callable, nn.Module] = "relu6",
    ):
        super().__init__()
        self.do_concat = concat
        self.layers = nn.Sequential(OrderedDict([
            *(
                (("norm", nn.BatchNorm1d(num_hidden)), ) if batch_norm else tuple()
            ),
            *(
                (f"layer_{i + 1}", nn.Sequential(OrderedDict([
                    ("linear", nn.Linear(num_hidden, num_hidden)),
                    ("act", activation_to_module(activation)),
                ])))
                for i in range(num_layers)
            )
        ]))

    def forward(self, x):
        if self.do_concat:
            return torch.concat([x, self.layers(x)], dim=-1)
        else:
            return x + self.layers(x)

    def extra_repr(self):
        return "concat=True" if self.do_concat else ""

    
class ImageManifoldDecoder(nn.Module):

    def __init__(
            self,
            num_input_channels: int,
            num_output_channels: int = 3,
            num_hidden: int = 256,
            num_blocks: int = 2,
            num_layers_per_block: int = 2,
            concat_residual: Union[bool, Iterable[bool]] = False,
            pos_embedding_freqs: Iterable[float] = (7, 17),
            batch_norm: bool = True,
            default_shape: Optional[Tuple[int, int]] = None,
            activation: Union[str, Callable, nn.Module] = "gelu",
            activation_out: Union[str, Callable, nn.Module] = "sigmoid",
            cross_attention: bool = True,
            cross_attention_heads: int = 8,
    ):
        super().__init__()
        self.num_input_channels = num_input_channels
        self.default_shape = default_shape
        self.num_output_channels = num_output_channels
        self.pos_embedding_freqs = tuple(pos_embedding_freqs)
        # x, y, sin-x, sin-y, cos-x, cos-y, ...
        self.pos_embedding_size = (len(self.pos_embedding_freqs) * 2 + 1) * 2
        if isinstance(concat_residual, bool):
            self.concat_residual = (concat_residual, ) * num_blocks
        else:
            self.concat_residual = tuple(concat_residual)
            if len(concat_residual) != num_blocks:
                raise ValueError(f"len(concat_residual) must be {num_blocks}, got {len(self.concat_residual)}")

        hidden_sizes = [num_hidden]
        hs = num_hidden
        for i, concat in enumerate(self.concat_residual):
            if concat:
                hs *= 2
            hidden_sizes.append(hs)

        if not cross_attention:
            self.pos_to_color = nn.Sequential()
            self.pos_to_color.add_module("linear_in", nn.Linear(num_input_channels + self.pos_embedding_size, hidden_sizes[0]))
            self.pos_to_color.add_module("act_in", activation_to_module(activation))
        else:
            print(hidden_sizes)
            self.upscale_pos = nn.Linear(self.pos_embedding_size, self.num_input_channels)
            self.cross_atn = nn.MultiheadAttention(self.num_input_channels, num_heads=cross_attention_heads)
            self.proj = nn.Linear(self.num_input_channels, hidden_sizes[0])
            self.pos_to_color = nn.Sequential()
            
        self.pos_to_color.add_module("resblocks", nn.Sequential(OrderedDict([
            (
                f"resblock_{i+1}",
                ResidualLinearBlock(
                    num_hidden=hs,
                    num_layers=num_layers_per_block,
                    batch_norm=batch_norm,
                    concat=concat,
                    activation=activation,
                )
            )
            for i, (concat, hs) in enumerate(zip(self.concat_residual, hidden_sizes))
        ])))
        self.pos_to_color.add_module("linear_out", nn.Linear(hidden_sizes[-1], num_output_channels))
        self.pos_to_color.add_module("act_out", activation_to_module(activation_out))
        
        self._cur_space = None
        self._cur_space_shape = None

    def extra_repr(self):
        args = [
            f"pos_embedding_freqs={self.pos_embedding_freqs}",
            f"concat_residual={self.concat_residual}",
        ]
        if self.default_shape is not None:
            args.append(f"default_shape={self.default_shape}")
        return ", ".join(args)

    def forward(self, x: torch.Tensor, shape: Optional[Tuple[int, int]] = None) -> torch.Tensor:
        if x.ndim not in (1, 2):
            raise ValueError(f"Expecting ndim 1 or 2, got {x.shape}")

        if x.ndim == 2:
            return torch.concat([
                self.forward(x_i, shape).unsqueeze(0)
                for x_i in x
            ])

        space, shape = self.get_pos_embedding(shape)
        input_codes = x.unsqueeze(0).expand(space.shape[0], x.shape[-1])
        
        if getattr(self, "cross_atn", None) is None:
            codes = torch.concat([
                input_codes,
                space
            ], dim=1)
            color = self.pos_to_color(codes)
        else:
            embedding = self.upscale_pos(space)
            print(x.shape, input_codes.shape, embedding.shape)
            codes, code_weights = self.cross_atn(
                query=embedding,
                key=input_codes,
                value=input_codes,
            )
            codes = self.proj(codes)
            color = self.pos_to_color(codes)
            
        return color.permute(1, 0).view(self.num_output_channels, *shape)

    def get_pos_embedding(self, shape: Optional[Tuple[int, int]] = None):
        if shape is None:
            shape = self.default_shape
        if shape is None:
            raise ValueError("Must either define `default_shape` or `shape`")

        if shape != self._cur_space_shape:
            space = Space2d(shape=(2, *shape)).space().to(self.pos_to_color[-2].weight)
            space = space.permute(1, 2, 0).view(-1, 2)
            space = torch.concat([
                space,
                *(
                    (space * freq).sin()
                    for freq in self.pos_embedding_freqs
                ),
                *(
                    (space * freq).cos()
                    for freq in self.pos_embedding_freqs
                )
            ], 1)
            self._cur_space = space.to(self.pos_to_color[-2].weight)
            self._cur_space_shape = shape

        return self._cur_space, shape

    def weight_images(self, **kwargs):
        images = []
        for i, p in enumerate(self.parameters()):
            if p.ndim == 2 and any(s > 1 for s in p.shape):
                images.append(p)

        return images
    
    
    
CODE_SIZE = 128
decoder = ImageManifoldDecoder(
    CODE_SIZE, num_blocks=8, num_layers_per_block=2, default_shape=SHAPE[-2:],
    concat_residual=[False, True] * 4, num_hidden=64, cross_attention=True,
)
print(f"params {num_module_parameters(decoder):,}")
if 1:
    with torch.no_grad():
        images = decoder(torch.randn(8, CODE_SIZE), (64, 64))
        print(images.shape)
        #print(images)
        display(VF.to_pil_image(make_grid(images, normalize=True)))
decoder#.weight_images()

In [None]:
nn.MultiheadAttention.forward?

In [None]:
with torch.inference_mode():
    data = torch.randn(64, decoder.num_input_channels)
    start_time = time.time()
    for _ in tqdm(range(5)):
        decoder(data)
    seconds = time.time() - start_time
    print(f"inference rate {64 * 5 / seconds:,.2f}/s")

# training test

In [None]:
from torch.nn.utils import clip_grad_norm_

def train_test(
    decoder: nn.Module,
    codes: torch.Tensor,
    targets: torch.Tensor,
    iters: int = 10000,
    device="auto",
    lr=0.0001,
):    
    device = to_torch_device(device)
    print(device)
    
    decoder = decoder.to(device)
    targets = targets.to(device)
    codes = codes.to(device)
    
    #optimizer = torch.optim.Adadelta(decoder.parameters(), lr=lr)
    optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
    #optimizer = torch.optim.AdamW(decoder.parameters(), lr=lr)
    #optimizer = torch.optim.RMSprop(decoder.parameters(), lr=lr)
    
    
    last_print_time = time.time()
    last_print_it = 0
    try:
        for it in tqdm(range(iters)):

            pixels = decoder(codes).clamp(0, 1)
            
            loss = F.l1_loss(pixels, targets)
            
            optimizer.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_value_(decoder.parameters(), .00000001)
            grad = 0
            #for i, p in enumerate(decoder.parameters()):
            #    grad += float(p.grad.abs().mean())
            #grad /= i + 1
            #print(grad, end=", ")
            optimizer.step()
            
            cur_time = time.time()
            if cur_time - last_print_time > 10 and it - last_print_it >= 300:
                last_print_time = cur_time
                last_print_it = it
                print(f"train loss {float(loss)}")
                display(VF.to_pil_image(make_grid(
                    torch.concat([pixels, targets]), 
                    nrow=len(pixels))))

    except KeyboardInterrupt:
        pass

    
decoder = ImageManifoldDecoder(128, num_blocks=8, num_layers_per_block=2, num_hidden=256, default_shape=SHAPE[-2:])
print(f"params {num_module_parameters(decoder):,}")

train_test(
    decoder,
    torch.randn(8, 128),
    torch.concat([dataset[16 + i][0].unsqueeze(0) for i in range(8)]),
    lr=.001,
)

In [None]:
with torch.no_grad():
    display(VF.to_pil_image(make_grid(decoder(torch.randn(8*4, CODE_SIZE).cuda()))))

In [None]:
decoder = ImageManifoldDecoder(128, num_blocks=2, num_layers_per_block=2, num_hidden=256, default_shape=SHAPE[-2:])
print(f"params {num_module_parameters(decoder):,}")

train_test(
    decoder,
    torch.randn(8, 128),
    torch.concat([dataset[16 + i][0].unsqueeze(0) for i in range(8)]),
    lr=.001,
)

In [None]:
decoder = ImageManifoldDecoder(128, num_blocks=4, num_layers_per_block=3, num_hidden=dden=256, default_shape=SHAPE[-2:])
print(f"params {num_module_parameters(decoder):,}")

train_test(
    decoder,
    torch.randn(8, 128),
    torch.concat([dataset[16 + i][0].unsqueeze(0) for i in range(8)]),
)

# encoder

In [None]:
class ImageManifoldEncoder(nn.Module):

    def __init__(
            self,
            num_output_channels: int,
            num_input_channels: int = 3,
            num_hidden: int = 256,
            num_blocks: int = 2,
            num_layers_per_block: int = 2,
    ):
        super().__init__()
        self.num_input_channels = num_input_channels
        self.num_output_channels = num_output_channels
        self.color_to_pos = nn.Sequential(OrderedDict([
            ("linear_in", nn.Linear(num_input_channels + 2, num_hidden)),
            ("act_in", nn.GELU()),
            ("resblocks", nn.Sequential(OrderedDict([
                (f"resblock_{i+1}", ResidualLinearBlock(num_hidden=num_hidden, num_layers=num_layers_per_block))
                for i in range(num_blocks)
            ]))),
            ("linear_out", nn.Linear(num_hidden, num_output_channels)),
            ("act_out", nn.Sigmoid()),
        ]))
        self._cur_space = None
        self._cur_space_shape = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim not in (3, 4):
            raise ValueError(f"Expecting ndim 3 or 4, got {x.shape}")

        if x.ndim == 4:
            return torch.concat([
                self.forward(x_i, shape).unsqueeze(0)
                for x_i in x
            ])

        space = self.get_space(x.shape[-2:])
        codes = torch.concat([x, space], dim=0).flatten(1).permute(1, 0)
        return self.color_to_pos(codes).permute(1, 0).mean(-1)

    def get_space(self, shape: Tuple[int, int] = None):
        if shape != self._cur_space_shape:
            space = Space2d(shape=(2, *shape)).space().to(self.color_to_pos[0].weight)
            self._cur_space = space
            self._cur_space_shape = shape

        return self._cur_space

    def weight_images(self, **kwargs):
        images = []
        for i, p in enumerate(self.parameters()):
            if p.ndim == 2:
                images.append(p)

        return images

encoder = ImageManifoldEncoder(128)
out = encoder(torch.randn(3, 64, 64))
print("out", out.shape)
print(out)
encoder


In [None]:
state = torch.load("../checkpoints/ae-manifold-7/best.pt")

In [None]:
decoder = ImageManifoldDecoder(
    128, 1, num_blocks=8, num_layers_per_block=2, num_hidden=256, default_shape=(64, 64),
)
decoder.load_state_dict({
    key[8:]: value
    for key, value in state["state_dict"].items()
    if key.startswith("decoder.")
})

In [None]:
with torch.no_grad():
    display(VF.to_pil_image(make_grid(decoder(.5 * torch.randn(8*2, CODE_SIZE)))))

# combine several decoders

In [None]:
class DecoderEnsemble(nn.Module):

    def __init__(
            self,
            *decoders: nn.Module,
            weights: Union[None, Iterable[float], torch.Tensor] = None,
            train_weights: bool = True,
    ):
        super().__init__()
        self.decoders = nn.ModuleDict({
            f"decoder_{i + 1}": decoder
            for i, decoder in enumerate(decoders)
        })
        
        if weights is None:
            weights = torch.ones(len(self.decoders)) / len(self.decoders)
        elif isinstance(weights, torch.Tensor):
            pass
        else:
            weights = torch.Tensor(weights)
        self.weights = nn.Parameter(weights, requires_grad=train_weights)
        
    def forward(self, *args, **kwargs):
        output_sum = None
        for i, decoder in enumerate(self.decoders.values()):
            output = decoder(*args, **kwargs) * self.weights[i]
            
            if output_sum is None:
                output_sum = output
            else:
                output_sum = output_sum + output
                
        return output_sum
    
dec = DecoderEnsemble(
    ImageManifoldDecoder(num_input_channels=100, num_output_channels=1, default_shape=(64, 64), num_hidden=32),
    ImageManifoldDecoder(num_input_channels=100, num_output_channels=1, default_shape=(64, 64), num_hidden=256)
)
output = dec(torch.rand(8, 100))
print(output.shape)
VF.to_pil_image(make_grid(output))