In [1]:
import torch
import torch.nn as nn
import math
import numpy as np

In [2]:
def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    Retrieved from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py#LL90C1-L109C13
    Retrieved from https://www.udemy.com/course/diffusion-models/learn/lecture/37971218#overview
    """

    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2

    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = timesteps.type(torch.float32) [:, None] *emb[None, :]
    emb = torch.concat([torch.sin(emb), torch.cos(emb)], axis=1)

    if embedding_dim % 2 == 1: # zero pad
        emb = torch.pad(emb, [[0, 0], [0, 1]])
    
    assert emb.shape == (timesteps.shape[0], embedding_dim), f"{emb.shape}"

    return emb

In [3]:
class Downsample(nn.Module):
  
    def __init__(self,C):
        """
        :param C (int): number of input and output channels
        """
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=2, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.conv(x)
        assert x.shape == (B, C, H // 2, W // 2)
        return x

In [4]:
class Upsample(nn.Module):
  
    def __init__(self,C):
        """
        :param C (int): number of input and output channels
        """
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=1, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape

        x = nn.functional.interpolate(x, size=None, scale_factor=2, mode='nearest')

        x = self.conv(x)
        assert x.shape == (B, C, H * 2, W * 2)
        return x

In [5]:
class Nin(nn.Module):
  
    def __init__(self, in_dim, out_dim, scale = 1e-10):
        super(Nin, self).__init__()

        n = (in_dim + out_dim) / 2
        limit = np.sqrt(3 * scale / n)
        self.W = torch.nn.Parameter(torch.zeros((in_dim, out_dim), dtype=torch.float32).uniform_(-limit, limit))
        self.b = torch.nn.Parameter(torch.zeros((1, out_dim, 1, 1), dtype=torch.float32))

    def forward(self,x ):
        return torch.einsum('bchw, co->bowh', x, self.W) + self.b

In [6]:
class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate=0.1):
        super(ResNetBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
        self.dense = nn.Linear(512, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1)

        if in_ch != out_ch:
            self.nin = Nin(in_ch, out_ch)

        self.dropout_rate = dropout_rate
        self.nonlinearity = torch.nn.SiLU()

    def forward(self, x, temb):
        """
        :param x: (B, C, H, W)
        :param temb: (B, dim)
        """

        h = self.nonlinearity(nn.functional.group_norm(x, num_groups=32))
        h = self.conv1(x)

        # add in timestep embedding
        h += self.dense(self.nonlinearity(temb))[:, :, None, None]

        h = self.nonlinearity(nn.functional.group_norm(h, num_groups=32))
        h = nn.functional.dropout(h, p=self.dropout_rate)
        h = self.conv2(h)

        if x.shape[1] != h.shape[1]:
            x = self.nin(x)

        assert x.shape == h.shape
        return x + h

In [7]:
class AttentionBlock(nn.Module):

    def __init__(self, ch):
        super(AttentionBlock, self).__init__()

        self.Q = Nin(ch, ch)
        self.K = Nin(ch, ch)
        self.V = Nin(ch, ch)

        self.ch = ch

        self.nin = Nin(ch, ch, scale=0.)

    def forward(self, x):

        B, C, H, W = x.shape
        assert C == self.ch

        h = nn. functional.group_norm(x, num_groups=32)
        q = self.Q(h)
        k = self.K(h)
        v = self.V(h)

        w = torch.einsum('bchw,bcHW->bhwHW', q, k) * (int(C) ** (-0.5)) # [B, H, W, H, W]
        w = torch.reshape(w, [B, H, W, H * W])
        w = torch.nn.functional.softmax(w, dim=-1)
        w = torch.reshape(w, [B, H, W, H, W])

        h = torch.einsum('bhwHW,bcHW->bchw', w, v)
        h = self.nin(h)

        assert h.shape == x.shape
        return x + h

In [8]:
class UNet(nn.Module):

    def __init__(self, ch=128, in_ch=1):
        super(UNet, self).__init__()

        self.ch = ch
        self.linear1 = nn.Linear(ch, 4 * ch)
        self.linear2 = nn.Linear(4 * ch, 4 * ch)

        self.conv1 = nn.Conv2d(in_ch, ch, 3, stride=1, padding=1)

        # Downsampling
        self.down = nn.ModuleList([ResNetBlock(ch, 1 * ch), # 32 x 32
                                   ResNetBlock(1 * ch, 1 * ch),
                                   Downsample(1 * ch), # 16 x 16

                                   ResNetBlock(1 * ch, 2 * ch),
                                   AttentionBlock(2 * ch),

                                   ResNetBlock(2 * ch, 2 * ch),
                                   AttentionBlock(2 * ch),
                                   Downsample(2 * ch),

                                   ResNetBlock(2 * ch, 2 * ch), 
                                   ResNetBlock(2 * ch, 2 * ch),
                                   Downsample(2 * ch),

                                   ResNetBlock(2 * ch, 2 * ch), 
                                   ResNetBlock(2 * ch, 2 * ch)])
        
        # Middle
        self.middle = nn.ModuleList([ResNetBlock(2 * ch, 2 * ch),
                                     AttentionBlock(2 * ch),
                                     ResNetBlock(2 * ch, 2 * ch)])
        
        # Upscaling
        self.up = nn.ModuleList([ResNetBlock(4 * ch, 2 * ch), # 4 x 4
                                 ResNetBlock(4 * ch, 2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 Upsample(2 * ch), 

                                 ResNetBlock(4 * ch, 2 * ch), # 8 x 8
                                 ResNetBlock(4 * ch, 2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 Upsample(2 * ch), # 16 x 16

                                 ResNetBlock(4 * ch, 2 * ch),
                                 AttentionBlock(2 * ch),

                                 ResNetBlock(4 * ch, 2 * ch),
                                 AttentionBlock(2 * ch),

                                 ResNetBlock(3 * ch, 2 * ch), # 3 channel 
                                 AttentionBlock(2 * ch),
                                 Upsample(2 * ch),

                                 ResNetBlock(3 * ch, 1 * ch),
                                 ResNetBlock(2 * ch, 1 * ch),
                                 ResNetBlock(2 * ch, 1 * ch)]) 
        

        self.final_conv = nn.Conv2d(ch, in_ch, 3, stride=1, padding=1)
                                   

    def forward(self, x, t):
        """ 
        :param x (torch.Tensor): batch of images [B, C, H, W]
        :param t (torch.Tensor): tensor of time steps (torch.long) [B]
        """

        # Timestep embedding
        temb = get_timestep_embedding(t, self.ch)
        temb = torch.nn.functional.silu(self.linear1(temb))
        temb = self.linear2(temb)
        assert temb.shape == (t.shape[0], self.ch *4)

        # Downsampling
        x1 = self.conv1(x)

        x2 = self.down[0](x1, temb)
        x3 = self.down[1](x2, temb)
        x4 = self.down[2](x3)
        x5 = self.down[3](x4, temb)
        x6 = self.down[4](x5)   # Attention
        x7 = self.down[5](x6, temb)
        x8 = self.down[6](x7)   # Attention
        x9 = self.down[7](x8)
        x10 = self.down[8](x9, temb)
        x11 = self.down[9](x10, temb)
        x12 = self.down[10](x11)
        x13 = self.down[11](x12, temb)
        x14 = self.down[12](x13, temb)

        # Middle
        x = self.middle[0](x14, temb)
        x = self.middle[1](x)
        x = self.middle[2](x, temb)

        # Upsampling
        x = self.up[0](torch.cat((x, x14), dim=1), temb)
        x = self.up[1](torch.cat((x, x13), dim=1), temb)
        x = self.up[2](torch.cat((x, x12), dim=1), temb)
        x = self.up[3](x)
        x = self.up[4](torch.cat((x, x11), dim=1), temb)
        x = self.up[5](torch.cat((x, x10), dim=1), temb)
        x = self.up[6](torch.cat((x, x9), dim=1), temb)
        x = self.up[7](x)
        x = self.up[8](torch.cat((x, x8), dim=1), temb)
        x = self.up[9](x)
        x = self.up[10](torch.cat((x, x6), dim=1), temb)
        x = self.up[11](x)
        x = self.up[12](torch.cat((x, x4), dim=1), temb)
        x = self.up[13](x)
        x = self.up[14](x)
        x = self.up[15](torch.cat((x, x3), dim=1), temb)
        x = self.up[16](torch.cat((x, x2), dim=1), temb)
        x = self.up[17](torch.cat((x, x1), dim=1), temb)


        x = nn.functional.silu(nn.functional.group_norm(x, num_groups=32))
        x = self.final_conv(x)

        return x


In [9]:
t = (torch.rand(10) * 10).long()
temb = get_timestep_embedding(t, 512)

downsample = Downsample(64)
img = torch.randn((10, 64, 16, 16))
img = downsample(img)

upsample = Upsample(64)
img = upsample(img)

nin = Nin(64, 128)
img = nin(img)

resnet = ResNetBlock(128, 128, 0.1)
img = resnet(img, temb)

resnet = ResNetBlock(128, 64, 0.1)
img = resnet(img, temb)

att = AttentionBlock(64)
img = att(img)

img = torch.randn((10, 1, 32, 32))
model = UNet()
img = model(img, t)

In [10]:
img.shape

torch.Size([10, 1, 32, 32])

In [11]:
sum([p.numel() for p in model.parameters()]) / 1e6

35.713281