In [1]:
import torch

In [73]:
class Down_Conv(torch.nn.Module):
            def __init__(self, in_dims, allow_crop = True):
            
                assert in_dims[1] == in_dims[2]  # for now, only allow quadratic input images
                
                super(Down_Conv, self).__init__()
                self.in_channels, self.in_img_dim1, self.in_img_dim2 = in_dims
                assert (self.in_img_dim1 % 4, self.in_img_dim2 % 4) == (0, 0)
                self.allow_crop = allow_crop
                self.crop = ( (self.in_img_dim1 // 4) % 2 == 1)
                #self.crop = False # --------------------
                
                self.conv1 = torch.nn.Conv2d(self.in_channels, self.in_channels * 2, kernel_size = 3, padding = 1)
                self.conv2 = torch.nn.Conv2d(self.in_channels * 2, self.in_channels  * 2, kernel_size = 3, padding = 1)
                self.maxpool = torch.nn.MaxPool2d(kernel_size = 2, stride = 2)
                
                self.out_dims = (self.in_channels * 2, self.in_img_dim1 // 2, self.in_img_dim2 // 2)
                print((self.in_channels, self.in_img_dim1, self.in_img_dim2), self.out_dims)
                if (self.crop and self.allow_crop):
                    self.out_dims = (self.out_dims[0], self.out_dims[1] - 2, self.out_dims[2] - 2)
                print((self.in_channels, self.in_img_dim1, self.in_img_dim2), self.out_dims)
                print()
                
            def forward(self, x):
                #print(x.shape)
                x = self.conv1(x)
                #print(x.shape)
                x = self.conv2(x)
                #print(x.shape)
                x = self.maxpool(x)
                #print(x.shape)
                #print()
                
                if(self.crop):
                    return x, x[:, :, 1:-1, 1:-1 ]
                return x, x


class AUnet(torch.nn.Module):
    def __init__(self, depth = 4, input_dims = (3, 300, 300), top_channels = 64):
        
        assert top_channels % 2 == 0
        
        super(AUnet, self).__init__()
        self.depth = depth
        self.top_channels = top_channels
        self.input_adapter_conv = torch.nn.Conv2d(input_dims[0], top_channels // 2, kernel_size = 1)
        
        self.down_convs = []
        assert input_dims[1] == input_dims[2] # for now, only allow quadratic input images
        in_dims = (top_channels // 2, input_dims[1], input_dims[2]) 
        for i in range(self.depth - 1):
            self.down_convs.append(Down_Conv(in_dims))
            #print(in_dims, self.down_convs[-1].out_dims)
            in_dims = self.down_convs[-1].out_dims
            #print(in_dims)
        self.down_convs.append(Down_Conv(in_dims, allow_crop = False))
            
    def forward(self, x):
        concats = []
        
        #print(x.shape)
        x1 = self.input_adapter_conv(x)
        #print(x1.shape)
        for i in range(self.depth):
            x0, x1 = self.down_convs[i](x1)
            #print(x0.shape, x1.shape)
            #print()
            concats.append(x0)
        
        #print(x1.shape)
        return x1
            
            
    
    #def forward(self, x):
    #    x1 = torch.nn.Conv2d(x.shape[-3], 512, kernel_size = 3)(x)
    #    return x1

In [74]:
img = torch.rand((1, 3, 300, 300))

for depth in range(4,5): 
    aunet = AUnet(depth = depth, input_dims = (3, 300, 300), top_channels = 64)
    out = aunet(img)
    #print(depth, out.shape)



(32, 300, 300) (64, 150, 150)
(32, 300, 300) (64, 148, 148)

(64, 148, 148) (128, 74, 74)
(64, 148, 148) (128, 72, 72)

(128, 72, 72) (256, 36, 36)
(128, 72, 72) (256, 36, 36)

(256, 36, 36) (512, 18, 18)
(256, 36, 36) (512, 18, 18)

