# Summary

```Objective```

Change the generators to be same as diffusion Unet

```Methods```

todo

```Results```

todo

# Config

# Libs

In [9]:
import torch
import torch.nn as nn
import math
import numpy as np

class Downsample(nn.Module):
    
    def __init__(self, C):
        """
        :param C (int): number of input and output channels
        """
        super(Downsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=2, padding=1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.conv(x)
        assert x.shape == (B, C, H // 2, W // 2)
        return x
    
class Upsample(nn.Module):
    
    def __init__(self, C):
        """
        :param C (int): number of input and output channels
        """
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(C, C, 3, stride=1, padding=1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        x = nn.functional.interpolate(x, size=None, scale_factor=2, mode='nearest')
  
        x = self.conv(x)
        assert x.shape == (B, C, H * 2, W * 2)
        return x
    
class Nin(nn.Module):
    
    def __init__(self, in_dim, out_dim, scale = 1e-10):
        super(Nin, self).__init__()
        
        n = (in_dim + out_dim) / 2
        limit = np.sqrt(3 * scale / n)
        self.W = torch.nn.Parameter(torch.zeros((in_dim, out_dim), dtype=torch.float32
                                               ).uniform_(-limit, limit))
        self.b = torch.nn.Parameter(torch.zeros((1, out_dim, 1, 1), dtype=torch.float32))

    def forward(self, x):    
        return torch.einsum('bchw, co->bowh', x, self.W) + self.b
    
class ResNetBlock(nn.Module):
    
    def __init__(self, in_ch, out_ch, dropout_rate=0.0):
        super(ResNetBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1)
        
        if not (in_ch == out_ch):
            self.nin = Nin(in_ch, out_ch)
            
        self.dropout_rate = dropout_rate
        self.nonlinearity = torch.nn.SiLU()
        
    def forward(self, x):
        """
        :param x: (B, C, H, W)
        """
        h = self.nonlinearity(nn.functional.instance_norm(x))
        h = self.conv1(h)
        h = self.nonlinearity(nn.functional.instance_norm(h))
        h = nn.functional.dropout(h, p=self.dropout_rate)
        h = self.conv2(h)
        
        if not (x.shape[1] == h.shape[1]):
            x = self.nin(x)
            
        assert x.shape == h.shape
        return x + h
    
class AttentionBlock(nn.Module):
    
    def __init__(self, ch):
        super(AttentionBlock, self).__init__()
        
        self.Q = Nin(ch, ch)
        self.K = Nin(ch, ch)
        self.V = Nin(ch, ch)
        
        self.ch = ch
        
        self.nin = Nin(ch, ch, scale=0.)
        
    def forward(self, x):
        
        B, C, H, W = x.shape
        assert C == self.ch
        
        h = nn.functional.group_norm(x, num_groups=32)
        q = self.Q(h)
        k = self.K(h)
        v = self.V(h)
        
        w = torch.einsum('bchw,bcHW->bhwHW', q, k) * (int(C) ** (-0.5)) # [B, H, W, H, W]
        w = torch.reshape(w, [B, H, W, H * W])
        w = torch.nn.functional.softmax(w, dim=-1)
        w = torch.reshape(w, [B, H, W, H, W])
        
        h = torch.einsum('bhwHW,bcHW->bchw', w, v)
        h = self.nin(h)
        
        assert h.shape == x.shape
        return x + h
    
class UNet(nn.Module):
    
    def __init__(self, input_nc=1, output_nc=1, ngf=128, *args, **kwargs):
        super(UNet, self).__init__()
        
        self.ch = ngf
        ch = ngf
        self.conv1 = nn.Conv2d(input_nc, ch, 3, stride=1, padding=1)
        self.down = nn.ModuleList([ResNetBlock(ch, 1 * ch),
                                   ResNetBlock(1 * ch, 1 * ch),
                                   Downsample(1 * ch),
                                   ResNetBlock(1 * ch, 2 * ch),
                                   ResNetBlock(2 * ch, 2 * ch),
                                   Downsample(2 * ch),
                                   ResNetBlock(2 * ch, 2 * ch),
                                   ResNetBlock(2 * ch, 2 * ch),
                                   Downsample(2 * ch),
                                   ResNetBlock(2 * ch, 2 * ch),
                                   ResNetBlock(2 * ch, 2 * ch)])
        
        self.middle = nn.ModuleList([ResNetBlock(2 * ch, 2 * ch),
                                     ResNetBlock(2 * ch, 2 * ch)])
        
        self.up = nn.ModuleList([ResNetBlock(4 * ch, 2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 Upsample(2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 Upsample(2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 AttentionBlock(2 * ch),
                                 ResNetBlock(4 * ch, 2 * ch),
                                 AttentionBlock(2 * ch),
                                 ResNetBlock(3 * ch, 2 * ch),
                                 AttentionBlock(2 * ch),
                                 Upsample(2 * ch),
                                 ResNetBlock(3 * ch, ch),
                                 ResNetBlock(2 * ch, ch),
                                 ResNetBlock(2 * ch, ch)]) 
        
        self.final_conv = nn.Conv2d(ch, output_nc, 3, stride=1, padding=1)
        
    def forward(self, x):
        """
        :param x (torch.Tensor): batch of images [B, C, H, W]
        :param t (torch.Tensor): tensor of time steps (torch.long) [B]
        """
        x1 = self.conv1(x)

        # Down
        x2 = self.down[0](x1)
        x3 = self.down[1](x2)
        x4 = self.down[2](x3)
        x5 = self.down[3](x4)
        x6 = self.down[4](x5)  
        x7 = self.down[5](x6)
        x8 = self.down[6](x7)
        x9 = self.down[7](x8)
        x10 = self.down[8](x9)
        x11 = self.down[9](x10)
        x12 = self.down[10](x11)
        
        # Middle
        x = self.middle[0](x12)
        x = self.middle[1](x)
        
        # Up
        x = self.up[0](torch.cat((x, x12), dim=1))
        x = self.up[1](x)
        x = self.up[2](torch.cat((x, x11), dim=1))
        x = self.up[3](torch.cat((x, x10), dim=1))
        x = self.up[4](x)
        x = self.up[5](torch.cat((x, x9), dim=1))
        x = self.up[6](x)
        x = self.up[7](torch.cat((x, x8), dim=1))
        x = self.up[8](torch.cat((x, x7), dim=1))
        x = self.up[9](x)
        x = self.up[10](torch.cat((x, x6), dim=1))
        x = self.up[11](x)
        x = self.up[12](torch.cat((x, x5), dim=1))
        x = self.up[13](x)
        x = self.up[14](torch.cat((x, x4), dim=1))
        x = self.up[15](x)
        x = self.up[16](x)
        x = self.up[17](torch.cat((x, x3), dim=1))
        x = self.up[18](torch.cat((x, x2), dim=1))
        x = self.up[19](torch.cat((x, x1), dim=1))
        
        x = nn.functional.silu(nn.functional.instance_norm(x))
        x = self.final_conv(x)
        x = nn.functional.tanh(x)
        
        return x


# Analysis

In [6]:
unet = UNet(in_channels=1, out_channels=1, ch=32)

In [11]:
test = torch.randn(1, 1, 256, 256)

out = unet(test)

AssertionError: 