In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_utils import get_param_num

In [2]:
def conv1x1_bn_relu(inp_activation, output_activation, BN=True, activation = True):
    "con 1x1 + Batchnormalization + relu"
    layer = [nn.Conv2d(inp_activation, output_activation, 1)]
    for i, j in zip([nn.BatchNorm2d(output_activation), nn.ReLU(inplace=True)],[BN, activation]):
        if j==True:
            layer.append(i)
    return nn.Sequential(*layer)

In [3]:
def conv3x3_bn_relu(inp_activation, output_activation, BN=True, activation = True):
    "con 3x3 + Batchnormalization + relu"
    layer = [nn.Conv2d(inp_activation, output_activation, 3, padding = 1)]
    for i, j in zip([nn.BatchNorm2d(output_activation), nn.ReLU(inplace=True)],[BN, activation]):
        if j==True:
            layer.append(i)
    return nn.Sequential(*layer)

In [4]:
conv3x3_bn_relu(23,34,BN = False)

Sequential(
  (0): Conv2d(23, 34, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
)

In [4]:
def conv5x5_bn_relu(inp_activation, output_activation, BN = True, activation = True):
    "con 5x5 + Batchnormalization + relu"
    layer = [nn.Conv2d(inp_activation, output_activation, 5, padding = 2)]
    for i, j in zip([nn.BatchNorm2d(output_activation), nn.ReLU(inplace=True)], [BN, activation]):
        if j==True:
            layer.append(i)
    return nn.Sequential(*layer)

In [5]:
class bottleneck_block(nn.Module):
    
    def __init__(self, input_activation, list_filter=[256, 64, 256], down = None):
        super().__init__()
        self.conv1x1_1 = conv1x1_bn_relu(input_activation, list_filter[0])
        self.conv3x3 = conv3x3_bn_relu(list_filter[0], list_filter[1])
        self.conv1x1_2 = conv1x1_bn_relu(list_filter[1], list_filter[2])
        self.down = down
        if down!=None:
            self.contract_conv = down
    def forward(self , inp):
        x = inp
        c = self.conv1x1_1(inp)
        c = self.conv3x3(c)
        result = self.conv1x1_2(c)
        if self.down!=None:
            x = self.contract_conv(inp)
        
        out = result + x
        
        return out
            


In [5]:
class conv_block(nn.Module):
    
    def __init__(self, inp_activation, list_filter, BN = False):
        super().__init__()
        self.conv3x3_1 = conv3x3_bn_relu(inp_activation, list_filter, BN=BN)
        self.conv3x3_2 = conv3x3_bn_relu(list_filter, list_filter, BN=BN)
    def forward(self , inp):
        c = self.conv3x3_1(inp)
        c = self.conv3x3_2(c)
        return c

In [6]:
conv_block(23,34)

conv_block(
  (conv3x3_1): Sequential(
    (0): Conv2d(23, 34, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
  (conv3x3_2): Sequential(
    (0): Conv2d(34, 34, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
  )
)

In [7]:
class conv_block_bottle(nn.Module):
    
    def __init__(self, inp_activation, output):
        super().__init__()
        self.input = inp_activation
        self.output = output
        
        self.bottle1 = bottleneck_block(inp_activation, [inp_activation]*3)
            
        self.bottle2 = bottleneck_block(inp_activation, [inp_activation]*2 + [output], down  = nn.Conv2d(inp_activation
                                                                                                       ,output, 1))

    def forward(self , inp):
        c = self.bottle1(inp)
        c = self.bottle2(c)
        
        return c

In [15]:
class Unet_res_b(nn.Module):
    def __init__(self, n_class):
        
        super().__init__()
        self.en_block1 = conv_block(16,32)
        self.en_block2 = conv_block(32,64)
        self.en_block3 = conv_block(64,128)
        self.en_block4 = conv_block(128,256)
        self.en_block5 = conv_block(256,512)
        self.en_block6 = conv_block(512, 1024)

        
        self.transpose5 = nn.ConvTranspose2d(1024,512,2,2)
        self.transpose4 = nn.ConvTranspose2d(512,256,2,2)

        self.transpose3 = nn.ConvTranspose2d(256,128,2,2)
        self.transpose2 = nn.ConvTranspose2d(128,64,2,2)
        self.transpose1 = nn.ConvTranspose2d(64,32,2,2)
        
        self.de_block1 = conv_block(64,32)
        self.de_block2 = conv_block(128,64)
        self.de_block3 = conv_block(256,128)

        self.de_block4 = conv_block(512, 256)
        self.de_block5 = conv_block(1024, 512)
        self.out_conv = nn.Conv2d(32, n_class, 1)

        
    

    def forward(self, inp):
        el1 = self.en_block1(inp) #  (32,h,w)
        print('el1',el1.shape)
        max1 = nn.MaxPool2d(2)(el1) # (32,h//2, w//2)
        print('max1',max1.shape)

        el2 = self.en_block2(max1)    #(64, h//2, w//2)
        print('el2',el2.shape)

        max2 = nn.MaxPool2d(2)(el2)  #(64, h//4, w//4)
        print('max2',max2.shape)

        el3 = self.en_block3(max2)    #(128, h//4, w//4)
        print('el3',el3.shape)

        max3 = nn.MaxPool2d(2)(el3)  #(128, h//8, w//8)
        print('max3',max3.shape)


        el4 = self.en_block4(max3)    #(256, h//8, w//8)
        print('el4',el4.shape)

        max4 = nn.MaxPool2d(2)(el4)  #(256, h//16, w//16)
        print('max4',max4.shape)

        el5 = self.en_block5(max4)  #(512, h//16, w//16)
        print('el5',el5.shape)

        max5 = nn.MaxPool2d(2)(el5)  #(512, h//32, w//32)
        print('max5',max5.shape)

        
        el6 = self.en_block6(max5)  #(1024, h//32, w//32)
        print('el6',el6.shape)


        tl5 = self.transpose5(el6)  #(512, h//16, w//16)
        print('tl5',tl5.shape)

        cat5 = torch.cat([tl5, el5], 1) #(1024, h//16, h//16 )
        print('cat5',cat5.shape)

        d5 =  self.de_block5(cat5)      #(512, h//16, w//16
        print('d5',d5.shape)

        
        tl4 = self.transpose4(d5)       #(256, h//8, w//8)
        cat4 = torch.cat([tl4, el4], 1) #(512, h//8, w//8)
        d4 =  self.de_block4(cat4)     #(256, h//8, w//8)
        
        tl3 = self.transpose3(d4)        #(128, h//4, w//4)
        cat3 = torch.cat([tl3, el3], 1)  #(256, h//4, w//4)
        d3 =  self.de_block3(cat3)        #(128, h//4, w//4)
        
        
        tl2 = self.transpose2(d3)          #(64, h//2, w//2)
        cat2 = torch.cat([tl2, el2], 1)   #(128, h//2, w//2)
        d2 =  self.de_block2(cat2)         #(64, h//2, w//2)
        
        tl1 = self.transpose1(d2)          #(32, h, w)
        cat1 = torch.cat([tl1, el1], 1) #(64, h, w)
        d1 =  self.de_block1(cat1)        #(32, h, w)
        output = self.out_conv(d1) 

        return output



In [10]:
m=Unet_res_b(2)

In [20]:
x = torch.rand(1,16,192,192)

In [13]:
for i in m.parameters():
    print(i.numel())

4608
32
9216
32
18432
64
36864
64
73728
128
147456
128
294912
256
589824
256
1179648
512
2359296
512
4718592
1024
9437184
1024
2097152
512
524288
256
131072
128
32768
64
8192
32
18432
32
9216
32
73728
64
36864
64
294912
128
147456
128
1179648
256
589824
256
4718592
512
2359296
512
64
2


In [21]:
m(x)

el1 torch.Size([1, 32, 192, 192])
max1 torch.Size([1, 32, 96, 96])
el2 torch.Size([1, 64, 96, 96])
max2 torch.Size([1, 64, 48, 48])
el3 torch.Size([1, 128, 48, 48])
max3 torch.Size([1, 128, 24, 24])
el4 torch.Size([1, 256, 24, 24])
max4 torch.Size([1, 256, 12, 12])
el5 torch.Size([1, 512, 12, 12])
max5 torch.Size([1, 512, 6, 6])
el6 torch.Size([1, 1024, 6, 6])
tl5 torch.Size([1, 512, 12, 12])
cat5 torch.Size([1, 1024, 12, 12])
d5 torch.Size([1, 512, 12, 12])


tensor([[[[ 1.3366e-02,  1.1545e-02,  9.4780e-03,  ...,  1.1809e-02,
            1.1548e-02,  1.6960e-02],
          [ 3.4588e-03, -1.7204e-03,  6.4900e-04,  ...,  3.4514e-03,
            3.8509e-03,  1.5615e-02],
          [-9.5516e-06, -2.8742e-03,  5.8421e-04,  ..., -2.7326e-03,
            6.8820e-03,  1.5244e-02],
          ...,
          [ 1.6847e-03, -3.3051e-03, -3.0177e-03,  ...,  4.1054e-03,
            1.3878e-03,  1.1440e-02],
          [-1.9456e-04,  9.8195e-04, -4.2372e-04,  ...,  2.5283e-03,
            6.5405e-03,  1.1661e-02],
          [ 8.6760e-03,  6.9916e-03,  6.3358e-03,  ...,  8.5623e-03,
            9.8536e-03,  1.2641e-02]],

         [[-3.2431e-02, -3.8877e-02, -4.0658e-02,  ..., -3.6740e-02,
           -3.3202e-02, -3.1175e-02],
          [-2.6763e-02, -3.9406e-02, -3.7323e-02,  ..., -3.5664e-02,
           -3.3962e-02, -3.2038e-02],
          [-2.6033e-02, -3.5824e-02, -3.3835e-02,  ..., -3.4929e-02,
           -3.3021e-02, -3.2600e-02],
          ...,
     

In [10]:
def train(model, optimizer, dataloader, grad_clip, loss):
    "train model"
    for batch in dataloader:
        model(batch)
        loss = loss
    
        optimizer.zero_grad()
        loss.backward()
        if grad_clip!= None:
            nn.utils.clip_grad_value_(model.parameters(), grad_clip)
        optimizer.step()