In [22]:
import torch
import torch.nn as nn
import math
import numpy as np

In [23]:
def get_timestep_embedding(timesteps, embedding_dim:int):
    '''
    Build sinusoidal embeddings
    positional embedding
    몇번쨰 timestep에 대한 timestep embedding이 얼마냐?
    '''
    assert len(timesteps.shape) == 1
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32)*-emb)
    emb = timesteps.type(torch.float32)[:,None] * emb[None, :]
    emb = torch.concat([torch.sin(emb),torch.cos(emb)],axis=1)

    return emb


In [24]:
class DownSampling(nn.Module):
    #C(channel) 수는 그대로 하되, 이미지 크기를 줄여나감
    def __init__(self, C):
        super().__init__()
        self.conv = nn.Conv2d(C, C, kernel_size=3, stride=2, padding=1)
        # ((input + 2*padding - kernel_size )/stride) + 1
    
    def forward(self, x):
        B, C, H, W  = x.shape
        x = self.conv(x)
        assert x.shape == (B, C, H//2, W//2 )
        return x


class UpSampling(nn.Module):
    #C(channel) 수는 그대로 하되, 이미지 크기를 키움
    def __init__(self, C):
        super().__init__()
        self.conv = nn.Conv2d(C, C, kernel_size=3, stride=1, padding = 1)

    def forward(self, x):
        B,C,H,W = x.shape
        x = nn.functional.interpolate(x, size = None, scale_factor=2, mode='nearest')
        x = self.conv(x)
        assert x.shape == (B, C, H*2, W*2 )
        return x



In [34]:
class Nin(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()

        scale = 1e-10
        n = (in_dim + out_dim) / 2
        limit = np.sqrt(3*scale/n)
        self.W = nn.Parameter(torch.zeros((in_dim, out_dim), dtype=torch.float32).uniform_(-limit, limit ))
        self.b = nn.Parameter(torch.zeros((1,out_dim,1,1), dtype=torch.float32))

    def forward(self, x):
        return torch.einsum('bchw, co ->bowh' , x , self.W) + self.b

In [43]:
class ResNetBlock(nn.Module):

    def __init__(self, in_ch, out_ch, dropout_rate):
        super().__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding =1)
        self.dense = nn.Linear(512, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding =1)

        if not (in_ch == out_ch):
            self.nin = Nin(in_ch, out_ch)

        self.dropout_rate = dropout_rate
        self.nonlinearity = nn.SiLU()

    def forward(self, x, temb): #temb: Batch, dim
        '''
        param x: (B, C, H, W)
        param temb: (B, dim)
        '''
        h = self.nonlinearity(nn.functional.group_norm(x,num_groups=32))
        h = self.conv1(h)

        #add timestep embedding
        h += self.dense(self.nonlinearity(temb))[:, :, None, None]

        h = self.nonlinearity(nn.functional.group_norm(h,num_groups=32))
        h = nn.functional.dropout(h, p=self.dropout_rate)
        h = self.conv2(h)

        if not (x.shape[1]==h.shape[1]):
            x = self.nin(x)

        return x + h


In [None]:
class AttentionBlock(nn.Module):

    def __init__(self, ch):
        super().__init__()
        self.Q = Nin(ch, ch)
        self.K = Nin(ch, ch)
        self.V = Nin(ch, ch)
        self.ch = ch
        self.nin = Nin(ch, ch)

    def forward(self, x):
        B, C ,H, W = x.shape

        assert C == self.ch

        h = nn.functional.group_norm(x, num_groups=32)
        q=self.Q(h)
        k=self.K(h)
        v=self.V(h)

        w = torch.einsum('bcwh, bcWH -> bwWhH', q , k) * (int(C)**(-0.5))
        w = torch.reshape(w, [B,H,W,H*W])
        w = torch.nn.softmax(w, dim =-1)
        w = torch.reshape(w, [B,H,W,H,W])

        h = torch.einsum('bcwh, bcWH -> bwWhH', w , v)
        h= self.nin(h)

        assert x.shape == h.shape
        return x + h

In [None]:
class UNet(nn.Module):

    def __init__(self, ch=128, in_ch=1): #black white
        super().__init__()
        self.ch = ch
        self.linear1 = nn.Linear(ch, ch*4)
        self.linear2 = nn.Linear(ch*4, ch*4)

        self.conv1 = nn.Conv2d(in_ch, ch, 3, stride=1, padding=1)

        self.down = nn.ModuleList([
            ResNetBlock(ch, ch),
            ResNetBlock(ch, ch),
            DownSampling(ch),
            ResNetBlock(1 * ch, 2 * ch),
            AttentionBlock(2 * ch),
            ResNetBlock(2 * ch, 2 * ch),
            AttentionBlock(2 * ch),
            DownSampling(2 * ch),
            ResNetBlock(2 * ch, 2 * ch),
            ResNetBlock(2 * ch, 2 * ch),
            DownSampling(2 * ch),
            ResNetBlock(2 * ch, 2 * ch),
            ResNetBlock(2 * ch, 2 * ch),
        ])


    def forward(self, x, t):
        '''
        param x (torch.Tensor): batch of images (B,C,H,W)
        param t (torch.Tensor): tensor of time steps (torch.long) [B]
        '''

        #timestep embedding
        temb = get_timestep_embedding(t, self.ch)
        temb = torch.nn.functional.silu(self.linear1(temb))
        temb = self.linear2(temb)
        assert temb.shape == (t.shape[0], self.ch*4)

In [None]:
t = (torch.rand(10)*10).long()
temb = get_timestep_embedding(t, 512)

downsample = DownSampling(64)
img = torch.randn(10,64,64,64)
h = downsample(img)

upsample = UpSampling(64)
img = upsample(h)

nin = Nin(64, 64)
img = nin(img)

resnet = ResNetBlock(64, 64, 0.1)
img = resnet(img, temb)

resnet = ResNetBlock(64, 32, 0.1)
img = resnet(img, temb)

att = AttentionBlock(32)
img = att(img)
