In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Optional, Union

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from sklearn.decomposition import PCA, FactorAnalysis
from sklearn.cluster import KMeans

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import Space2d, IFS
from src.datasets import *
from src.models.cnn import *
from src.models.encoder import *
from src.models.decoder import *
from src.models.transform import *
from src.util.embedding import *

def resize(img, scale: float, aa: bool = False):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], VF.InterpolationMode.BILINEAR if aa else VF.InterpolationMode.NEAREST, antialias=aa)

In [None]:
class CodebookAutoencoder(nn.Module):
    def __init__(
            self, 
            shape: Tuple[int, int, int],
            code_size: int,
            code_dim: int,
            channels: Iterable[int] = (16, 24, 32),
            kernel_size: Union[int, Iterable[int]] = (3, 4, 3),
            space_to_depth: bool = True,
    ):
        super().__init__()
        self.shape = shape
        self.code_size = code_size
        self.code_dim = code_dim
        self.encoder = EncoderConv2d(shape=shape, code_size=code_size * code_dim, channels=channels, kernel_size=kernel_size, space_to_depth=space_to_depth)
        self.decoder = DecoderConv2d(shape=shape, code_size=code_size * code_dim, channels=list(reversed(channels)), kernel_size=kernel_size, space_to_depth=space_to_depth)
        self.code_book = nn.Embedding(code_size, code_dim)

    def encode_X(self, x: torch.Tensor) -> torch.Tensor:
        codes = self.encoder(x).view(-1, self.code_size, self.code_dim)
        # sim = codes @ self.code_book.weight.T
        code_book_expanded = self.code_book.weight.unsqueeze(0).expand(x.shape[0], *self.code_book.weight.shape)
        #print(codes.shape, code_book_expanded.shape)
        #dist = torch.linalg.norm(codes - code_book_expanded, 1., -1)
        dist = (codes - code_book_expanded).abs()#.mean(-1)
        print("dist", dist.shape)
        m, indices = dist.min(dim=-1)
        # indices = sim.argsort()[..., 0]
        print("indices", indices.shape, indices.min(), indices.max())
        return indices

    def encode_Y(self, x: torch.Tensor) -> torch.Tensor:
        codes = self.encoder(x).view(-1, self.code_size, self.code_dim)
        sim = codes @ self.code_book.weight.T
        #print("sim", sim.shape)
        indices = sim.argmax(dim=-2)
        print("indices", indices.shape, indices.min(), indices.max())
        return indices

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        codes = self.encoder(x).view(-1, self.code_size, self.code_dim)
        print("codes", codes.shape, self.code_book.weight.shape)
        sim = codes @ self.code_book.weight.T
        #print("sim", sim.shape)
        indices = sim.argmax(dim=-2)
        print("indices", indices.shape, indices.min(), indices.max())
        return indices

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        codes = self.code_book(x).view(-1, self.code_size * self.code_dim)
        return self.decoder(codes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))


with torch.no_grad():
    ae = CodebookAutoencoder((3, 32, 32), 128, 256)
    print(f"params: {num_module_parameters(ae):,}")
    image = torch.randn(10, 3, 32, 32)
    output = ae(image)
print(output.shape)
VF.to_pil_image(resize(make_grid(output).clamp(0, 1), 2))

In [None]:
from experiments.datasets import rpg_tile_dataset_3x32x32
ds = rpg_tile_dataset_3x32x32((1, 32, 32), validation=True)

In [None]:
ae = CodebookAutoencoder((1, 32, 32), 128, 128)
ae.load_state_dict(torch.load("../checkpoints/ae/codebook1/snapshot.pt")["state_dict"])

In [None]:
with torch.no_grad():
    images = torch.concat([ds[i][0].unsqueeze(0) for i in range(10)])
    print(images.shape)
    repros = ae(images)
    display(VF.to_pil_image(make_grid(images, nrow=len(images))))
    display(VF.to_pil_image(make_grid(repros, nrow=len(images))))
    print(F.l1_loss(repros, images))

In [None]:
from src.models.encoder.vqvae import *
vae = VQVAE(
    in_channel=1,
    channel=256,
    n_res_block=2,
    n_res_channel=64,
    embed_dim=16,
    n_embed=1024,
)
vae.load_state_dict(torch.load("../checkpoints/ae/vqvae4/snapshot.pt")["state_dict"])
vae = vae.eval()

In [None]:
with torch.inference_mode():
    #image = torch.randn(10, 1, 32, 32)
    image = resize(torch.concat([ds[i][0].unsqueeze(0) for i in range(10)]), 1, aa=False)
    output, diff = vae(image)
    codes = vae.encode(image)
    print("codes", [c.shape for c in codes])
    print("codes", [(c.mean(), c.std()) for c in codes[:2]])
    print("reduction", math.prod(image.shape[-3:]) / (math.prod(codes[0].shape[-3:]) + math.prod(codes[1].shape[-3:])))
print(output.shape, diff.shape)
VF.to_pil_image(resize(make_grid(output).clamp(0, 1), 2))

In [None]:
32*32, (2*4*4 + 2*8*8)

In [None]:
with torch.inference_mode():
    features = torch.randn(10, 64, 1, 1) * .2, torch.randn(10, 64, 2, 2) * 0.6
    output = vae.decode(*features)
    print("out:", output.shape)
VF.to_pil_image(resize(make_grid(output).clamp(0, 1), 2))

In [None]:
q = Quantize(12, 1000)#.eval()

In [None]:
print(q(torch.ones(2, 12)))
print(list(q.buffers()))

In [None]:
from src.models.encoder.vqvae import *

class Autoencoder(nn.Module):
    def __init__(
            self, 
            shape: Tuple[int, int, int],
            code_size: int,
            channel=128,
            n_res_block=2,
            n_res_channel=32,
    ):
        super().__init__()
        self.encoder = nn.Sequential()
        self.encoder.add_module("enc_b", Encoder(shape[0], channel, n_res_block, n_res_channel, stride=4))
        self.encoder.add_module("enc_t", Encoder(channel, channel, n_res_block, n_res_channel, stride=2))
        
        with torch.no_grad():
            enc_shape = self.encoder(torch.empty(1, *shape)).shape[-3:]
            print(enc_shape)
        
        self.encoder.add_module("flatteb", nn.Flatten(1))
        self.encoder.add_module("linear", nn.Linear(math.prod(enc_shape), code_size))
        
        self.decoder = nn.Sequential()
        self.decoder.add_module("linear", nn.Linear(code_size, math.prod(enc_shape)))
        self.decoder.add_module("reshape", Reshape(enc_shape))
        self.decoder.add_module("dec", Decoder(
            channel,
            shape[0],
            channel,
            n_res_block,
            n_res_channel,
            stride=6,
        ))
    
    def forward(self, x):
        return self.decoder(self.encoder(x))
    
with torch.no_grad():
    ae = Autoencoder((3, 32, 32), 128)
    print(f"params: {num_module_parameters(ae):,}")
    image = torch.randn(10, 3, 32, 32)
    output = ae(image)
print(output.shape)
VF.to_pil_image(resize(make_grid(output).clamp(0, 1), 2))

In [None]:
class Sobel(nn.Module):
    def __init__(
            self,
            magnitude: bool = True,
            direction: bool = True,
            padding: int = 0,
    ):
        super().__init__()
        self.magnitude = magnitude
        self.direction = direction
        self.padding = padding
        self.kernel_1 = nn.Parameter(torch.Tensor([[[
            [1, 0, -1], [2, 0, -2], [1, 0, -1]
        ]]]), requires_grad=False)
        self.kernel_2 = nn.Parameter(torch.Tensor([[[
            [1, 2, 1], [0, 0, 0], [-1, -2, -1]
        ]]]), requires_grad=False)

    def forward(self, x):
        g1 = F.conv2d(x, self.kernel_1, padding=self.padding)
        g2 = F.conv2d(x, self.kernel_2, padding=self.padding)

        if self.magnitude:
            mag = torch.sqrt(g1 ** 2 + g2 ** 2)

        if self.direction:
            dir = torch.atan2(g1, g2)

        if self.magnitude:
            if not self.direction:
                return mag
            else:
                return torch.concat([mag, dir], dim=1)
        else:
            if self.direction:
                return dir
            else:
                raise ValueError("Must define at least one of `magnitude` or `direction`")

display(VF.to_pil_image(resize(make_grid(images, nrow=20), 3)))
display(VF.to_pil_image(resize(make_grid(Sobel(direction=True, padding=1)(images), nrow=20), 3)))
Sobel(direction=True)(images).shape


In [None]:
F.conv2d?