In [3]:
import torch 
from torch import nn
from torch.nn import functional as F


class Block(nn.Module):
    def __init__(self,in_channels,out_channels, stride = 1, expansion= 1,downsample:nn.Module = None):
        super(Block, self).__init__()
        self.expansion = expansion
        self.downsample = downsample

        self.conv0 = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn0 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels = out_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding  = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self,x):
        output = self.conv0(x)
        output = self.bn0(output)
        output = self.relu(output)
        output = self.conv1(output)
        output = self.bn1(output)
        output = self.relu(output)

        return output
    
class Downsample(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(Downsample, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.maxpool0 = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv0 = Block(in_channels = self.in_channels, out_channels = self.out_channels)
        
    def forward(self,x):
        output = self.maxpool0(x)
        output = self.conv0(output)
        

        return output

class Upsample(nn.Module):

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = Block(in_channels, out_channels, in_channels // 2)
        

    def forward(self, x1, x2):
        x1 = self.up(x1)
       
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
       
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)

        return x
    

class UNet(nn.Module):
    def __init__(self, img_channels,n_classes = 64):
        super(UNet,self).__init__()
        
        self.img_channels = img_channels
        self.out_channels = 64
        self.expansion = 2

        # Left-side of U
        self.layer1 = Block(in_channels = self.img_channels, out_channels = self.out_channels) # 3 - 64
        self.layer2 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion) # 64 - 128
        
        self.out_channels = self.out_channels * self.expansion
        self.layer3 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion) # 128 - 256

        self.out_channels = self.out_channels * self.expansion
        self.layer4 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion) # 256 - 512

        self.out_channels = self.out_channels * self.expansion
        self.layer5 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion // self.expansion) # 512 - 512

        # Right-side of U
        self.layer6 = Upsample(in_channels = self.out_channels * self.expansion, out_channels = self.out_channels // self.expansion) # 1024 - 256

        
        self.layer7 = Upsample(in_channels = self.out_channels , out_channels = self.out_channels // (self.expansion * self.expansion)) # 512 - 128

        self.out_channels = self.out_channels // self.expansion
        self.layer8 = Upsample(in_channels = self.out_channels , out_channels = self.out_channels // (self.expansion * self.expansion)) # 256 - 64

        self.out_channels = self.out_channels // self.expansion
        self.layer9 = Upsample(in_channels = self.out_channels , out_channels = self.out_channels // self.expansion) # 128 - 64
    
        # Output channel
        self.out_channels = self.out_channels // self.expansion
        self.outc = nn.Conv2d(self.out_channels, n_classes, kernel_size=1)

        
    def forward(self, x):
            x1 = self.layer1(x)
            x2 = self.layer2(x1)
            x3 = self.layer3(x2)
            x4 = self.layer4(x3)
            x5 = self.layer5(x4)
            x = self.layer6(x5,x4)
            x = self.layer7(x,x3)
            x = self.layer8(x,x2)
            x = self.layer9(x,x1)
            x = self.outc(x)
            return x


In [9]:
class UNet(nn.Module):
    def __init__(self, img_channels,n_classes = 64):
        super(UNet,self).__init__()
        
        self.img_channels = img_channels
        self.out_channels = 64
        self.expansion = 2

        # Left-side of U
        self.layer1 = Block(in_channels = self.img_channels, out_channels = self.out_channels) # 3 - 64
        self.layer2 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion) # 64 - 128
        
        self.out_channels = self.out_channels * self.expansion
        self.layer3 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion) # 128 - 256

        self.out_channels = self.out_channels * self.expansion
        self.layer4 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion) # 256 - 512

        self.out_channels = self.out_channels * self.expansion
        self.layer5 = Downsample(in_channels = self.out_channels, out_channels = self.out_channels * self.expansion // self.expansion) # 512 - 512

        # Right-side of U
        self.layer6 = Upsample(in_channels = self.out_channels * self.expansion, out_channels = self.out_channels // self.expansion) # 1024 - 256

        
        self.layer7 = Upsample(in_channels = self.out_channels , out_channels = self.out_channels // (self.expansion * self.expansion)) # 512 - 128

        self.out_channels = self.out_channels // self.expansion
        self.layer8 = Upsample(in_channels = self.out_channels , out_channels = self.out_channels // (self.expansion * self.expansion)) # 256 - 64

        self.out_channels = self.out_channels // self.expansion
        self.layer9 = Upsample(in_channels = self.out_channels , out_channels = self.out_channels // self.expansion) # 128 - 64
    
        # Output channel
        self.out_channels = self.out_channels // self.expansion
        self.outc = nn.Conv2d(self.out_channels, n_classes, kernel_size=1)


       # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
       # self.fc = nn.Linear(512*self.expansion, num_classes)
    
                
    def forward(self, x):
            x1 = self.layer1(x)
            x2 = self.layer2(x1)
            x3 = self.layer3(x2)
            x4 = self.layer4(x3)
            x5 = self.layer5(x4)
            x = self.layer6(x5,x4)
            x = self.layer7(x,x3)
            x = self.layer8(x,x2)
            x = self.layer9(x,x1)
            x = self.outc(x)
            return x

In [10]:
tensor = torch.rand([1, 3, 572, 572])
model = UNet(img_channels = 3)
output = model(tensor)
output.shape

torch.Size([1, 64, 572, 572])