In [1]:
import torch

In [7]:
class Down_Conv(torch.nn.Module):
            def __init__(self, in_dims, allow_crop = True):
            
                assert in_dims[1] == in_dims[2]  # for now, only allow quadratic input images
                
                super(Down_Conv, self).__init__()
                self.in_channels, self.in_img_dim1, self.in_img_dim2 = in_dims
                assert (self.in_img_dim1 % 4, self.in_img_dim2 % 4) == (0, 0)
                self.allow_crop = allow_crop
                self.crop = ( (self.in_img_dim1 // 4) % 2 == 1)
                #self.crop = False # --------------------
                
                self.conv1 = torch.nn.Conv2d(self.in_channels, self.in_channels * 2, kernel_size = 3, padding = 1)
                self.conv2 = torch.nn.Conv2d(self.in_channels * 2, self.in_channels  * 2, kernel_size = 3, padding = 1)
                self.maxpool = torch.nn.MaxPool2d(kernel_size = 2, stride = 2)
                
                self.conv_out_dims = (self.in_channels * 2, self.in_img_dim1, self.in_img_dim2)
                
                self.out_dims = (self.in_channels * 2, self.in_img_dim1 // 2, self.in_img_dim2 // 2)
                #print((self.in_channels, self.in_img_dim1, self.in_img_dim2), self.out_dims)
                if (self.crop and self.allow_crop):
                    self.out_dims = (self.out_dims[0], self.out_dims[1] - 2, self.out_dims[2] - 2)
                
                print('conv_out', self.conv_out_dims, 'out', self.out_dims, 'cropped', self.crop and self.allow_crop)
                
                #print((self.in_channels, self.in_img_dim1, self.in_img_dim2), self.out_dims)
                #print()
                
            def forward(self, x):
                #print(x.shape)
                x = self.conv1(x)
                #print(x.shape)
                c = self.conv2(x)
                #print(x.shape)
                x = self.maxpool(c)
                #print(x.shape)
                #print()
                
                if(self.crop and self.allow_crop):
                    return c, x[:, :, 1:-1, 1:-1 ]
                return c, x
            
class Up_Conv(torch.nn.Module):
            def __init__(self, in_dims, pad = False):
                
                super(Up_Conv, self).__init__()
                self.in_channels, self.in_img_dim1, self.in_img_dim2 = in_dims
                self.pad = pad
                
                self.up_conv = torch.nn.ConvTranspose2d(self.in_channels, self.in_channels // 2, kernel_size = 2, stride = 2, padding = 0)
                self.conv1 = torch.nn.Conv2d(self.in_channels, self.in_channels // 2, kernel_size = 3, padding = 1)
                self.conv2 = torch.nn.Conv2d(self.in_channels // 2, self.in_channels  // 2, kernel_size = 3, padding = 1)
                
                self.out_dims = (self.in_channels // 2, self.in_img_dim1 * 2, self.in_img_dim2 * 2)
                if (self.pad):
                    self.out_dims = (self.out_dims[0], self.out_dims[1] + 4, self.out_dims[2] + 4)
                    
            def forward(self, concat_layer, x):
                    if self.pad:
                        x = torch.nn.ZeroPad2d(1)(x)
                    x = self.up_conv(x)
                    x = torch.cat([concat_layer, x], axis=1)
                    x = self.conv1(x)
                    x = self.conv2(x)
                    
                    return x
            
                

class AUnet(torch.nn.Module):
    def __init__(self, depth = 4, input_dims = (3, 300, 300), top_channels = 64, out_channels = 1):
        
        assert top_channels % 2 == 0
        assert depth > 0
        
        super(AUnet, self).__init__()
        self.depth = depth
        self.top_channels = top_channels
        self.out_channels = out_channels
        
        self.input_adapter_conv = torch.nn.Conv2d(input_dims[0], top_channels // 2, kernel_size = 1)
        
        self.down_convs = []
        assert input_dims[1] == input_dims[2] # for now, only allow quadratic input images
        in_dims = (top_channels // 2, input_dims[1], input_dims[2]) 
        for i in range(self.depth - 1):
            self.down_convs.append(Down_Conv(in_dims))
            #print(in_dims, self.down_convs[-1].out_dims)
            in_dims = self.down_convs[-1].out_dims
            #print(in_dims)
        self.down_convs.append(Down_Conv(in_dims, allow_crop = False))
        in_dims = self.down_convs[-1].out_dims
        
        #print(f'H{in_dims}')
        self.connector_conv1 = torch.nn.Conv2d(in_dims[0], in_dims[0] * 2, kernel_size = 3, padding = 1)
        in_dims = (in_dims[0] * 2, in_dims[1], in_dims[2])
        self.connector_conv2 = torch.nn.Conv2d(in_dims[0], in_dims[0], kernel_size = 3, padding = 1)
        
        
        self.up_convs = []
        for i in range(self.depth):
            print('Assert', self.down_convs[-i-1].conv_out_dims[1], in_dims[1] * 2)
            assert (self.down_convs[-i-1].conv_out_dims[1] == in_dims[1] * 2) or (self.down_convs[-i-1].conv_out_dims[1] == in_dims[1] * 2 + 4)
            pad = self.down_convs[-i-1].conv_out_dims[1] == in_dims[1] * 2 + 4
            print(pad)
            self.up_convs.append(Up_Conv(in_dims, pad))
            #print(in_dims, self.down_convs[-1].out_dims)
            in_dims = self.up_convs[-1].out_dims
            print(in_dims)
            
        #self.up_convs.append(Up_Conv(in_dims, pad=False))
        #in_dims = self.up_convs[-1].out_dims #-----------------------------
        #print(in_dims)
        
        self.head_conv = torch.nn.Conv2d(in_dims[0], out_channels, kernel_size = 1, padding = 0)
        
            
    def forward(self, x):
        concats = []
        
        print(x.shape)
        
        x1 = self.input_adapter_conv(x)
        print(x1.shape)
        for i in range(self.depth):
            x0, x1 = self.down_convs[i](x1)
            print(x0.shape, x1.shape)
            #print()
            concats.append(x0)
        print(x1.shape)    
        x = self.connector_conv1(x1)
        print(x.shape)
        x = self.connector_conv2(x)
        print(x.shape)
        
        for i in range(self.depth):
            x = self.up_convs[i](concats[-i-1], x)
            
        x = self.head_conv(x)
        return x
            
            
    
    #def forward(self, x):
    #    x1 = torch.nn.Conv2d(x.shape[-3], 512, kernel_size = 3)(x)
    #    return x1

In [None]:
img = torch.rand((727, 3, 300, 300))

for depth in range(4,5): 
    aunet = AUnet(depth = depth, input_dims = (3, 300, 300), top_channels = 64)
    out = aunet(img)
    print('out.shape', out.shape)



conv_out (64, 300, 300) out (64, 148, 148) cropped True
conv_out (128, 148, 148) out (128, 72, 72) cropped True
conv_out (256, 72, 72) out (256, 36, 36) cropped False
conv_out (512, 36, 36) out (512, 18, 18) cropped False
Assert 36 36
False
(512, 36, 36)
Assert 72 72
False
(256, 72, 72)
Assert 148 144
True
(128, 148, 148)
Assert 300 296
True
(64, 300, 300)
torch.Size([727, 3, 300, 300])
torch.Size([727, 32, 300, 300])


In [4]:
img = torch.rand((1, 6, 148, 148))
x = torch.nn.ZeroPad2d(1)(img)
x = torch.nn.ConvTranspose2d(6, 3, kernel_size = 2, stride = 2, padding = 0)(x)
print(x.shape)

torch.Size([1, 3, 300, 300])
