In [None]:
import sys
sys.path.append("..")

import random
import math
from io import BytesIO
from pathlib import Path
from collections import OrderedDict
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid
from IPython.display import display

from src.datasets import *
from src.util.image import *
from src.util import *
from src.algo import *
from src.models.cnn import *

In [None]:
SHAPE = (3, 64, 64)
dataset = TensorDataset(torch.load(f"../datasets/kali-uint8-{SHAPE[-2]}x{SHAPE[-1]}.pt"))
dataset = TransformDataset(dataset, dtype=torch.float, multiply=1./255.)

In [None]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1):
        super().__init__()

        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)

        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu3 = nn.ReLU(inplace=True)

        self.downsample = None
        self.stride = stride

        if stride > 1 or inplanes != planes * Bottleneck.expansion:
            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
            self.downsample = nn.Sequential(OrderedDict([
                ("-1", nn.AvgPool2d(stride)),
                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
                ("1", nn.BatchNorm2d(planes * self.expansion))
            ]))

    def forward(self, x: torch.Tensor):
        identity = x

        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.relu2(self.bn2(self.conv2(out)))
        out = self.avgpool(out)
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)
        return out

m = Bottleneck(3, 4, 1)
m(torch.rand(1, 3, 4, 6)).shape

In [None]:
VF.to_pil_image(
#dataset[12][0]
m(dataset[12][0].unsqueeze(0))[0, :3].clip(0,1)
)

In [None]:
class LayerNorm(nn.LayerNorm):
    """Subclass torch's LayerNorm to handle fp16."""

    def forward(self, x: torch.Tensor):
        orig_type = x.dtype
        ret = super().forward(x.type(torch.float32))
        return ret.type(orig_type)


class QuickGELU(nn.Module):
    def forward(self, x: torch.Tensor):
        return x * torch.sigmoid(1.702 * x)
    
class ResidualAttentionBlock(nn.Module):
    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
        super().__init__()

        self.attn = nn.MultiheadAttention(d_model, n_head)
        self.ln_1 = LayerNorm(d_model)
        self.mlp = nn.Sequential(OrderedDict([
            ("c_fc", nn.Linear(d_model, d_model * 4)),
            ("gelu", QuickGELU()),
            ("c_proj", nn.Linear(d_model * 4, d_model))
        ]))
        self.ln_2 = LayerNorm(d_model)
        self.attn_mask = attn_mask

    def attention(self, x: torch.Tensor):
        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

    def forward(self, x: torch.Tensor):
        x = x + self.attention(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
        super().__init__()
        self.width = width
        self.layers = layers
        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])

    def forward(self, x: torch.Tensor):
        return self.resblocks(x)


class VisionTransformer(nn.Module):
    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
        super().__init__()
        self.input_resolution = input_resolution
        self.output_dim = output_dim
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)

        self.ln_post = LayerNorm(width)
        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)  # shape = [*, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
        x = x + self.positional_embedding.to(x.dtype)
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD

        x = self.ln_post(x[:, 0, :])

        if self.proj is not None:
            x = x @ self.proj

        return x

trans = VisionTransformer(64, patch_size=32, width=256, layers=10, heads=8, output_dim=128)
print(f"params: {num_module_parameters(trans):,}")
#trans(torch.zeros(1, 3, 64, 64))
trans#(dataset[0][0].unsqueeze(0)).shape


In [None]:
nn.Sequential(
    nn.Conv2d(3, 5, 7),
    nn.MaxPool2d(16),
)(torch.rand(1, 3, 100, 100)).shape

In [None]:
conv = Conv2dBlock([3, 5], pool_kernel_size=5)

conv(torch.rand(1, 3, 100, 100)).shape

In [None]:
import clip

In [None]:
clip_model = clip.load("ViT-B/32")

In [None]:
num_module_parameters(clip_model)
clip_model.visual

In [None]:
clip_model.visual

In [None]:

def compressed_size(img, format: str, **kwargs):
    fp = BytesIO()
    img.save(fp, format, **kwargs)
    return fp.tell()

img = VF.to_pil_image(dataset[114][0])
print(compressed_size(img, "jpeg", quality=0))
img

In [None]:
rows = []
for i in range(100):
    img = VF.to_pil_image(dataset[i][0])
    row = {}
    for q in range(0, 20, 1):
        row[f"q{q}"] = compressed_size(img, "jpeg", quality=q)
    rows.append(row)
df = pd.DataFrame(rows)
#px.line(df)
px.imshow(df.corr())

In [None]:
rows = []
for i in range(100):
    img = VF.to_pil_image(dataset[i][0])
    row = {}
    for q, o in (
        (0, False), (1, False), (1, True), (4, False), (4, True), (9, False), (9, True),
    ):
        row[f"q{q}{o}"] = compressed_size(img, "png", compress_level=q, optimize=o)
    rows.append(row)
df = pd.DataFrame(rows)
df

In [None]:
px.line(df)

In [None]:
img=VF.to_pil_image(dataset[118][0])
img

In [None]:
rows = []
ratioer = ImageCompressionRatio()
for i in range(1000, 2000):
    img = VF.to_pil_image(dataset[i][0])
    row = ratioer.all(img)
    row.update(ratioer.all(
        VF.gaussian_blur(img, kernel_size=[21, 21], sigma=10),
        suffix="-blur",
    ))
    rows.append(row)
df = pd.DataFrame(rows)
display(px.line(df))
px.imshow(df.corr())

In [None]:
df2 = df.loc[df["png-low-blur"] < df["jpeg-low"]].copy()
df2.loc[:, "diff"] = df["jpeg-low"] - df["png-high"]
df2 = df2.sort_values("diff")
#df2
#px.line(df2)
VF.to_pil_image(make_grid([dataset[i][0] for i in df2.index]))