In [2]:
import torch
import torch.nn.functional as F
import torchvision.models.resnet as resnet

In [16]:
class DecoderDeeplabV3p(torch.nn.Module):
    def __init__(self, bottleneck_ch, skip_4x_ch, num_out_ch):
        super(DecoderDeeplabV3p, self).__init__()

        # TODO: Implement a proper decoder with skip connections instead of the following
        #self.features_to_predictions = torch.nn.Conv2d(bottleneck_ch, num_out_ch, kernel_size=1, stride=1)
        self.features_to_predictions = torch.nn.Conv2d(num_out_ch, num_out_ch, kernel_size=1, stride=1)
        self.conv1x1 = torch.nn.Conv2d(skip_4x_ch, num_out_ch, kernel_size = 1, stride = 1)
        self.conv3x3 = torch.nn.Conv2d(num_out_ch + bottleneck_ch, num_out_ch, kernel_size = 3, stride = 1)
        
    def forward(self, features_bottleneck, features_skip_4x):
        """
        DeepLabV3+ style decoder
        :param features_bottleneck: bottleneck features of scale > 4 coming from aspp module
        :param features_skip_4x: features of encoder of scale == 4 coming from DCNN
        :return: features with 256 channels and the final tensor of predictions
        """
        # TODO: Implement a proper decoder with skip connections instead of the following; keep returned
        #       tensors in the same order and of the same shape.
        
        
        features_aspp_4x = F.interpolate(
                        features_bottleneck, size=features_skip_4x.shape[2:], mode='bilinear', align_corners=False)
        dcnn_features = selv.conv1x1(features_skip_4x)
        concat = torch.cat([features_aspp_4x, dcnn_features], dim = 1)
        features = self.conv3x3(concat)
        features_4x = F.interpolate(
                        features, size=features_skip_4x.shape[2:], mode='bilinear', align_corners=False)
        predictions_4x = self.features_to_predictions(features_4x)
        return predictions_4x, features_4x


class ASPPpart(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation):
        super().__init__(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
        )


class ASPP(torch.nn.Module):
    def __init__(self, in_channels, out_channels, rates=(3, 6, 9)):
        super().__init__()
        # TODO: Implement ASPP properly instead of the following
        modules = []
        rates = [2*x for x in rates]
        modules.append(ASPPpart(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1))
        for rate in rates:
            modules.append(ASPPpart(in_channels, out_channels, kernel_size=3, stride=1, padding=rate, dilation=rate))
       
        #global_avg = torch.nn.AdaptiveAvgPool2d(1) does not work gives [256,512, H, W] instead of [256,256, H,W]
        # therefore apply convolution with correct output channels
        global_avg = torch.nn.Sequential(torch.nn.AdaptiveAvgPool2d(1),
                                         torch.nn.Conv2d(in_channels, out_channels, kernel_size = 1))
        modules.append(global_avg)
        self.aspp_convs = torch.nn.ModuleList(modules)
        print("In channels", in_channels)
        print("Out channels", out_channels)
        # At this stage when called, already concatenated so we know how many out channels we have for each conv.
        # So total after concatenation of all diff layers of conv is len(self.aspp_convs)*out_channels
        self.conv_1x1 = torch.nn.Conv2d(out_channels*len(self.aspp_convs), out_channels, kernel_size = 1)

    def forward(self, x):
        # TODO: Implement ASPP properly instead of the following
        res = []
        resolution_h_w = (x.shape[2], x.shape[3]) # height and width of feature map image
        for layer in self.aspp_convs:
            res.append(layer(x))
            print(layer)
            print("Tensor size = ", layer(x).size())
        #res[4] is the output of the average pooling but has h= 1, w = 1 so we upsample it to the needed height and width
        res[4] = F.interpolate(res[4],resolution_h_w, mode = 'bilinear', align_corners=False)
        res = torch.cat(res, dim = 1)
        print("after cat")
        return self.conv_1x1(res)

In [13]:
aspp = ASPP(512,256) # instanciate class

In channels 512
Out channels 256


In [14]:
features = torch.randn((256,512,10,10)) # format [BatchSize, Channels, Height, Width]
features_tasks = aspp(features) # call forward method
print("return size of forward = ", features_tasks.size())

ASPPpart(
  (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
Tensor size =  torch.Size([256, 256, 10, 10])
ASPPpart(
  (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6), bias=False)
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
Tensor size =  torch.Size([256, 256, 10, 10])
ASPPpart(
  (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(12, 12), dilation=(12, 12), bias=False)
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
Tensor size =  torch.Size([256, 256, 10, 10])
ASPPpart(
  (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(18, 18), dilation=(18, 18), bias=False)
  (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)
Tensor size =  