In [2]:
import torch
from torch.nn import Conv2d, Sequential, ReLU
import numpy as np
import cv2

In [64]:
def ContextModule(input_channels, kernel_size=[3, 3], img_size=512):
    """
    Builds the context module block block for MobileNets
    based on: "MULTI-SCALE CONTEXT AGGREGATION BYDILATED CONVOLUTIONS"
    http://vladlen.info/papers/dilated-convolutions.pdf
    Architecture:
    Layer           1   2   3   4       5       6       7       8
    Convolution     3×3 3×3 3×3 3×3     3×3     3×3     3×3     1×1
    Dilation        1   1   2   4       8       16      1       1 
    Truncation      Yes Yes Yes Yes     Yes     Yes     Yes     No
    Receptive field 3×3 5×5 9×9 17×17   33×33   65×65   67×67   67×67
    Convolution feature map size formula: 
    The context module is designed to increase the performance of dense prediction architectures by
    aggregating multi-scale contextual information. The module takes C feature maps as input and
    produces C feature maps as output. The input and output have the same form, thus the module can
    be plugged into existing dense prediction architectures.
    """
    net = []
    dilations = [1,1,2,4,8,16]

    for d in dilations:
        # Get padding to keep output shape same as input shape
        # o = [i + 2*p - k - (k-1)*(d-1)]/s + 1
        k = kernel_size if isinstance(kernel_size, int) else kernel_size[0]
        #pad = ((img_size - 1 ) / s - img_size + k + (k-1)*(d-1)) / 2  --- since s will be 1, simplify:
        pad = int((k + (k-1)*(d-1) - 1) / 2)
        print("padding: %d"%pad)
        net.append( Conv2d(in_channels = input_channels, 
                           out_channels = input_channels, 
                           kernel_size=kernel_size, 
                           padding=pad, 
                           dilation=d) )

    # Pointwise
    net.append( Conv2d(in_channels = input_channels, out_channels = input_channels, kernel_size=[1, 1]) )

    #net.append( BatchNorm2d(input_channels) )

    #net.append( ReLU() )

    return Sequential(*net)
    #####################
    #      End ContextModule     
    #####################

In [65]:
cmod = ContextModule(1)

padding: 1
padding: 1
padding: 2
padding: 4
padding: 8
padding: 16


In [66]:
img = np.random.rand(1,1,512,512)

In [67]:
img.shape, img.max(), img.min()

((1, 1, 512, 512), 0.9999980216352864, 1.2225622584294271e-06)

In [68]:
imgTensor = torch.FloatTensor(img)

In [69]:
out = cmod(imgTensor)

In [70]:
out

tensor([[[[-0.3652, -0.3625, -0.3657,  ..., -0.3331, -0.3319, -0.3357],
          [-0.3656, -0.3619, -0.3658,  ..., -0.3319, -0.3341, -0.3319],
          [-0.3652, -0.3641, -0.3642,  ..., -0.3348, -0.3312, -0.3334],
          ...,
          [-0.3252, -0.3141, -0.3234,  ..., -0.3556, -0.3515, -0.3524],
          [-0.3163, -0.3205, -0.3190,  ..., -0.3517, -0.3521, -0.3539],
          [-0.3224, -0.3158, -0.3214,  ..., -0.3527, -0.3541, -0.3501]]]],
       grad_fn=<ThnnConv2DBackward>)

In [71]:
imgTensor

tensor([[[[0.4839, 0.2239, 0.2471,  ..., 0.0078, 0.8321, 0.8030],
          [0.8905, 0.6168, 0.6412,  ..., 0.6550, 0.6975, 0.6839],
          [0.6794, 0.9819, 0.7558,  ..., 0.3268, 0.9318, 0.5718],
          ...,
          [0.5202, 0.6571, 0.9101,  ..., 0.1338, 0.1855, 0.0876],
          [0.1861, 0.5415, 0.9671,  ..., 0.5816, 0.4401, 0.3209],
          [0.6829, 0.9410, 0.0272,  ..., 0.8861, 0.6290, 0.0989]]]])

In [72]:
imgTensor.shape

torch.Size([1, 1, 512, 512])

In [73]:
out.shape

torch.Size([1, 1, 512, 512])