In [2]:
import torch
import torch.nn as nn

In [8]:
# Convolutional block:
 #   It follows a two 3x3 convolutional layer, each followed by a batch normalization and a relu activation.

class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU()
    
    
    def forward(self, inputs):   #first layer it takes input
        x = self.conv1(inputs)
        #print(x.shape)
        x = self.bn1(x)#batchnorm and relu don't change anything
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

In [9]:
    #now start encoder
#   It consists of an conv_block followed by a max pooling.
#   Here the number of filters doubles and the height and width half after every block.


    class encoder_block(nn.Module): #base classs
        def __init__(self, in_c, out_c): #child class constructor
            super().__init__()

            self.conv = conv_block(in_c, out_c)
            self.pool = nn.MaxPool2d((2, 2))

        def forward(self, inputs):
            x = self.conv(inputs)
            p =self.pool(x)
            return x, p

In [10]:
# now start decoder
#The decoder block begins with a transpose convolution, followed by a concatenation with the skip
# connection from the encoder block. Next comes the conv_block.
#Here the number filters decreases by half and the height and width doubles.


class decoder_block(nn.Module):
  def __init__(self, in_c, out_c):  #basically its a constructor
    super().__init__()

    self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
    self.conv = conv_block(out_c+out_c, out_c)

  def forward(self, inputs, skip):
    x = self.up(inputs)
    x = torch.cat([x, skip], axis=1)
    #print(x.shape)
    x = self.conv(x)
    return x


       

In [14]:
 class build_unet(nn.Module):
   def __init__(self):
     super().__init__()
     
#encder
     self.e1 = encoder_block(3, 64)
     self.e2 = encoder_block(64, 128)
     self.e3 = encoder_block(128, 256)
     self.e4 = encoder_block(256, 512)


     #bottleneck
     self.b = conv_block(512, 1024)


     #Decoder
     self.d1 = decoder_block(1024, 512)
     self.d2 = decoder_block(512, 256)
     self.d3 = decoder_block(256, 128)
     self.d4 = decoder_block(128, 64)

     #Classifier 
     self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)


   def forward(self, inputs):
     #Encoder
     s1, p1 = self.e1(inputs)
     s2, p2 = self.e2(p1)
     s3, p3 = self.e3(p2)
     s4, p4 = self.e4(p3)

     #Bottleneck
     b = self.b(p4)

     #Decoder
     d1 = self.d1(b, s4)
     d2 = self.d2(d1, s3)
     d3 = self.d3(d2, s2)
     d4 = self.d4(d3, s1)     

     #Classifier
     outputs = self.outputs(d4)

     return outputs

     #print(s1.shape, p1.shape)
     #print(s2.shape, p2.shape)
     #print(s3.shape, p3.shape)
     #print(s4.shape, p4.shape)


#if __name__ == "__main__":
 #model = build_unet()
  #model(inputs)







In [16]:
if __name__ == "__main__":
  pass
    # in tensoflow we write [b, h, w, c] in pytorch [b, c, h, w]
  #  inputs = torch.randn((2, 3, 512, 512)) #specify shape

    #e = encoder_block(3, 64)
    #x, p = e(inputs)
    #print(x.shape, p.shape)

    #inputs = torch.randn((2, 64, 256, 256))
    #skip = torch.randn((2, 32, 512, 512))
    #d = decoder_block(64, 32)
    #x = d(inputs,skip)
    #print(x.shape)
    #print(inputs, skip)
  inputs = torch.randn((2, 3, 512, 512))
  model = build_unet()
  y = model(inputs)
  print(y.shape)

    #c = conv_block(3, 64)
    #c(inputs)


torch.Size([2, 1, 512, 512])
