In [None]:
import sys
sys.path.append("..")

import random
import math
import time
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.decoder import *
from src.models.transform import *
from src.models.util import *

In [None]:
def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
class SpatialGraphConv(nn.Module):
    """
    https://arxiv.org/pdf/2308.07946.pdf
    https://github.com/Juntongkuki/Pytorch-DSFNet/blob/main/lib/DSFNet.py
    """
    def __init__(
            self,
            num_channels: int,
            act: Callable = F.relu_,
    ):
        super().__init__()
        self.act = act
        num_hidden = num_channels // 2
        self.conv_b = nn.Conv2d(num_channels, num_hidden, kernel_size=1)
        self.conv_c = nn.Conv2d(num_channels, num_hidden, kernel_size=1)
        self.conv_d = nn.Conv2d(num_channels, num_channels, kernel_size=1)
        self.conv_e = nn.Conv2d(num_hidden, num_channels, kernel_size=1)
        self.bn_e = nn.BatchNorm2d(num_channels)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 4, x.ndim                # BxCxHxW
        b = self.conv_b(x).flatten(-2)            # BxCxN
        c = self.conv_c(x).flatten(-2)            # BxCxN
        b_c = b @ c.permute(0, 2, 1)              # BxCxC
        b_c = F.softmax(b_c, dim=-1)              # BxCxC
        d = self.conv_d(x).flatten(-2)            # BxCxN
        print(d.shape, "d")
        print(b_c.shape, "bc")
        b_c_d = b_c.flatten(-2) @ d               # BxCxN
        print(b_c_d.shape, "bcd")
        b_c_d = b_c_d.view(x.shape[0], -1, *x.shape[-2:])  # Bx1xHxW
        print(b_c_d.shape)
        e = self.conv_e(b_c_d)                    # BxCxHxW
        print(e.shape)
        e = self.bn_e(e)
        print(x.shape, b_c_d.shape, e.shape)
        y = self.act(x + e)
        return y
    
img = torch.randn(3, 4, 10, 10)
out = SpatialGraphConv(4)(img)
VF.to_pil_image(make_grid(torch.concat([img, out]), nrow=8))

In [None]:
torch.bmm?

In [None]:
class SpatialGCN(nn.Module):
    def __init__(
            self, 
            plane: int,
            inter_plane: Optional[int] = None,
            act: Callable = F.relu_,
            residual: bool = True,
    ):
        super().__init__()
        if inter_plane is None:
            inter_plane = plane // 2
        
        self.residual = residual
        self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1)
        self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1)
        self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1)

        self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False)
        self.bn_wg = nn.BatchNorm1d(inter_plane)
        self.softmax = nn.Softmax(dim=2)
        
        self.out = nn.Sequential(
            nn.Conv2d(inter_plane, plane, kernel_size=1),
            nn.BatchNorm2d(plane),
        )
        self.act = act

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 4, x.ndim 
        # b, c, h, w = x.size()
        node_k = self.node_k(x)
        node_v = self.node_v(x)
        node_q = self.node_q(x)
        print(node_k.shape, node_v.shape, node_q.shape)
        b,c,h,w = node_k.size()
        node_k = node_k.contiguous().view(b, c, -1).permute(0, 2, 1)
        node_q = node_q.contiguous().view(b, c, -1)
        node_v = node_v.contiguous().view(b, c, -1).permute(0, 2, 1)
        print(node_k.shape, node_v.shape, node_q.shape)
        # A = k * q
        # AV = k * q * v
        # AVW = k *(q *v) * w
        AV = torch.bmm(node_q,node_v)
        AV = self.softmax(AV)
        AV = torch.bmm(node_k, AV)
        AV = AV.transpose(1, 2).contiguous()
        AVW = self.conv_wg(AV)
        AVW = self.bn_wg(AVW)
        AVW = AVW.contiguous().view(b, c, h, -1)
        
        out = self.out(AVW)
        if self.residual:
            out = out + x
        return self.act(out)
        
        
img = torch.linspace(0, 399, 400).view(1, 4, 10, 10)
out = SpatialGCN(4)(img)
VF.to_pil_image(resize(make_grid(torch.concat([img, out]), nrow=8, normalize=True), 10))

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, num_in, num_out, kernel_size: int = 5):
        super().__init__()
        num_hid = max(2, num_in)
        self.q_conv = nn.Conv2d(num_in, num_hid, kernel_size)
        self.k_conv = nn.Conv2d(num_in, num_hid, kernel_size)
        self.v_conv = nn.Conv2d(num_in, num_hid, kernel_size, padding=kernel_size // 2)
        self.s_conv = nn.Conv2d(num_hid * num_hid, num_hid, 1)
        self.out_conv = nn.Conv2d(num_hid, num_out, 1)
        self.residual = nn.Identity() if num_in == num_out else nn.Conv2d(num_in, num_out, 1) 
        
    def forward(self, x):
        assert x.ndim == 4, x.shape
        bs = x.shape[0]
        q = self.q_conv(x).flatten(-2)                # BxCx(HxW)
        k = self.k_conv(x).flatten(-2)
        v = self.v_conv(x)
        s = F.softmax(q @ k.permute(0, 2, 1), dim=1)  # BxCxC
        s = self.s_conv(s.view(bs, -1, 1, 1))         # BxCx1x1
        y = v * s                                     # BxCxHxW
        y = self.out_conv(y)                          # BxCxHxW
        y = F.relu(self.residual(x) + y)
        return y
    
class Encoder(nn.Module):
    def __init__(
            self, 
            shape: Tuple[int, int, int],
            num_out: int,
            channels_ks: Tuple[Tuple[int, int], ...] = ((16, 3), (32, 5), (32, 7)),
    ):
        super().__init__()
        self.shape = tuple(shape)
        self.num_out = num_out
        self.channels_ks = tuple(channels_ks)
        
        self.blocks = nn.Sequential(OrderedDict([
            (
                f"block_{i + 1}",
                EncoderBlock(
                    num_in=self.shape[0] if i == 0 else self.channels_ks[i - 1][0], 
                    num_out=ch_out, 
                    kernel_size=ks,
                )
            )
            for i, (ch_out, ks) in enumerate(self.channels_ks)
        ]))
        with torch.no_grad():
            self._conv_shape = self.blocks(torch.zeros(1, *self.shape)).shape
            
        self.w_out = nn.Linear(math.prod(self._conv_shape), self.num_out)
        
    def forward(self, x):
        assert x.ndim == 4, x.shape
        conv = self.blocks(x)
        code = self.w_out(conv.flatten(1))
        
    def extra_repr(self):
        return (
            f"shape={self.shape}, num_out={self.num_out},\nchannels_ks={self.channels_ks}"
            f",\n_conv_shape={self._conv_shape}"
        )
    
#EncoderBlock(1, 3, 3)(torch.rand(1, 1, 5, 7).bernoulli()).round(decimals=2)
Encoder((1, 32, 32), 10)#(torch.rand(1, 1, 32, 32).bernoulli()).round(decimals=2)

In [None]:
#VF.to_pil_image(make_grid(EncoderBlock(1, 3, 5)(torch.rand(8*8, 1, 64, 64).bernoulli())))
VF.to_pil_image(make_grid(Encoder(1, 3)(torch.rand(8*8, 1, 64, 64).bernoulli())))

In [None]:
VF.to_pil_image(make_grid([
    EncoderBlock(1, 3, 25)(torch.rand(1, 1, 64, 64).bernoulli()).squeeze(0)
    for i in range(64)
]))

In [None]:
nn.MultiheadAttention(100, 4)(torch.randn(1, 100), torch.randn(1, 100), torch.randn(1, 100))[0].shape

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(
            self,
            shape: Tuple[int, int, int],
            num_out: int,
            patch_size: int = 8,
            num_layers: int = 4,
            num_heads: int = 4,
            hidden_dim: int = 64,
            mlp_dim: Optional[int] = None,
            representation_size: Optional[int] = None,
    ):
        super().__init__()
        assert len(shape) == 3, shape
        assert shape[-2] == shape[-1], shape
        
        self.shape = shape
        self.transformer = torchvision.models.VisionTransformer(
            image_size=shape[-1],
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            num_classes=num_out,
            mlp_dim=mlp_dim or hidden_dim,
            representation_size=representation_size,
        )
        
    def forward(self, x):
        assert x.ndim == 4, x.shape
        B, C, H, W = x.shape
        assert C <= 3, x.shape
        
        if C < 3:
            x = x.expand(B, 3, H, W)
        return self.transformer(x)
    
enc = TransformerEncoder((1, 32, 32), 100) 
print(enc(torch.rand(1, 1, 32, 32)))
print(f"{num_module_parameters(enc):,}")
enc

In [None]:
import torchvision.models
torchvision.models.VisionTransformer?

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(
            self,
            num_in: int,
            shape: Tuple[int, int, int],
            patch_size: int = 8,
            num_layers: int = 4,
            num_heads: int = 4,
            hidden_dim: int = 64,
            mlp_dim: Optional[int] = None,
            representation_size: Optional[int] = None,
    ):
        super().__init__()
        assert len(shape) == 3, shape
        assert shape[-2] == shape[-1], shape
        
        self.shape = shape
        self.transformer = torchvision.models.VisionTransformer(
            image_size=shape[-1],
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            num_classes=num_out,
            mlp_dim=mlp_dim or hidden_dim,
            representation_size=representation_size,
        )
        
    def forward(self, x):
        assert x.ndim == 4, x.shape
        B, C, H, W = x.shape
        assert C <= 3, x.shape
        
        if C < 3:
            x = x.expand(B, 3, H, W)
        return self.transformer(x)

In [None]:
nn.modules.TransformerDecoderLayer?

In [None]:
layer = nn.TransformerDecoderLayer(d_model=64, nhead=8)
dec = nn.TransformerDecoder?


In [None]:
class TransformerDecoder(nn.Module):
    def __init__(
            self,
            shape: Tuple[int, int, int],
            code_size: int,
            patch_size: Union[int, Tuple[int, int]] = 8,
            stride: Union[None, int, Tuple] = None,
            num_layers: int = 4,
            num_hidden: int = 64,
            num_heads: int = 4,
            mlp_dim: Optional[int] = None,
            dropout: float = 0.1,
            activation: Union[None, str, Callable] = F.relu,
    ):
        super().__init__()
        conv_in = nn.Conv2d(shape[0], num_hidden, kernel_size=patch_size, stride=stride or patch_size)
        conv_out = nn.ConvTranspose2d(num_hidden, shape[0], kernel_size=patch_size, stride=stride or patch_size)
        self.conv_shape = conv_in(torch.empty(1, *shape)).shape[-3:]

        self.proj = nn.Linear(code_size, num_hidden)
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=num_hidden,
                nhead=num_heads,
                dim_feedforward=mlp_dim or num_hidden,
                dropout=dropout,
                activation=activation,
            ),
            num_layers=num_layers,
        )
        self.proj_out = nn.Linear(num_hidden, math.prod(self.conv_shape))
        self.patches = conv_out

    def forward(self, x):
        assert x.ndim == 2, x.shape
        y = self.proj(x)
        y = self.transformer(y, y)
        y = self.proj_out(y).view(-1, *self.conv_shape)
        y = self.patches(y)
        return y


class TransformerAutoencoder(nn.Module):
    def __init__(
            self,
            shape: Tuple[int, int, int],
            code_size: int,
            patch_size: Union[int, Tuple[int, int]] = 8,
            stride: Union[None, int, Tuple] = None,
            num_layers: int = 4,
            num_hidden: int = 64,
            num_heads: int = 4,
            mlp_dim: Optional[int] = None,
            dropout: float = 0.1,
            activation: Union[None, str, Callable] = F.relu,
    ):
        super().__init__()
        assert len(shape) == 3, shape

        self.shape = shape

        conv_in = nn.Conv2d(shape[0], num_hidden, kernel_size=patch_size, stride=stride or patch_size)
        conv_shape = conv_in(torch.empty(1, *shape)).shape[-3:]

        self.encoder = nn.Sequential(OrderedDict([
            ("patches", conv_in),
            ('flatten', nn.Flatten(-3)),
            ("proj", nn.Linear(math.prod(conv_shape), num_hidden)),
            ("transformer", nn.TransformerEncoder(
                nn.TransformerEncoderLayer(
                    d_model=num_hidden,
                    nhead=num_heads,
                    dim_feedforward=mlp_dim or num_hidden,
                    dropout=dropout,
                    activation=activation,
                ),
                num_layers=num_layers,
            )),
            ("proj_out", nn.Linear(num_hidden, code_size)),
        ]))
        self.decoder = TransformerDecoder(
            shape=shape,
            code_size=code_size,
            patch_size=patch_size,
            stride=stride,
            num_layers=num_layers,
            num_hidden=num_hidden,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout=dropout,
            activation=activation,
        )

    def forward(self, x):
        assert x.ndim == 4, x.shape
        y = self.encoder(x)
        return self.decoder(y)
    
           
TransformerAutoencoder((1, 28, 28), 76, dropout=0.77, patch_size=8, num_layers=2, num_hidden=256, num_heads=8)(torch.ones(1, 1, 28, 28)).shape

In [None]:
nn.TransformerDecoderLayer?

In [None]:
print(dec(torch.ones(1, 64), torch.zeros(1, 64)))
print(dec(torch.zeros(1, 64), torch.ones(1, 64)))
print(dec(torch.ones(1, 64), torch.ones(1, 64)))

In [None]:
from src.models.cnn import *

class ConvUpscaleDecoder(nn.Module):
    def __init__(
            self,
            shape: Tuple[int, int, int],
            code_size: int,
            kernel_size: Union[int, Tuple[int, int]] = 3,
            stride: Union[None, int, Tuple] = None,
            num_layers: int = 4,
            num_hidden: int = 64,
            upscale_every: int = 3,
            mlp_dim: Optional[int] = None,
            dropout: float = 0.1,
            activation: Union[None, str, Callable] = F.relu,
    ):
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = [kernel_size, kernel_size]
        
        def _is_upscale(i):
            return (i + 1) % upscale_every == 0
        
        # calculate shape at start of convolution
        s = list(shape[-2:])
        for i in range(num_layers - 1, -1, -1):
            if _is_upscale(i):
                for j in range(2):
                    if s[j] % 2 != 0:
                        raise ValueError(f"Upscaling in layer {i+1} requires scale divisible by 2, got {s}")
                    s[j] //= 2
            
            for j in range(2):
                s[j] -= int(math.ceil(kernel_size[j] / 2))
            for j in range(2):
                if s[j] <= 0:
                    raise ValueError(f"Convolution in layer {i+1} requires scale > 0, got {s}")
        
        start_shape = (num_hidden, *s)
        
        self.layers = nn.Sequential(OrderedDict([
            ("proj", nn.Linear(code_size, math.prod(start_shape))),
            ("reshape", Reshape(start_shape)),
        ]))
        if activation is not None:
            self.layers.add_module("act_0", activation_to_module(activation))
            
        ch_in = start_shape[0]
        ch_out = num_hidden
        for i in range(num_layers):
            is_last = i == num_layers - 1
                        
            self.layers.add_module(f"conv_{i + 1}", nn.ConvTranspose2d(
                ch_in, 
                ch_out * 4 if _is_upscale(i) else ch_out, 
                kernel_size
            ))
            if activation is not None:
                self.layers.add_module(f"act_{i + 1}", activation_to_module(activation))
            ch_in = ch_out
            
            if _is_upscale(i):
                self.layers.add_module(f"up_{i // upscale_every + 1}", nn.PixelShuffle(2))
        
        self.layers.add_module(f"conv_out", nn.ConvTranspose2d(
            ch_in, 
            shape[0], 
            1
        ))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)
    
dec = ConvUpscaleDecoder((1, 32, 32), 20, 3, num_layers=4, upscale_every=2)
print(dec(torch.ones(1, 20)).shape)
print(f"{num_module_parameters(dec):,}")
#display(VF.to_pil_image(make_grid(dec(torch.randn(4, 1, 100, 100)))))
dec

In [None]:
encoder = Conv2dBlock

In [None]:
nn.PixelUnshuffle(2)(torch.ones(1, 4, 24, 32)).shape

In [None]:
from src.models.encoder import Encoder2d

class ResidualEncoderConv2d(Encoder2d):
        
    def __init__(
            self,
            shape: Tuple[int, int, int],
            code_size: int,
            kernel_size: Union[int, Iterable[int]] = 3,
            channels: Iterable[int] = (16, 32),
            stride: int = 1,
            layers_per_block: int = 2,
            #blocks: Iterable[BlockConfig],
            act_fn: Union[None, str, Callable, nn.Module] = nn.ReLU(),
    ):
        super().__init__(shape=shape, code_size=code_size)
        self.channels = tuple(channels)
        self.kernel_size = kernel_size
        self.stride = stride
        # self.act_fn = act_fn
        
        self.conv_in = nn.Conv2d(shape[0], self.channels[0], 1)
        
        channels = [self.shape[0], *self.channels]
        if layers_per_block >= len(channels):
            raise ValueError(f"`layers_per_block` must be <= number of layers, got {layers_per_block} and {len(channels)} channels")

        self.convolution = nn.Sequential()
        num_blocks = int(math.ceil(len(channels) / layers_per_block))
        for block_idx in range(num_blocks):
            conv = Conv2dBlock(
                channels=channels[:layers_per_block],
                kernel_size=self.kernel_size,
                act_fn=act_fn,
                stride=self.stride,
            )
            self.convolution.add_module(f"conv_{block_idx+1}", conv)
            channels = channels[layers_per_block:]
        
        #with torch.no_grad():
        #    encoded_shape = self.convolution(torch.empty(shape)).shape
        #self.linear = nn.Linear(math.prod(encoded_shape), self.code_size)

    @property
    def device(self):
        return self.linear.weight.device

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #return self.linear
        return (self.convolution(x).flatten(1))

    def get_extra_state(self):
        return {
            **super().get_extra_state(),
            "kernel_size": self.kernel_size,
            "channels": self.channels,
            "act_fn": self.convolution._act_fn,
        }

    @classmethod
    def from_data(cls, data: dict):
        extra = data["_extra_state"]
        model = cls(
            shape=extra["shape"],
            kernel_size=extra["kernel_size"],
            stride=extra.get("stride", 1),
            channels=extra["channels"],
            code_size=extra["code_size"],
            act_fn=extra["act_fn"],
        )
        model.load_state_dict(data)
        return model

    
enc = ResidualEncoderConv2d(
    (1, 32, 32), 20, kernel_size=3, 
    channels=(16, 32, 48), layers_per_block=2,        
)
#print(enc(torch.ones(1, 1, 32, 32)).shape)
print(f"{num_module_parameters(enc):,}")
#display(VF.to_pil_image(make_grid(dec(torch.randn(4, 1, 100, 100)))))
enc

In [None]:
import torchvision.models
from src.models.encoder import resnet
    
enc = resnet.resnet18_open(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1).cpu()
enc.eval()
print(f"{num_module_parameters(enc):,}")
enc

In [None]:
enc(torch.ones(1, 3, 64, 64))