In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util import *
from src.util.image import * 
from src.algo import *
from src.datasets.generative import *
from src.models.cnn import *
from src.util.embedding import *
from src.models.encoder import *

In [None]:
def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for idx, entry in enumerate(tqdm(iterable, total=total)):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry) if callable(label) else idx)
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

# load model

In [None]:
if 1:
    SHAPE = (1, 32, 32)
    CODE_SIZE = 512
    from scripts.train_image_diffusion import DiffusionModel, DiffusionModel1
    
    model = DiffusionModel1(shape=SHAPE, code_size=CODE_SIZE)
    model = DiffusionModel(SHAPE, CODE_SIZE, model)
    model.load_state_dict(torch.load("../checkpoints/diff1/best.pt")["state_dict"])
    
print(f"params: {num_module_parameters(model):,}")

# load samples

In [None]:
ds = TensorDataset(
    torch.load(f"../datasets/kali-uint8-{64}x{64}.pt"),
    torch.load(f"../datasets/kali-uint8-{64}x{64}-CLIP.pt"),
)
ds = TransformDataset(
    ds,
    dtype=torch.float, multiply=1. / 255.,
    transforms=[
        #VT.CenterCrop(64),
        VT.RandomCrop(SHAPE[-2:]),
        VT.Grayscale(),
    ],
    num_repeat=1,
)
plot_samples(ds, label=True)

# create images

In [None]:
@torch.no_grad()
def create_image(code: torch.Tensor, steps: int = 64):
    code = code.unsqueeze(0)
    image = torch.randn(1, *SHAPE)
    
    images = []
    for step in range(steps - 1, -1, -1):
        noise = model.predict_noise(image, code, min(5, step))
        image = image - noise
        images.append(image.squeeze(0))
    
    display(VF.to_pil_image(
        resize(make_grid(images, nrow=16).clamp(0, 1), 3)
    ))
    
create_image(ds[0][1])

In [None]:
import torchvision.models

class DiffusionModel2(nn.Module):
    def __init__(self, shape: Tuple[int, int, int], code_size: int, step_encoding_size: int = 10):
        super().__init__()
        self.shape = shape
        self.code_size = code_size
        self.step_encoding_size = step_encoding_size
        self.transformer = torchvision.models.VisionTransformer(
            image_size=SHAPE[-1],
            patch_size=SHAPE[-1],
            num_layers=3,
            num_heads=4,
            hidden_dim=128,
            mlp_dim=128,
            dropout=0.1,
            num_classes=512,
        )
        self.layers = nn.Sequential(
            nn.Linear(512 + code_size + step_encoding_size, 1024),
            nn.GELU(),
            nn.Linear(1024, math.prod(shape)),
        )

    def forward(self, image_batch: torch.Tensor, code_batch: torch.Tensor, step: int) -> torch.Tensor:
        encoded_batch = self.transformer(image_batch.expand(-1, 3, -1, -1))
        print(encoded_batch.shape)
        step_encoding = torch.Tensor([[step]]).to(code_batch.device).expand(code_batch.shape[0], self.step_encoding_size)
        
        x = torch.concat([
            encoded_batch,
            code_batch,
            torch.sin(step_encoding * 10_000.),
        ], dim=1)
        print(x.shape)
        y = self.layers(x)

        return y.view(-1, *self.shape)

diff2 = DiffusionModel2(SHAPE, CODE_SIZE)
print(diff2(torch.randn(1, *SHAPE), torch.randn(1, CODE_SIZE), 1).shape)
diff2
#VF.to_pil_image(diff2(torch.randn(1, *SHAPE), torch.randn(1, CODE_SIZE), 1).squeeze(0))

In [None]:
torchvision.models.VisionTransformer?

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t):
        # 第一次卷积
        h = self.bnorm1(self.relu(self.conv1(x)))
        # 时间嵌入
        time_emb = self.relu(self.time_mlp(t))
        # 扩展到最后2个维度
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # 添加时间通道
        h = h + time_emb
        # 第二次卷积
        h = self.bnorm2(self.relu(self.conv2(h)))
        # 上采样或者下采样
        return self.transform(h)


class SimpleUnet(nn.Module):
    """
    Unet架构的一个简化版本
    """
    def __init__(
            self,
            image_channels: int = 3,
            down_channels: Tuple[int] = (64, 128, 256, 512, 1024),
            up_channels: Tuple[int] = (1024, 512, 256, 128, 64),
            time_emb_dim: int = 32,
            code_dim: Optional[int] = None,
    ):
        super().__init__()

        big_dim = time_emb_dim
        if code_dim is not None:
            big_dim += code_dim
        # 时间嵌入
        self.time_mlp = nn.Sequential(
            SinusoidalNumberEmbedding(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )
        if code_dim is not None:
            self.code_mlp = nn.Sequential(
                nn.Linear(code_dim, code_dim),
                nn.ReLU()
            )

        # 初始预估
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # 下采样
        self.downs = nn.ModuleList([
            Block(down_channels[i], down_channels[i+1], big_dim)
            for i in range(len(down_channels)-1)
        ])
        # 上采样
        self.ups = nn.ModuleList([
            Block(up_channels[i], up_channels[i+1], big_dim, up=True)
            for i in range(len(up_channels)-1)
        ])

        self.output = nn.Conv2d(up_channels[-1], image_channels, 1)

    def forward(self, x, timestep, code: Optional[torch.Tensor] = None):
        # 时间嵌入
        t = self.time_mlp(timestep)

        if code is not None:
            if not hasattr(self, "code_mlp"):
                raise ValueError(f"code specified in forward but no code_dim in constructor")
            t = torch.concat([t, self.code_mlp(code)], dim=-1)

        # 初始卷积
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # 添加残差结构作为额外的通道
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

    
    
unet = SimpleUnet(image_channels=1, code_dim=333)
print(f"params: {num_module_parameters(unet):,}")
o = unet(torch.ones(1, 1, 32, 32), torch.Tensor([0]), torch.ones(1, 333))
print(o.shape)
unet

In [None]:

t = torch.linspace(0, 20, 21)#.unsqueeze(1)
SinusoidalNumberEmbedding(8)(t).round(decimals=1)

In [None]:
t[:]