In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Type

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.util.image import *
from src.util import *
from src.algo import ca1
from src.models.util import *
from src.models.transform import *
from src.models.debug import DebugLayer

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
patches = torch.load("../datasets/rpg-3x32x32-uint-test.pt").float() / 255.

In [None]:
small_patches = []
for p in patches:
    for sp in iter_image_patches(p, shape=(16, 16), stride=(4, 4)):
        small_patches.append(sp.unsqueeze(0))
small_patches = torch.concat(small_patches)
small_patches.shape

In [None]:
from sklearn.decomposition import PCA
pca = PCA(128)
pca.fit(small_patches.flatten(1))
lib = torch.Tensor(pca.components_).view(-1, 3, *small_patches.shape[-2:])
#lib = (lib - lib.min()) / (lib.max() - lib.min())
print(f"min/max: {lib.min()}/{lib.max()}")

display(VF.to_pil_image(resize(make_grid(lib, normalize=True, nrow=16), 4)))

In [None]:
torch.save(lib, "../datasets/pca-128x3x16x16.pt")

In [None]:
lib_size = lib.shape[0]
patch_size = lib.shape[-1]

w = torch.randn(8, lib_size, 4, 4)
c = nn.ConvTranspose2d(lib_size, 3, patch_size, stride=patch_size, bias=False)
with torch.no_grad():
    c.weight[:] = lib
print("w", c.weight.shape)
o = c(w)#.view(-1, 3, 10, 32, 32)
print(f"output {o.shape}, min/max: {o.min()}/{o.max()}")
o = F.sigmoid(o)
VF.to_pil_image(resize(make_grid(
    o#.clamp(0, 1)
, normalize=False, nrow=4), 4))

In [None]:
class LibDecoder2d(nn.Module):
    def __init__(
            self, 
            shape: Tuple[int, int, int],
            code_size: int,
            lib_size: int,
            patch_size: Union[int, Tuple[int, int]],
            patch_filename: Optional[str] = None,
            output_activation: Union[None, str, Callable] = "sigmoid",
    ):
        super().__init__()
        self.shape = tuple(shape)
        self.code_size = code_size
        self.lib_size = lib_size
        if isinstance(patch_size, int):
            patch_size = [patch_size, patch_size]
            
        self.grid_shape = (shape[-2] // patch_size[-2], shape[-1] // patch_size[-1]) 

        self.mlp = nn.Sequential(
            nn.Linear(code_size, lib_size * math.prod(self.grid_shape))  
        )
         
        self.conv = nn.ConvTranspose2d(lib_size, shape[0], kernel_size=patch_size, stride=patch_size)
        if patch_filename:
            with torch.no_grad():
                self.conv.weight[:] = torch.load(patch_filename)
                
        self.output_activation = activation_to_callable(output_activation)
        
    def forward(self, x):
        grid = self.mlp(x).reshape(-1, self.lib_size, *self.grid_shape)
        return self.output_activation(self.conv(grid))
    
    def extra_repr(self):
        return f"shape={self.shape}, code_size={self.code_size}, lib_size={self.lib_size}"

model = LibDecoder2d((3, 64, 64), 100, 128, (16, 16), patch_filename="../datasets/pca-128x3x16x16.pt")
print(f"params: {num_module_parameters(model):,}")
print("out:", model(torch.rand(1, 100)).shape)

display(VF.to_pil_image(resize(make_grid(
    model(torch.rand(16, 100))
, normalize=False, nrow=4), 4)))

model

In [None]:
class ResMLP(nn.Module):
    def __init__(
            self,
            num_channels: int,
            num_hidden: int,
            activation: Union[None, str, Callable] = "relu",
    ):
        super().__init__()
        self.num_channels = num_channels
        self.num_hidden = num_hidden

        self.mlp = nn.Sequential()
        self.mlp.add_module("linear1", nn.Linear(num_channels, num_hidden))
        self.mlp.add_module("act1", activation_to_module(activation))
        self.mlp.add_module("linear2", nn.Linear(num_hidden, num_hidden))
        self.mlp.add_module("act2", activation_to_module(activation))
        self.mlp.add_module("linear3", nn.Linear(num_hidden, num_channels))
        self.mlp.add_module("act3", activation_to_module(activation))

    def forward(self, x):
        y = self.mlp(x)
        return x + y

In [None]:
#class PatchConv(nn.Module):
    

class PatchAutoencoder2d(nn.Module):
    def __init__(
            self,
            shape: Tuple[int, int, int],
            code_size: int,
            lib_size: int,
            patch_size: Union[int, Tuple[int, int]],
            mlp_blocks: int = 2,
            mlp_cells: Optional[int] = None,
            mlp_activation: Union[None, str, Callable] = "relu",
            output_activation: Union[None, str, Callable] = "sigmoid",
    ):
        super().__init__()
        self.shape = tuple(shape)
        self.code_size = code_size
        self.lib_size = lib_size
        if isinstance(patch_size, int):
            patch_size = [patch_size, patch_size]
        self.patch_size = patch_size
        
        self.grid_shape = (shape[-2] // patch_size[-2], shape[-1] // patch_size[-1])
        grid_size = lib_size * math.prod(self.grid_shape)
        if mlp_cells is None:
            mlp_cells = grid_size
            
        self.encoder = nn.Sequential()
        self.encoder.add_module("patch", nn.Conv2d(shape[0], lib_size, kernel_size=patch_size, stride=patch_size))
        self.encoder.add_module("flatten", nn.Flatten(-3))
        for i in range(mlp_blocks):
            self.encoder.add_module(f"block{i+1}", ResMLP(grid_size, mlp_cells, activation=mlp_activation))
        #self.encoder.add_module("d", DebugLayer())
        self.encoder.add_module(f"proj", nn.Linear(grid_size, code_size))
        
        print(self.grid_shape, grid_size)
        self.decoder = nn.Sequential()
        self.decoder.add_module(f"proj", nn.Linear(code_size, grid_size))
        for i in range(mlp_blocks):
            self.decoder.add_module(f"block{i+1}", ResMLP(grid_size, mlp_cells, activation=mlp_activation))
        self.decoder.add_module("unflatten", Reshape((lib_size, *self.grid_shape)))
        self.decoder.add_module("patch", nn.ConvTranspose2d(lib_size, shape[0], kernel_size=patch_size, stride=patch_size))
        if output_activation is not None:
            self.decoder.add_module("act_out", activation_to_module(output_activation))

    def forward(self, x):
        return self.decoder(self.encoder(x))

    def extra_repr(self):
        return f"shape={self.shape}, code_size={self.code_size}, lib_size={self.lib_size}, patch_size={self.patch_size}"

model = PatchAutoencoder2d((3, 64, 64), 100, 256, (16, 16), mlp_cells=128)
print(f"params: {num_module_parameters(model):,}")
print(model)
model(torch.rand(1, 3, 64, 64))

In [None]:
from src.models.decoder import *
from src.models.encoder import *

class EnsembleDecoder2d(nn.Module):
    def __init__(
            self,
            shape: Tuple[int, int, int],
            code_size: int,
            activation: Union[None, str, Callable] = "relu",
    ):
        super().__init__()
        self.shape = tuple(shape)
        self.code_size = code_size

        self.conv = DecoderConv2d(shape=shape, code_size=code_size)
        self.manifold = ImageManifoldDecoder(
            default_shape=shape[-2:], num_input_channels=code_size, num_output_channels=shape[0],
            pos_embedding_freqs=(7, 17, 77,),
            activation_out=None,
        )

    def forward(self, x):
        y1 = self.conv(x)
        y2 = self.manifold(x)
        return F.sigmoid(y1) + y2


model = EnsembleDecoder2d((3, 64, 64), 100)
print(f"params: {num_module_parameters(model):,}")
out = model(torch.rand(8, 100))
print(out.shape)
display(VF.to_pil_image(make_grid(out)))
print(model)
