In [17]:
import torch
import torch.nn as nn
import math
import numpy as np

In [18]:
def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    Retrieved from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py#LL90C1-L109C13
    Retrieved from https://www.udemy.com/course/diffusion-models/learn/lecture/37971218#overview
    """

    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2

    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = timesteps.type(torch.float32) [:, None] *emb[None, :]
    emb = torch.concat([torch.sin(emb), torch.cos(emb)], axis=1)

    if embedding_dim % 2 == 1: # zero pad
        emb = torch.pad(emb, [[0, 0], [0, 1]])
    
    assert emb.shape == (timesteps.shape[0], embedding_dim), f"{emb.shape}"

    return emb

In [19]:
class Downsample(nn.Module):
  
    def __init__(self,C):
        """
        :param C (int): number of input and output channels
        """
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=2, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.conv(x)
        assert x.shape == (B, C, H // 2, W // 2)
        return x

In [20]:
class Upsample(nn.Module):
  
    def __init__(self,C):
        """
        :param C (int): number of input and output channels
        """
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=1, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape

        x = nn.functional.interpolate(x, size=None, scale_factor=2, mode='nearest')

        x = self.conv(x)
        assert x.shape == (B, C, H * 2, W * 2)
        return x

In [21]:
class Nin(nn.Module):
  
    def __init__(self, in_dim, out_dim, scale = 1e-10):
        super(Nin, self).__init__()

        n = (in_dim + out_dim) / 2
        limit = np.sqrt(3 * scale / n)
        self.W = torch.nn.Parameter(torch.zeros((in_dim, out_dim), dtype=torch.float32).uniform_(-limit, limit))
        self.b = torch.nn.Parameter(torch.zeros((1, out_dim, 1, 1), dtype=torch.float32))

    def forward(self,x ):
        return torch.einsum('bchw, co->bowh', x, self.W) + self.b

In [22]:
class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate):
        super(ResNetBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
        self.dense = nn.Linear(512, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1)

        if in_ch != out_ch:
            self.nin = Nin(in_ch, out_ch)

        self.dropout_rate = dropout_rate
        self.nonlinearity = torch.nn.SiLU()

    def forward(self, x, temb):
        """
        :param x: (B, C, H, W)
        :param temb: (B, dim)
        """

        h = self.nonlinearity(nn.functional.group_norm(x, num_groups=32))
        h = self.conv1(x)

        # add in timestep embedding
        h += self.dense(self.nonlinearity(temb))[:, :, None, None]

        h = self.nonlinearity(nn.functional.group_norm(h, num_groups=32))
        h = nn.functional.dropout(h, p=self.dropout_rate)
        h = self.conv2(h)

        if x.shape[1] != h.shape[1]:
            x = self.nin(x)

        assert x.shape == h.shape
        return x + h

In [23]:
from matplotlib import scale


class AttentionBlock(nn.Module):

    def __init__(self, ch):
        super(AttentionBlock, self).__init__()

        self.Q = Nin(ch, ch)
        self.K = Nin(ch, ch)
        self.V = Nin(ch, ch)

        self.ch = ch

        self.nin = Nin(ch, ch, scale==0.)

    def forward(self, x):

        B, C, H, W = x.shape
        assert C == self.ch

        h = nn. functional.group_norm(x, num_groups=32)
        q = self.Q(h)
        k = self.K(h)
        v = self.V(h)

        w = torch.einsum('bchw,bcHW->bhwHW', q, k) * (int(C) ** (-0.5)) # [B, H, W, H, W]
        w = torch.reshape(w, [B, H, W, H * W])
        w = torch.nn.functional.softmax(w, dim=-1)
        w = torch.reshape(w, [B, H, W, H, W])

        h = torch.einsum('bhwHW,bcHW->bchw', w, v)
        h = self.nin(h)

        assert h.shape == x.shape
        return x + h

In [24]:
t = (torch.rand(10) * 10).long()
temb = get_timestep_embedding(t, 512)

downsample = Downsample(64)
img = torch.randn((10, 64, 16, 16))
img = downsample(img)

upsample = Upsample(64)
img = upsample(img)

nin = Nin(64, 128)
img = nin(img)

resnet = ResNetBlock(128, 128, 0.1)
img = resnet(img, temb)

resnet = ResNetBlock(128, 64, 0.1)
img = resnet(img, temb)

att = AttentionBlock(64)
img = att(img)

In [25]:
img.shape

torch.Size([10, 64, 16, 16])