# Implementacion de U-NET
@juan1rving

En este notebook implemento la red conocida como Unet, la cual es utiliza para realizar segmentación semántica. 

![Drag Racing](unet.png)

El paper original es:

> Ronneberger, O., Fischer, P., & Brox, T. (2015, October). U-net: Convolutional networks for biomedical image segmentation. In International Conference on Medical image computing and computer-assisted intervention (pp. 234-241). Springer, Cham.

In [6]:
import torch
import torch.nn as nn


In [7]:
# Esta funcion lleva a cabo la doble convolución que se utiliza en el paper original. 
# En cada nivel se implementa.

def double_conv(input_c, output_c):
    conv = nn.Sequential(
        nn.Conv2d(input_c, output_c, kernel_size = 3),
        nn.ReLU(inplace = True),
        nn.Conv2d(output_c, output_c, kernel_size = 3),
        nn.ReLU(inplace = True)
    )
    return conv



In [11]:
# Corta el contorno del sensor para que se pueda concatenar a la convolución transpuesta

def crop_tensor(input_tensor, target_tensor):
    target_size = target_tensor.size()[2]
    input_size = input_tensor.size()[2]
    delta = input_size - target_size # assuming that input is larger
    delta = delta // 2
    return input_tensor[:, :, delta:input_size-delta, delta:input_size-delta]
    

In [46]:
# Implementacion de UNET 

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Maxpool reutilizable
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size = 2, stride = 2)

        # Encoder
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)
        
        # Decoder
        self.up_tconv_5 = nn.ConvTranspose2d(in_channels = 1024, out_channels = 512, kernel_size = 2, stride = 2)
        self.up_conv_4 = double_conv(1024,512)
        
        self.up_tconv_4 = nn.ConvTranspose2d(in_channels = 512, out_channels = 256, kernel_size = 2, stride = 2)
        self.up_conv_3 = double_conv(512,256)
        
        self.up_tconv_3 = nn.ConvTranspose2d(in_channels = 256, out_channels = 128, kernel_size = 2, stride = 2)
        self.up_conv_2 = double_conv(256,128)
        
        self.up_tconv_2 = nn.ConvTranspose2d(in_channels = 128, out_channels = 64, kernel_size = 2, stride = 2)
        self.up_conv_1 = double_conv(128,64)
        
        # Conversion a 2 canales
        self.out = nn.Conv2d(in_channels = 64, out_channels = 2, kernel_size = 1)
        
    def forward(self, image):
        # encoder
        
        x1 = self.down_conv_1(image)
        print(x1.size())
        x1_down = self.max_pool_2x2(x1)
        print(x1_down.size())
        
        x2 = self.down_conv_2(x1_down)
        x2_down = self.max_pool_2x2(x2)
        print(x2_down.size())
        
        x3 = self.down_conv_3(x2_down)
        x3_down = self.max_pool_2x2(x3)
        print(x3_down.size())
        
        x4 = self.down_conv_4(x3_down)
        x4_down = self.max_pool_2x2(x4)
        print(x4_down.size())
        
        x5 = self.down_conv_5(x4_down)
        print(x5.size())
        
        #decoder
        y4 = self.up_tconv_5(x5)
        print(y4.size())
        crop4 = crop_tensor(x4, y4)

        y4_up = self.up_conv_4(torch.cat([crop4, y4], 1))
        print(y4_up.size())
        
        y3 = self.up_tconv_4(y4_up)
        crop3 = crop_tensor(x3, y3)
        y3_up = self.up_conv_3(torch.cat([crop3, y3],1) )
        print(y3_up.size())
                
        y2 = self.up_tconv_3(y3_up)
        crop2 = crop_tensor(x2, y2)
        y2_up = self.up_conv_2(torch.cat([crop2, y2],1) )
        print(y2_up.size())
        
        y1 = self.up_tconv_2(y2_up)
        crop1 = crop_tensor(x1, y1)
        
        y1_fin = self.up_conv_1(torch.cat([crop1, y1],1) )
        print(y1_fin.size())
        
        output =  self.out(y1_fin)
        print(output.size())
        
        return output                 

In [47]:
# batch channels y x
image = torch.rand((1,1,572,572))

In [48]:
model = UNet()

#print(model)

In [49]:
model(image)

torch.Size([1, 64, 568, 568])
torch.Size([1, 64, 284, 284])
torch.Size([1, 128, 140, 140])
torch.Size([1, 256, 68, 68])
torch.Size([1, 512, 32, 32])
torch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 56, 56])
torch.Size([1, 512, 52, 52])
torch.Size([1, 256, 100, 100])
torch.Size([1, 128, 196, 196])
torch.Size([1, 64, 388, 388])
torch.Size([1, 2, 388, 388])


tensor([[[[0.1318, 0.1312, 0.1292,  ..., 0.1303, 0.1284, 0.1317],
          [0.1297, 0.1334, 0.1307,  ..., 0.1280, 0.1305, 0.1279],
          [0.1303, 0.1321, 0.1306,  ..., 0.1310, 0.1286, 0.1308],
          ...,
          [0.1303, 0.1313, 0.1310,  ..., 0.1317, 0.1285, 0.1292],
          [0.1290, 0.1295, 0.1294,  ..., 0.1289, 0.1291, 0.1307],
          [0.1309, 0.1271, 0.1318,  ..., 0.1280, 0.1293, 0.1331]],

         [[0.0526, 0.0530, 0.0540,  ..., 0.0548, 0.0540, 0.0537],
          [0.0544, 0.0560, 0.0537,  ..., 0.0538, 0.0501, 0.0560],
          [0.0543, 0.0534, 0.0534,  ..., 0.0544, 0.0539, 0.0545],
          ...,
          [0.0547, 0.0533, 0.0518,  ..., 0.0516, 0.0506, 0.0539],
          [0.0538, 0.0527, 0.0561,  ..., 0.0533, 0.0549, 0.0557],
          [0.0543, 0.0513, 0.0568,  ..., 0.0518, 0.0493, 0.0518]]]],
       grad_fn=<MkldnnConvolutionBackward>)