In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from sklearn.manifold import TSNE
from sklearn.decomposition import IncrementalPCA

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
import torchaudio.transforms as AT
import torchaudio.functional as AF
from torchvision.utils import make_grid
from IPython.display import display, Audio
import torchaudio
from torchaudio.io import StreamReader

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device, iter_batches
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *
from src.models.cnn import *
from src.models.util import *
from src.models.transform import *
from src.util.audio import *
from src.util.files import *
from src.util.embedding import *
from scripts import datasets

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
def get_vec(*shape):
    n = math.prod(shape)
    return torch.linspace(0, n - 1, n).view(shape)

get_vec(1, 3, 10, 8)

In [None]:
def space_to_depth(x: torch.Tensor) -> torch.Tensor:
    return torch.cat(
        [
            x[..., ::2, ::2], 
            x[..., 1::2, ::2], 
            x[..., ::2, 1::2], 
            x[..., 1::2, 1::2]
        ], 
        dim=1
    )
    
batch = get_vec(1, 2, 4, 6)
print(batch.shape)
print(batch)
d = space_to_depth(batch)
print(d.shape)
d

In [None]:
def depth_to_space(x: torch.Tensor) -> torch.Tensor:
    nc = x.shape[-3] // 2
    s = torch.stack([
        d[..., :nc, :, :], 
        d[..., nc:, :, :],
       # d[..., 1:nc+1, :, :]
    ], dim=-1).view(-1, nc, d.shape[-2], d.shape[-1] * 2)
    nc //= 2
    s = torch.stack([
        s[..., :nc, :, :], 
        s[..., nc:, :, :], 
    ], dim=-2).view(-1, nc, d.shape[-2] * 2, d.shape[-1] * 2)
    return s

rd = depth_to_space(d)
print(rd.shape)
rd

In [None]:
class ConvTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(4 * 3, 4, 5)
        self.conv2 = nn.Conv2d(4 * 4, 3, 5)
        
    def forward(self, x):
        x = space_to_depth(x)          # 12 x 32 x 32
        x = self.conv1(x)              #  4 x 28 x 28
        x = space_to_depth(x)          # 16 x 14 x 14
        x = self.conv2(x)              #  3 x 10 x 10
        return x

conv = ConvTest()
image = get_vec(1, 3, 64, 64)
print(image.shape)
out = conv(image)
print(out.shape)

In [None]:
out

In [None]:
class PrintShape(nn.Module):
    def __init__(self, prefix: str = ""):
        super().__init__()
        self.prefix = prefix
        
    def forward(self, x):
        print(self.prefix, x.shape)
        return x
    
class Conv2dBlockLOCAL(Conv2dBlock):

    def __init__(
            self,
            channels: Iterable[int],
            kernel_size: Union[int, Iterable[int]] = 5,
            stride: int = 1,
            pool_kernel_size: int = 0,
            pool_type: str = "max",  # "max", "average"
            act_fn: Optional[nn.Module] = None,
            act_last_layer: bool = True,
            bias: bool = True,
            transpose: bool = False,
            batch_norm: bool = False,
            space_to_depth: bool = False,
    ):
        super().__init__(channels, kernel_size, stride, pool_kernel_size, pool_type, act_fn, act_last_layer, bias, transpose, batch_norm, space_to_depth)
        self.channels = list(channels)
        assert len(self.channels) >= 2, f"Got: {channels}"
            
        num_layers = len(self.channels) - 1
        if isinstance(kernel_size, int):
            self.kernel_size = [kernel_size] * num_layers 
        else:
            self.kernel_size = list(kernel_size)
            if len(self.kernel_size) != num_layers:
                raise ValueError(f"Expected `kernel_size` to have {num_layers} elements, got {self.kernel_size}")
                
        self._act_fn = act_fn

        self.layers = nn.Sequential()

        if batch_norm:
            self.layers.append(
                nn.BatchNorm2d(self.channels[0])
            )

        #if space_to_depth and transpose:
        #    for i, ch in enumerate(self.channels):
        #        if ch / 4 != ch // 4:
        #            raise ValueError(f"with 'space_to_depth' and 'transpose', channels must be divisible by 4, got {ch} at position {i}")
                    
        in_channel_mult = 1
        out_channel_mult = 1
        for i, (in_channels, out_channels, kernel_size) in enumerate(zip(self.channels, self.channels[1:], self.kernel_size)):
            
            if space_to_depth and transpose:
                self.layers.append(SpaceToDepth2d(transpose=transpose))
                self.layers.append(PrintShape("before conv"))
                out_channel_mult = 1
                if i < len(self.channels) - 2:
                    out_channel_mult = 4
                    
            self.layers.append(
                self._create_layer(
                    in_channels=in_channels * in_channel_mult,
                    out_channels=out_channels * out_channel_mult,
                    kernel_size=kernel_size,
                    stride=stride,
                    bias=bias,
                    transpose=transpose,
                )
            )
            
            if space_to_depth and not transpose and i < len(self.channels) - 1:
                self.layers.append(PrintShape("after conv"))
                self.layers.append(SpaceToDepth2d(transpose=transpose))
                in_channel_mult = 4

            if pool_kernel_size and i == len(self.channels) - 2:
                klass = {
                    "max": nn.MaxPool2d,
                    "average": nn.AvgPool2d,
                }[pool_type]
                self.layers.append(
                    klass(pool_kernel_size)
                )
            if self._act_fn and (act_last_layer or i + 2 < len(self.channels)):
                self.layers.append(act_fn)
    
CHANNELS = [3, 2, 4, 5]
KERNEL_SIZE = [5, 5, 6]
conv = Conv2dBlockLOCAL(channels=CHANNELS, kernel_size=KERNEL_SIZE, space_to_depth=True)
print(conv)
print("out:", conv(image).shape)
convt = Conv2dBlockLOCAL(channels=list(reversed(CHANNELS)), kernel_size=list(reversed(KERNEL_SIZE)), space_to_depth=True, transpose=True)
print(convt)
print(convt(conv(image)).shape)

In [None]:
32*3*3

In [None]:
c1 = Conv2dBlock([3, 1], 5)
c2 = Conv2dBlock([1, 3], 5, transpose=True)
i1 = torch.rand(1, 3, 64, 64)
i2 = c1(i1)
i3 = c2(i2)
print(f"{i1.shape} -> {i2.shape} -> {i3.shape}")

In [None]:
ds = TransformDataset(
    TensorDataset(torch.load("../datasets/kali-uint8-64x64.pt")),
    dtype=torch.float,
    multiply=1./255.,
)

In [None]:
class Sobel(nn.Module):
    def __init__(self, kernel_size: int = 5, sigma: float = 5.):
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma
        
    def forward(self, x):
        blurred = VF.gaussian_blur(x, [self.kernel_size, self.kernel_size], [self.sigma, self.sigma])
        return (x - blurred).clamp_min(0)
    
image = ds[23][0]
display(VF.to_pil_image(image))
display(VF.to_pil_image(Sobel()(image)))

In [None]:
VF.gaussian_blur?