In [36]:
import torch 
import torch.nn as nn
import numpy as np
from torch.nn import init
import os
from PIL import Image
import matplotlib.pyplot as plt
from torchinfo import summary
%matplotlib inline

In [37]:
# Generator 2

class Unet(nn.Module):

    def __init__(self,in_dim=1,conv_dim=64,out_dim=1):
        super(Unet, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(in_dim,conv_dim,kernel_size=3,stride=2,padding=1), #64
            nn.BatchNorm2d(conv_dim),
            nn.ReLU(inplace=True)
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(conv_dim,conv_dim*2,kernel_size=3,stride=2,padding=1), #32
            nn.BatchNorm2d(conv_dim*2),
            nn.ReLU(inplace=True)
        )
        self.conv3 =nn.Sequential(
            nn.Conv2d(conv_dim*2, conv_dim * 4, kernel_size=3, stride=2, padding=1), #16
            nn.BatchNorm2d(conv_dim * 4),
            nn.ReLU(inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(conv_dim * 4, conv_dim * 8, kernel_size=3, stride=2, padding=1), #8
            nn.BatchNorm2d(conv_dim * 8),
            nn.ReLU(inplace=True)
        )
        self.deconv1=nn.Sequential(
            nn.ConvTranspose2d(conv_dim * 8,conv_dim * 8,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.BatchNorm2d(conv_dim * 8),
            nn.ReLU(inplace=True)
        )
        self.deconv2=nn.Sequential(
            nn.ConvTranspose2d(conv_dim * (8+4),conv_dim * 4,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.BatchNorm2d(conv_dim * 4),
            nn.ReLU(inplace=True)
        )
        self.deconv3=nn.Sequential(
            nn.ConvTranspose2d(conv_dim * (4+2),conv_dim * 2,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.BatchNorm2d(conv_dim * 2),
            nn.ReLU(inplace=True)
        )
        self.deconv4=nn.Sequential(
            nn.ConvTranspose2d(conv_dim * (2+1),out_dim ,kernel_size=3,stride=2,padding=1,output_padding=1),
            nn.Sigmoid(),
        )
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, a=0)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            if isinstance(m,nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

                
    def forward(self, x):
        x1=self.conv1(x)
        x2=self.conv2(x1)
        x3=self.conv3(x2)
        x4=self.conv4(x3)
        out=self.deconv1(x4)
        x3=torch.cat([x3,out],dim=1)
        out=self.deconv2(x3)
        x2 = torch.cat([x2, out], dim=1)
        out=self.deconv3(x2)
        x1=torch.cat([x1,out],dim=1)
        out=self.deconv4(x1)
        return out

In [38]:
x = torch.rand((1, 1, 128, 128))
unet=Unet()
y=unet(x)
print(y.shape)

torch.Size([1, 1, 128, 128])


In [39]:
summary(model=unet,input_size=(1,1,128,128))

Layer (type:depth-idx)                   Output Shape              Param #
Unet                                     [1, 1, 128, 128]          --
├─Sequential: 1-1                        [1, 64, 64, 64]           --
│    └─Conv2d: 2-1                       [1, 64, 64, 64]           640
│    └─BatchNorm2d: 2-2                  [1, 64, 64, 64]           128
│    └─ReLU: 2-3                         [1, 64, 64, 64]           --
├─Sequential: 1-2                        [1, 128, 32, 32]          --
│    └─Conv2d: 2-4                       [1, 128, 32, 32]          73,856
│    └─BatchNorm2d: 2-5                  [1, 128, 32, 32]          256
│    └─ReLU: 2-6                         [1, 128, 32, 32]          --
├─Sequential: 1-3                        [1, 256, 16, 16]          --
│    └─Conv2d: 2-7                       [1, 256, 16, 16]          295,168
│    └─BatchNorm2d: 2-8                  [1, 256, 16, 16]          512
│    └─ReLU: 2-9                         [1, 256, 16, 16]          --
├─

In [40]:
# G1 testing 

xb = torch.randn(1, 1, 256, 256)

print(xb.shape)
# Downsampling 
c0 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)
c1 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1, bias=False)
c2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)
c3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False)
c4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False)
c5 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False)
c6 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False)
c7 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)



# Upsampling
c8 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
c9 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False)
# c10 = nn.ConvTranspose2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
c10 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False)
# c12 = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
c11 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False)
c12 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False)
c13 = nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False)
# c16 = nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1, bias=False)
c14 = nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1, bias=False)
c15 = nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1, bias=False)
c16= nn.Tanh()


out = c0(xb)
print("Layer 1: ", xb.shape, out.shape, sep="\t")

c = [c1,c2, c3, c4, c5, c6, c7,c8,c9,c10,c11,c12,c13,c14,c15,c16]

for n, layer in enumerate(c):
    out_in = out
    out = layer(out)
    print(f"Layer {n+2}: ",out_in.shape, out.shape, sep="\t")
    if layer in [c7]:
        print()



torch.Size([1, 1, 256, 256])
Layer 1: 	torch.Size([1, 1, 256, 256])	torch.Size([1, 16, 128, 128])
Layer 2: 	torch.Size([1, 16, 128, 128])	torch.Size([1, 16, 64, 64])
Layer 3: 	torch.Size([1, 16, 64, 64])	torch.Size([1, 32, 32, 32])
Layer 4: 	torch.Size([1, 32, 32, 32])	torch.Size([1, 64, 16, 16])
Layer 5: 	torch.Size([1, 64, 16, 16])	torch.Size([1, 128, 8, 8])
Layer 6: 	torch.Size([1, 128, 8, 8])	torch.Size([1, 256, 4, 4])
Layer 7: 	torch.Size([1, 256, 4, 4])	torch.Size([1, 512, 4, 4])
Layer 8: 	torch.Size([1, 512, 4, 4])	torch.Size([1, 512, 4, 4])

Layer 9: 	torch.Size([1, 512, 4, 4])	torch.Size([1, 512, 4, 4])
Layer 10: 	torch.Size([1, 512, 4, 4])	torch.Size([1, 256, 4, 4])
Layer 11: 	torch.Size([1, 256, 4, 4])	torch.Size([1, 128, 8, 8])
Layer 12: 	torch.Size([1, 128, 8, 8])	torch.Size([1, 64, 16, 16])
Layer 13: 	torch.Size([1, 64, 16, 16])	torch.Size([1, 32, 32, 32])
Layer 14: 	torch.Size([1, 32, 32, 32])	torch.Size([1, 16, 64, 64])
Layer 15: 	torch.Size([1, 16, 64, 64])	torch.Size(

In [41]:
xb = torch.randn(1, 1, 256, 256)

x0 = c0(xb)
print("x0: ", x0.shape)
x1 = c1(x0)
print("x1: ", x1.shape)
x2 = c2(x1)
print("x2: ", x2.shape)
x3 = c3(x2)
print("x3: ", x3.shape)
x4 = c4(x3)
print("x4: ", x4.shape)
x5 = c5(x4)
print("x5: ", x5.shape)
x6 = c6(x5)
print("x6: ", x6.shape)
x7 = c7(x6)
print("x7: ", x7.shape)
x8 = c8(x7)
print("x8: ", x8.shape)
x9 = c9(x8) + x5
print("x9: ", x9.shape)
x10 = c10(x9) + x4
print("x10: ", x10.shape)
x11 = c11(x10) + x3
print("x11: ", x11.shape)
x12 = c12(x11) + x2
print("x12: ", x12.shape)
x13 = c13(x12) + x1
print("x13: ", x13.shape)
x14 = c14(x13) 
print("x14: ", x14.shape)
x15 = c15(x14) 
print("x15: ", x15.shape)
x16 = c16(x15) 
print("x16: ", x16.shape)

x0:  torch.Size([1, 16, 128, 128])
x1:  torch.Size([1, 16, 64, 64])
x2:  torch.Size([1, 32, 32, 32])
x3:  torch.Size([1, 64, 16, 16])
x4:  torch.Size([1, 128, 8, 8])
x5:  torch.Size([1, 256, 4, 4])
x6:  torch.Size([1, 512, 4, 4])
x7:  torch.Size([1, 512, 4, 4])
x8:  torch.Size([1, 512, 4, 4])
x9:  torch.Size([1, 256, 4, 4])
x10:  torch.Size([1, 128, 8, 8])
x11:  torch.Size([1, 64, 16, 16])
x12:  torch.Size([1, 32, 32, 32])
x13:  torch.Size([1, 16, 64, 64])
x14:  torch.Size([1, 16, 128, 128])
x15:  torch.Size([1, 1, 256, 256])
x16:  torch.Size([1, 1, 256, 256])


In [42]:
class UnetGen(nn.Module):
    
    def __init__(self):
        super(UnetGen, self).__init__()
        
        self.c0 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)         
        )
        
        self.c1 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)         
        )
        
        self.c2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)         
        )
        
        self.c3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)         
        )
        
        self.c4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)         
        )
        
        self.c5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)         
        )
        
        self.c6 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)         
        )
        
        self.c7 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)         
        )
        
        # Upsample
        
        self.c8 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)         
        )
        
        self.c9 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)         
        )
        
        self.c10 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)         
        )
        
        self.c11 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)         
        )
        
        self.c12 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)         
        )
        
        self.c13 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)         
        )
        
        self.c14 = nn.Sequential(
            nn.ConvTranspose2d(16, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)         
        )
        
        self.c15 = nn.Sequential(
            nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)         
        )
        
        self.c16 = nn.Tanh()
        
    
    def forward(self, x):
        x0 = self.c0(x)
        x1 = self.c1(x0)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        x4 = self.c4(x3)
        x5 = self.c5(x4)
        x6 = self.c6(x5)
        x7 = self.c7(x6)
        x8 = self.c8(x7)
        x9 = self.c9(x8) + x5
        x10 = self.c10(x9) + x4
        x11 = self.c11(x10) + x3
        x12 = self.c12(x11) + x2
        x13 = self.c13(x12) + x1
        x14 = self.c14(x13) 
        x15 = self.c15(x14) 
        x16 = self.c16(x15) 

        return x16
        

In [43]:
x = torch.rand((32, 1, 256, 256))
unetgen=UnetGen()
y=unetgen(x)
print(y.shape)

torch.Size([32, 1, 256, 256])


In [46]:
summary(model=unetgen, input_size=(1,1,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
UnetGen                                  [1, 1, 256, 256]          --
├─Sequential: 1-1                        [1, 16, 128, 128]         --
│    └─Conv2d: 2-1                       [1, 16, 128, 128]         144
│    └─BatchNorm2d: 2-2                  [1, 16, 128, 128]         32
│    └─ReLU: 2-3                         [1, 16, 128, 128]         --
├─Sequential: 1-2                        [1, 16, 64, 64]           --
│    └─Conv2d: 2-4                       [1, 16, 64, 64]           2,304
│    └─BatchNorm2d: 2-5                  [1, 16, 64, 64]           32
│    └─ReLU: 2-6                         [1, 16, 64, 64]           --
├─Sequential: 1-3                        [1, 32, 32, 32]           --
│    └─Conv2d: 2-7                       [1, 32, 32, 32]           4,608
│    └─BatchNorm2d: 2-8                  [1, 32, 32, 32]           64
│    └─ReLU: 2-9                         [1, 32, 32, 32]           --
├─Sequen