In [1]:
import torch
import torch.nn as nn
import math
import numpy as np

In [2]:
def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    Retrieved from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py#LL90C1-L109C13
    Retrieved from https://www.udemy.com/course/diffusion-models/learn/lecture/37971218#overview
    """

    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2

    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = timesteps.type(torch.float32) [:, None] *emb[None, :]
    emb = torch.concat([torch.sin(emb), torch.cos(emb)], axis=1)

    if embedding_dim % 2 == 1: # zero pad
        emb = torch.pad(emb, [[0, 0], [0, 1]])
    
    assert emb.shape == (timesteps.shape[0], embedding_dim), f"{emb.shape}"

    return emb

In [3]:
class Downsample(nn.Module):
  
    def __init__(self,C):
        """
        :param C (int): number of input and output channels
        """
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=2, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.conv(x)
        assert x.shape == (B, C, H // 2, W // 2)
        return x

In [4]:
class Upsample(nn.Module):
  
    def __init__(self,C):
        """
        :param C (int): number of input and output channels
        """
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=1, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape

        x = nn.functional.interpolate(x, size=None, scale_factor=2,
                                      mode='nearest')

        x = self.conv(x)
        assert x.shape == (B, C, H * 2, W * 2)
        return x

In [5]:
class Nin(nn.Module):
  
    def __init__(self, in_dim, out_dim):
        super(Nin, self).__init__()

        scale = 1e-10
        n = (in_dim + out_dim) / 2
        limit = np.sqrt(3 * scale / n)
        self.W = torch.nn.Parameter(torch.zeros((in_dim, out_dim), dtype=torch.float32).uniform_(-limit, limit))
        self.b = torch.nn.Parameter(torch.zeros((1, out_dim, 1, 1), dtype=torch.float32))

    def forward(self,x ):
        return torch.einsum('bchw, co->bowh', x, self.W) + self.b

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate):
        super(ResNetBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
        self.dense = nn.Linear(512, out_ch)
        self.conv2 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)

        if in_ch != out_ch:
            self.Nin(in_ch, out_ch)

        self.dropout_rate = dropout_rate



In [8]:
t = (torch.rand(100) * 10).long()
get_timestep_embedding(t, 64)

downsample = Downsample(64)
img = torch.randn((10, 64, 400, 400))
h = downsample(img)

upsample = Upsample(64)
img = upsample(h)
print(img.shape)

nin = Nin(64, 128)
nin(img).shape

torch.Size([10, 64, 400, 400])


torch.Size([10, 128, 400, 400])