In [1]:
import torch.nn.functional as F
from torch import nn, einsum
import torch
from einops import rearrange
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import math
from inspect import isfunction
from functools import partial

%matplotlib inline


  from .autonotebook import tqdm as notebook_tqdm


## diffusion-modelとは？

- 単純な分布から得たノイズをデータサンプルに変換するモデル
- 純粋なノイズからNNが徐々にノイズ除去の方法を学習していく

以下の二つのプロセスから構成される

### Forward diffusion process
- データにガウス分布から生成したノイズを加算していき、十分純粋なノイズとなるまでTステップ繰り返す
    - DDPMでは1000くらい
    - 十分Tが大きく、また「a well behaved schedule for adding noise at each time step」であればisotropic Gaussian distributionと見なせるらしい
        - https://math.stackexchange.com/questions/1991961/gaussian-distribution-is-isotropic

### Learned reverse denoising diffusion process
- NNが純粋なノイズから元の画像になるまでノイズ除去できるよう学習する

In [2]:
# NN実装用のヘルパー関数

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        """
        docstring
        """
        return self.fn(x, *args, **kwargs) + x



def Upsample(dim):
    """
    アップサンプル用の関数
    """
    return nn.ConvTranspose2d(dim, dim, 4,2,1)

def Downsample(dim):
    """
    ダウンサンプリング
    """
    return nn.Conv2d(dim, dim, 4,2,1)


## tの埋め込み

- transformerと同じくPositional Embeddingsを行う

In [3]:
# tのpositionalEmbeddingクラス

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        """
        dimは埋め込み先次元
        """
        super().__init__()
        self.dim = dim
    
    def forward(self, time:torch.Tensor):
        """
        tの埋め込み
        """
        device = time.device
        half_dim = self.dim // 2

        embeddings = math.log(10000) / (half_dim-1)
        embeddings = torch.exp(torch.arange(
            half_dim, device=device
        ) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


In [4]:
from time import time


class Block(nn.Module):
    """
    docstring
    """
    def __init__(self,dim, dim_out, groups=8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()
    
    def forward(self, c, scale_shift=None):
        """
        docstring
        """
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x*(scale + 1) + shift

        
        x = self.act(x)
        return x


class ResNetBlock(nn.Module):
    """
    docstring
    """

    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        """
        docstring
        """
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )
        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(
            dim, dim_out, 1
        ) if dim != dim_out else nn.Identity()
    
    def forward(self, x, time_emb=None):
        """
        docstring
        """
        h = self.block1(x)

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, "b c -> b c 1 1") + h

        h = self.block2(h)
        return h + self.res_conv(x)


class ConvNextBlock(nn.Module):
    """
    https://arxiv.org/abs/2201.03545
    """

    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        """
        docstring
        """
        super().__init__()
        self.mlp = (
            nn.Sequential(
                nn.GELU(), nn.Linear(time_emb_dim, dim)
            ) if exists(time_emb_dim) else None
        )

        self.ds_conv = nn.Conv2d(
            dim, dim, 7, padding=3, groups=dim
        )

        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out*mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out*mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1)
        )

        self.res_conv = nn.Conv2d(
            dim, dim_out, 1
        ) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        """
        docstring
        """
        h = self.ds_conv(x)

        if exists(self.mlp) and exists(time_emb):
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")
        
        h = self.net(h)

        return h + self.res_conv(x)


In [5]:
class Attention(nn.Module):
    """
    docstring
    """
    def __init__(self, dim, heads=4, dim_head=32):
        """
        docstring
        """
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        """
        docstring
        """
        b,c ,h,w = x.shape
        qkv = self.to_qkv(x).chunk(3,dim=1)
        q, k, v = map(
            lambda t: rearrange(
                t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    """
    docstring
    """
    def __init__(self, dim, heads=4, dim_head=32):
        """
        docstring
        """
        super().__init__()

        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(
            dim, hidden_dim*3, 1, bias=False
        )
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1),nn.GroupNorm(1, dim)
        )

    def forward(self, x):
        """
        docstring
        """
        b,c,h,w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(
                t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y",
                        h=self.heads, x=h, y=w)
        return self.to_out(out)


In [6]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        """
        docstring
        """
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        """
        docstring
        """
        x = self.norm(x)
        return self.fn(x)

In [7]:
from tokenize import group


class Unet(nn.Module):
    """
    docstring
    """
    def __init__(
        self, 
        dim,
        init_dim=None,
        out_dim = None,
        dim_mults=(1,2,4,8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=True,
        convnext_mult=2
    ):
        """
        docstring
        """
        super().__init__()

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [
            init_dim, *map(lambda m: dim*m, dim_mults)
        ]
        in_out = list(zip(dims[:-1], dims[1:]))

        if use_convnext:
            block_klass = partial(
                ConvNextBlock, mult=convnext_mult
            )
        else:
            block_klass = partial(
                ResNetBlock, groups=resnet_block_groups
            )

        if with_time_emb:
            time_dim = dim*4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim)
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

            mid_dim = dims[-1]
            self.mid_block1 = block_klass(
                mid_dim, mid_dim, time_emb_dim=time_dim)
            self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
            self.mid_block2 = block_klass(
                mid_dim, mid_dim, time_emb_dim=time_dim)
