In [4]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

In [12]:
class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(DoubleConv,self).__init__()
    self.Conv= nn.Sequential(
        nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True),
        nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace = True)

    )
  def forward(self,x):
    return self.Conv(x)



In [13]:
class UNET(nn.Module):
  def __init__(self,in_channels=3,out_channels=1,features=[64,128,256,512]):
    super(UNET,self).__init__()
    self.downs =nn.ModuleList()
    self.ups=nn.ModuleList()
    self.pool=nn.MaxPool2d(kernel_size=2,stride=2)
    #down
    for feature in features:
      self.downs.append(DoubleConv(in_channels,feature))
      in_channels=feature

    #ups
    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2))
      self.ups.append(DoubleConv(feature*2,feature))

    #bottleneck
    self.bottleneck = DoubleConv(features[-1],features[-1]*2)

    #final conv
    self.final_conv=nn.Conv2d(features[0],out_channels,kernel_size=1)

  def forward(self,x):
    skip_connections=[]
    for down in self.downs:
      x=down(x)
      skip_connections.append(x)
      x=self.pool(x)
    x=self.bottleneck(x)
    skip_connections=skip_connections[: : -1]
    for idx in range(0,len(self.ups),2):
      x=self.ups[idx](x)
      skip_connection = skip_connections[idx//2]
      if x.shape!=skip_connection.shape:
        x=TF.resize(x,size=skip_connection.shape[2:])
      concat_skip = torch.cat((skip_connection,x),dim=1)
      x= self.ups[idx+1](concat_skip)
    return self.final_conv(x)


In [24]:
model = UNET(in_channels=1,out_channels=1)
def test():
  x= torch.randn((3,1,160,160))
  preds=model(x)
  print(preds.shape)
  print(x.shape)

In [25]:
if __name__ =="__main__":
  test()

torch.Size([3, 1, 160, 160])
torch.Size([3, 1, 160, 160])


In [26]:
try:
  import torchinfo
except:
  !pip install torchinfo
  import torchinfo
from torchinfo import summary
summary(model, input_size=[3,1,160,160])

Layer (type:depth-idx)                   Output Shape              Param #
UNET                                     [3, 1, 160, 160]          --
├─ModuleList: 1-7                        --                        (recursive)
│    └─DoubleConv: 2-1                   [3, 64, 160, 160]         --
│    │    └─Sequential: 3-1              [3, 64, 160, 160]         37,696
├─MaxPool2d: 1-2                         [3, 64, 80, 80]           --
├─ModuleList: 1-7                        --                        (recursive)
│    └─DoubleConv: 2-2                   [3, 128, 80, 80]          --
│    │    └─Sequential: 3-2              [3, 128, 80, 80]          221,696
├─MaxPool2d: 1-4                         [3, 128, 40, 40]          --
├─ModuleList: 1-7                        --                        (recursive)
│    └─DoubleConv: 2-3                   [3, 256, 40, 40]          --
│    │    └─Sequential: 3-3              [3, 256, 40, 40]          885,760
├─MaxPool2d: 1-6                         [3,