In [37]:
""" Full assembly of the parts to form the complete network """

from unet.unet_parts import *


class UNetpp(nn.Module):
    def __init__(self, n_channels, n_classes, n_filters = 16, bilinear=False):
        super(UNetpp, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, n_filters))
        self.conv10 = (DoubleConv(n_filters, n_filters*2))
        self.conv20 = (DoubleConv(n_filters*2, n_filters*4))
        self.conv30 = (DoubleConv(n_filters*4, n_filters*8))
        factor = 2 if bilinear else 1
        self.conv40 = (DoubleConv(n_filters*8, n_filters*16 // factor))
        
        self.max_pool = nn.MaxPool2d(2)
        self.up01 = (Up_(n_filters*2, n_filters // factor, bilinear))
        self.up02 = (Up_(n_filters*2, n_filters // factor, bilinear))
        self.up03 = (Up_(n_filters*2, n_filters // factor, bilinear))
        self.up04 = (Up_(n_filters*2, n_filters // factor, bilinear))
        
        self.up11 = (Up_(n_filters*4, n_filters*2 // factor, bilinear))
        self.up12 = (Up_(n_filters*4, n_filters*2 // factor, bilinear))
        self.up13 = (Up_(n_filters*4, n_filters*2 // factor, bilinear))
        
        self.up21 = (Up_(n_filters*8, n_filters*4 // factor, bilinear))
        self.up22 = (Up_(n_filters*8, n_filters*4 // factor, bilinear))
        
        self.up31 = (Up_(n_filters*16, n_filters*8 // factor, bilinear))

        self.conv01 = (DoubleConv(n_filters*2, n_filters))
        self.conv02 = (DoubleConv(n_filters*3, n_filters))
        self.conv03 = (DoubleConv(n_filters*4, n_filters))
        self.conv04 = (DoubleConv(n_filters*5, n_classes)) #out
        
        self.conv11 = (DoubleConv(n_filters*4, n_filters*2))
        self.conv12 = (DoubleConv(n_filters*6, n_filters*2))
        self.conv13 = (DoubleConv(n_filters*8, n_filters*2))
        
        self.conv21 = (DoubleConv(n_filters*8, n_filters*4))
        self.conv22 = (DoubleConv(n_filters*12, n_filters*4))
        
        self.conv31 = (DoubleConv(n_filters*16, n_filters*8))
        
    def forward(self, x):
        x00 = self.inc(x)
        p00 = self.max_pool(x00)
        x10 = self.conv10(p00)
        p10 = self.max_pool(x10)
        x20 = self.conv20(p10)
        p20 = self.max_pool(x20)
        x30 = self.conv30(p20)
        p30 = self.max_pool(x30)
        x40 = self.conv40(p30)
        
        u01 = self.up01(x10)
        u11 = self.up11(x20)
        u21 = self.up21(x30)
        u31 = self.up31(x40)
        
        x01 = self.conv01(torch.cat([x00, u01], dim=1))
        x11 = self.conv11(torch.cat([x10, u11], dim=1))
        x21 = self.conv21(torch.cat([x20, u21], dim=1))
        x31 = self.conv31(torch.cat([x30, u31], dim=1))
        
        u02 = self.up02(x11)
        u12 = self.up12(x21)
        u22 = self.up22(x31)
        
        x02 = self.conv02(torch.cat([x00, x01, u02], dim=1))
        x12 = self.conv12(torch.cat([x10, x11, u12], dim=1))
        x22 = self.conv22(torch.cat([x20, x21, u22], dim=1))
        
        u03 = self.up03(x12)
        u13 = self.up13(x22)
        
        x03 = self.conv03(torch.cat([x00, x01, x02, u03], dim=1))
        x13 = self.conv13(torch.cat([x10, x11, x12, u13], dim=1))
        
        u04 = self.up04(x13)
        
        x04 = self.conv04(torch.cat([x00, x01, x02, x03, u04], dim=1))
        
        return x04

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)