In [1]:
import torch 
import torch.nn as nn


In [24]:
    
def double_conv(in_c, out_c):
    
    convolution = nn.Sequential(nn.Conv2d(in_c,out_c,kernel_size=3),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(out_c,out_c, kernel_size=3),
                        nn.ReLU(inplace=True))
    return convolution

def crop_image(tensor,target_tensor):
    
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta//2
    return tensor[:,:,delta:tensor_size-delta,delta:tensor_size-delta]
    
    
class U_NET(nn.Module):
    def __init__(self):
        super(U_NET, self).__init__()
        
        self.maxpool2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.downconv1 = double_conv(1,64)
        self.downconv2 = double_conv(64,128)
        self.downconv3 = double_conv(128,256)
        self.downconv4 = double_conv(256,512)
        self.downconv5 = double_conv(512,1024)
        
        self.uptrans1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2)
        self.upconv1 = double_conv(1024,512)
        
        self.uptrans2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2)
        self.upconv2 = double_conv(512,256)
        
        self.uptrans3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2)
        self.upconv3 = double_conv(256,128)
        
        self.uptrans4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2)
        self.upconv4 = double_conv(128,64)
        
        self.out = nn.Conv2d(in_channels=64,out_channels=2, kernel_size=1)
              
        
        
    def forward(self, x):

        x1 = self.downconv1(x)
        x2 = self.maxpool2x2(x1)
        x3 = self.downconv2(x2)
        x4 = self.maxpool2x2(x3)  
        x5 = self.downconv3(x4)
        x6 = self.maxpool2x2(x5)    
        x7 = self.downconv4(x6)
        x8 = self.maxpool2x2(x7)    
        x9 = self.downconv5(x8)
        
        x = self.uptrans1(x9)
        y = crop_image(x7,x)
        x = self.upconv1(torch.cat([x,y],1))
        
        x = self.uptrans2(x)
        y = crop_image(x5,x)
        x = self.upconv2(torch.cat([x,y],1))
        
        x = self.uptrans3(x)
        y = crop_image(x3,x)
        x = self.upconv3(torch.cat([x,y],1))
        
        x = self.uptrans4(x)
        y = crop_image(x1,x)
        x = self.upconv4(torch.cat([x,y],1))
        
        x = self.out(x)
        print(x.size())
        
if __name__ == "__main__":
    
    image = torch.rand((1,1,572,572))
    model = U_NET()
    print(model(image))

torch.Size([1, 2, 388, 388])
None
