In [1]:
import numpy as np
import torch
import torch.nn as nn
import math

torch.set_printoptions(profile="full", linewidth=200, precision=2)

In [2]:
def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1
    
    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = timesteps.type(torch.float32)[:, None] * emb[None, :]
    emb = torch.concat([torch.sin(emb), torch.cos(emb)], axis=1)

    if embedding_dim % 2 == 1:  # zero pad
      emb = torch.pad(emb, [[0, 0], [0, 1]])
      
    assert emb.shape == (timesteps.shape[0], embedding_dim), f"{emb.shape}"
    return emb

In [3]:
t = (torch.rand(100)*10).long()
t.shape

torch.Size([100])

In [4]:
get_timestep_embedding(t, 64).shape

torch.Size([100, 64])

In [5]:
class Downsample(nn.Module):
    def __init__(self, C):
        """
        param C:input and output channels
        """
        super(Downsample, self).__init__()
        self.C = C
        self.conv = nn.Conv2d(in_channels=C, out_channels=C, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x =  self.conv(x)
        assert x.shape == (B, C, H // 2, W // 2)
        return x

In [6]:
t = (torch.rand(100)*10).long()
emb = get_timestep_embedding(t, 64)
print(emb.shape)

model = Downsample(64)
img = torch.randn((10, 64, 400, 400))
out = model(img)
out.shape

torch.Size([100, 64])


torch.Size([10, 64, 200, 200])

In [7]:
class Upsample(nn.Module):
    def __init__(self, C):
        """
        param C:input and output channels
        """
        super(Upsample, self).__init__()
        self.C = C
        self.conv = nn.Conv2d(in_channels=C, out_channels=C, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = nn.functional.interpolate(x, size=None, scale_factor=2, mode='nearest-exact')
        assert x.shape == (B, C, H * 2, W * 2)
        x =  self.conv(x)
        assert x.shape == (B, C, H * 2, W * 2)
        return x

In [8]:
t = (torch.rand(100)*10).long()
emb = get_timestep_embedding(t, 64)
print(emb.shape)

downsample = Downsample(64)
img = torch.randn((10, 64, 400, 400))
h = downsample(img)
print(h.shape)

upsample = Upsample(64)
out = upsample(h)
print(out.shape)

torch.Size([100, 64])
torch.Size([10, 64, 200, 200])
torch.Size([10, 64, 400, 400])


In [9]:
class Nin(nn.Module):
    """
    Basic block of ResNet
    it is like applying MLP to a 2D image and modifying the number of channels
    """
    def __init__(self, in_dim, out_dim, scale = 1e-10):
        super(Nin, self).__init__()
        n= (in_dim + out_dim) / 2
        limit = np.sqrt(3 * scale / n)
        self.W = torch.nn.Parameter(torch.zeros((in_dim, out_dim), dtype= torch.float32
                                              ).uniform_(-limit, limit))
        self.b = torch.nn.Parameter(torch.zeros((1, out_dim, 1 , 1), dtype= torch.float32))

    def forward(self, x):
        return torch.einsum("bchw, co->bowh", x, self.W) + self.b

In [10]:
t = (torch.rand(100)*10).long()
emb = get_timestep_embedding(t, 64)
print(emb.shape)

downsample = Downsample(64)
img = torch.randn((10, 64, 400, 400))
h = downsample(img)
print(h.shape)

upsample = Upsample(64)
img = upsample(h)
print(img.shape)

nin = Nin(64, 128)
print(nin(img).shape)

torch.Size([100, 64])
torch.Size([10, 64, 200, 200])
torch.Size([10, 64, 400, 400])
torch.Size([10, 128, 400, 400])


In [11]:
class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate=0.1):
        super(ResNetBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)
        self.dense = nn.Linear(512, out_ch)
        self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)

        # needed to do the skip connection
        if not (in_ch == out_ch):
            self.nin = Nin(in_ch, out_ch)

        self.dropout_rate = dropout_rate
        self.nonlinearity = nn.SiLU()

    def forward(self, x, temb): # temb [batch_size, 512] -> B, out_ch
        """
        param x:    (B, C, H, W)
        param temb: (B, dim)
                
        """
        print(f"\nx shape: {x.shape}, temb shape: {temb.shape}")
        #print(f"{}")
        h = self.nonlinearity(nn.functional.group_norm(x, num_groups=32))
        h= self.conv1(h)
        print(f"after conv1: {h.shape}")
        
        # add in timestep embedding
        h +=  self.dense(self.nonlinearity(temb))[:, :, None, None]
        h = self.nonlinearity(nn.functional.group_norm(h, num_groups=32))
        
        h= nn.functional.dropout(h, p=self.dropout_rate)
        print(f"after non linearity: {self.dense(self.nonlinearity(temb))[:, :, None, None].shape}")
        h = self.conv2(h)
        print(f"after conv2: {h.shape}")

        if not (x.shape[1] == h.shape[1]):
            print("reshaping x")
            x = self.nin(x)
            
        print(f"before output x shape: {x.shape}, temb shape: {h.shape}")
        assert x.shape == h.shape
        return x + h

In [12]:
t = (torch.rand(10)*10).long()
print(f"t shape: {t.shape}")
temb = get_timestep_embedding(t, 512)
print(f"embedded t shape: {temb.shape}")

downsample = Downsample(64)
img = torch.randn((10, 64, 128, 128))
h = downsample(img)
print(f"downsampled image shape: {h.shape}")

upsample = Upsample(64)
img = upsample(h)
print(f"upsampled image shape: {img.shape}")

nin = Nin(64, 128)
img = nin(img)
print(f"after Nin application shape: {img.shape}")

resnet = ResNetBlock(128, 128, 0.1)
img = resnet(img, temb)
print(f"final x + h shape: {img.shape}")

resnet = ResNetBlock(128, 64, 0.1)
img = resnet(img, temb)
print(f"final x + h shape: {img.shape}")

t shape: torch.Size([10])
embedded t shape: torch.Size([10, 512])
downsampled image shape: torch.Size([10, 64, 64, 64])
upsampled image shape: torch.Size([10, 64, 128, 128])
after Nin application shape: torch.Size([10, 128, 128, 128])

x shape: torch.Size([10, 128, 128, 128]), temb shape: torch.Size([10, 512])
after conv1: torch.Size([10, 128, 128, 128])
after non linearity: torch.Size([10, 128, 1, 1])
after conv2: torch.Size([10, 128, 128, 128])
before output x shape: torch.Size([10, 128, 128, 128]), temb shape: torch.Size([10, 128, 128, 128])
final x + h shape: torch.Size([10, 128, 128, 128])

x shape: torch.Size([10, 128, 128, 128]), temb shape: torch.Size([10, 512])
after conv1: torch.Size([10, 64, 128, 128])
after non linearity: torch.Size([10, 64, 1, 1])
after conv2: torch.Size([10, 64, 128, 128])
reshaping x
before output x shape: torch.Size([10, 64, 128, 128]), temb shape: torch.Size([10, 64, 128, 128])
final x + h shape: torch.Size([10, 64, 128, 128])


In [13]:
class AttentionBlock(nn.Module):
    def __init__(self, ch):
        super(AttentionBlock, self).__init__()

        self.Q = Nin(ch, ch)
        self.K = Nin(ch, ch)
        self.V = Nin(ch, ch)
        self.ch = ch
        self.nin = Nin(ch, ch, scale = 0.)

    def forward(self, x):
        B, C, H, W = x.shape
        assert C == self.ch

        h = nn.functional.group_norm(x, num_groups=32)
        q = self.Q(h)
        k = self.K(h)
        v = self.V(h)

        w = torch.einsum('bchw,bcHW->bhwHW', q, k) * (int(C) ** (-0.5)) # [B, H, W, H, W]
        w = torch.reshape(w, [B, H, W, H * W])
        w = torch.nn.functional.softmax(w, dim=-1)
        w = torch.reshape(w, [B, H, W, H, W])

        h = torch.einsum('bhwHW,bcHW->bchw', w, v)
        h = self.nin(h)

        assert h.shape == x.shape
        
        return x + h

In [14]:
t = (torch.rand(10)*10).long()
print(f"t shape: {t.shape}")
temb = get_timestep_embedding(t, 512)
print(f"embedded t shape: {temb.shape}")

downsample = Downsample(64)
img = torch.randn((10, 64, 16, 16))
h = downsample(img)
print(f"downsampled image shape: {h.shape}")

upsample = Upsample(64)
img = upsample(h)
print(f"upsampled image shape: {img.shape}")

nin = Nin(64, 128)
img = nin(img)
print(f"after Nin application shape: {img.shape}")

resnet = ResNetBlock(128, 128, 0.1)
img = resnet(img, temb)
print(f"final x + h shape: {img.shape}")

resnet = ResNetBlock(128, 64, 0.1)
img = resnet(img, temb)
print(f"final x + h shape: {img.shape}")


att = AttentionBlock(64)
img = att(img)
print(f"attention block output shape: {img.shape}")

t shape: torch.Size([10])
embedded t shape: torch.Size([10, 512])
downsampled image shape: torch.Size([10, 64, 8, 8])
upsampled image shape: torch.Size([10, 64, 16, 16])
after Nin application shape: torch.Size([10, 128, 16, 16])

x shape: torch.Size([10, 128, 16, 16]), temb shape: torch.Size([10, 512])
after conv1: torch.Size([10, 128, 16, 16])
after non linearity: torch.Size([10, 128, 1, 1])
after conv2: torch.Size([10, 128, 16, 16])
before output x shape: torch.Size([10, 128, 16, 16]), temb shape: torch.Size([10, 128, 16, 16])
final x + h shape: torch.Size([10, 128, 16, 16])

x shape: torch.Size([10, 128, 16, 16]), temb shape: torch.Size([10, 512])
after conv1: torch.Size([10, 64, 16, 16])
after non linearity: torch.Size([10, 64, 1, 1])
after conv2: torch.Size([10, 64, 16, 16])
reshaping x
before output x shape: torch.Size([10, 64, 16, 16]), temb shape: torch.Size([10, 64, 16, 16])
final x + h shape: torch.Size([10, 64, 16, 16])
attention block output shape: torch.Size([10, 64, 16, 1

### UNet Neural Network

In [15]:
class ResNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_rate=0.1):
        super(ResNetBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)
        self.dense = nn.Linear(512, out_ch)
        self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1)

        # needed to do the skip connection
        if not (in_ch == out_ch):
            self.nin = Nin(in_ch, out_ch)

        self.dropout_rate = dropout_rate
        self.nonlinearity = nn.SiLU()

    def forward(self, x, temb): # temb [batch_size, 512] -> B, out_ch
        """
        param x:    (B, C, H, W)
        param temb: (B, dim)
                
        """
        h = self.nonlinearity(nn.functional.group_norm(x, num_groups=32))
        h= self.conv1(h)
        # add in timestep embedding
        h +=  self.dense(self.nonlinearity(temb))[:, :, None, None]
        h = self.nonlinearity(nn.functional.group_norm(h, num_groups=32))
        h= nn.functional.dropout(h, p=self.dropout_rate)
        h = self.conv2(h)


        if not (x.shape[1] == h.shape[1]):
            x = self.nin(x)

        assert x.shape == h.shape
        return x + h

In [30]:
class UNet(nn.Module):
    def __init__(self, ch=128, in_ch=1):
        super(UNet, self).__init__()

        self.ch = ch
        self.in_ch = in_ch
        self.linear1 = nn.Linear(ch, 4 * ch)
        self.linear2 = nn.Linear(ch * 4, ch * 4)
    
        self.conv1 = nn.Conv2d(in_ch, ch, 3, stride=1, padding=1)
        self.down = nn.ModuleList([ ResNetBlock(ch, 1 * ch), # [32, 32] # first block
                                    ResNetBlock(1 * ch, 1 * ch),  # first block
                                    Downsample(1 * ch), # [16, 16] # first block
                                   
                                    ResNetBlock(1 * ch, 2 * ch), # second block
                                    AttentionBlock(2 * ch), # second block
                                    ResNetBlock(2 * ch, 2 * ch), # second block
                                    AttentionBlock(2 * ch), # second block
                                    Downsample(2 * ch), # [16, 16] # second block

                                    ResNetBlock(2 * ch, 2 * ch), # third block
                                    ResNetBlock(2 * ch, 2 * ch), # third block
                                    Downsample(2 * ch), # [16, 16] # third block

                                    ResNetBlock(2 * ch, 2 * ch), # forth block
                                    ResNetBlock(2 * ch, 2 * ch), # forth block
        ])
        
        self.middle = nn.ModuleList([ResNetBlock(2 * ch, 2 * ch),
                                     AttentionBlock(2 * ch),
                                     ResNetBlock(2 * ch, 2 * ch)
        ])

        self.up = nn.ModuleList([ResNetBlock(4 * ch, 2 * ch), # first block [4, 4]
                                 ResNetBlock(4 * ch, 2 * ch), # first block
                                 ResNetBlock(4 * ch, 2 * ch), # first block
                                 Upsample(2 * ch), # first block [8, 8]

                                 ResNetBlock(4 * ch, 2 * ch), # second block [8, 8]
                                 ResNetBlock(4 * ch, 2 * ch), # second block
                                 ResNetBlock(4 * ch, 2 * ch), # second block
                                 Upsample(2 * ch), # second block [16, 16]

                                 ResNetBlock(4 * ch, 2 * ch), # third block [16, 16]
                                 AttentionBlock(2 * ch),  # third block
                                 ResNetBlock(4 * ch, 2 * ch), # third block
                                 AttentionBlock(2 * ch),  # third block
                                 ResNetBlock(3 * ch, 2 * ch), # third block
                                 AttentionBlock(2 * ch),  # third block
                                 Upsample(2 * ch), # third block
                                 
                                 ResNetBlock(3 * ch, ch), # forth block
                                 ResNetBlock(2 * ch, ch), # forth block
                                 ResNetBlock(2 * ch, ch), # forth block     
        ])

        self.final_conv = nn.Conv2d(ch, in_ch, 3, stride=1, padding=1)

    def forward(self, x, t):
        """
        param x (torch.Tensor): batch of of images [B, C, H, W]
        param t (torch.Tensor): tensor of time steps (torch.long) [B]
        """
        temb = get_timestep_embedding(t, self.ch)
        temb = torch.nn.functional.silu(self.linear1(temb))
        temb = self.linear2(temb)
        assert temb.shape == (t.shape[0], self.ch*4)
        print(f"\ntemb after tansformation shape: {temb.shape}")

        x1 = self.conv1(x)
        print(f"x1 after conv1 shape: {x1.shape}")

        # DownSampling
        x2  = self.down[0](x1, temb)
        x3  = self.down[1](x2, temb)
        x4  = self.down[2](x3)
        x5  = self.down[3](x4, temb)
        x6  = self.down[4](x5) # Attention
        x7  = self.down[5](x6, temb)
        x8  = self.down[6](x7) # Attention
        x9  = self.down[7](x8)
        x10 = self.down[8](x9, temb)
        x11 = self.down[9](x10, temb)
        x12 = self.down[10](x11)
        x13 = self.down[11](x12, temb)
        x14 = self.down[12](x13, temb)
        print(f"output from Downsampling block {x14.shape}")

        # Middle
        x = self.middle[0](x14, temb)
        x = self.middle[1](x)
        x = self.middle[2](x, temb)
        print(f"output from Middle block {x.shape}")

        # UpSampling
        x = self.up[0](torch.cat((x, x14), dim=1), temb)
        x = self.up[1](torch.cat((x, x13), dim=1), temb)
        x = self.up[2](torch.cat((x, x12), dim=1), temb)
        x = self.up[3](x)
        x = self.up[4](torch.cat((x, x11), dim=1), temb)
        x = self.up[5](torch.cat((x, x10), dim=1), temb)
        x = self.up[6](torch.cat((x, x9), dim=1), temb)
        x = self.up[7](x)
        x = self.up[8](torch.cat((x, x8), dim=1), temb)
        x = self.up[9](x)
        x = self.up[10](torch.cat((x, x6), dim=1), temb)
        x = self.up[11](x)
        x = self.up[12](torch.cat((x, x4), dim=1), temb)
        x = self.up[13](x)
        x = self.up[14](x)
        x = self.up[15](torch.cat((x, x3), dim=1), temb)
        x = self.up[16](torch.cat((x, x2), dim=1), temb)
        x = self.up[17](torch.cat((x, x1), dim=1), temb)

        x = torch.nn.functional.silu(nn.functional.group_norm(x, num_groups=32))
        x = self.final_conv(x)
        
        return x

In [42]:
img = torch.randn((10, 1, 32, 32))
print(f"image input shape to UNet {img.shape}")
unet =UNet(in_ch=1)
out = unet(img, t)
print(f"image output shape from UNet {out.shape}")

image input shape to UNet torch.Size([10, 1, 32, 32])

temb after tansformation shape: torch.Size([10, 512])
x1 after conv1 shape: torch.Size([10, 128, 32, 32])
output from Downsampling block torch.Size([10, 256, 4, 4])
output from Middle block torch.Size([10, 256, 4, 4])
image output shape from UNet torch.Size([10, 1, 32, 32])


In [48]:
sum([p.numel() for p in unet.parameters()]) / 1e6

35.713281