In [1]:

import torch
import torch.nn as nn
import torchvision


In [2]:
class DeformableConv2d(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 padding=1,
                 bias=False):
        super(DeformableConv2d, self).__init__()

        self.padding = padding

        self.offset_conv = nn.Conv2d(in_channels,
                                     2 * kernel_size * kernel_size,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=self.padding,
                                     bias=True)

        nn.init.constant_(self.offset_conv.weight, 0.)
        nn.init.constant_(self.offset_conv.bias, 0.)

        self.modulator_conv = nn.Conv2d(in_channels,
                                        1 * kernel_size * kernel_size,
                                        kernel_size=kernel_size,
                                        stride=stride,
                                        padding=self.padding,
                                        bias=True)

        nn.init.constant_(self.modulator_conv.weight, 0.)
        nn.init.constant_(self.modulator_conv.bias, 0.)

        self.regular_conv = nn.Conv2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      padding=self.padding,
                                      bias=bias)

    def forward(self, x):
        h, w = x.shape[2:]
        max_offset = max(h, w) / 4.

        offset = self.offset_conv(x).clamp(-max_offset, max_offset)
        modulator = 2. * torch.sigmoid(self.modulator_conv(x))

        x = torchvision.ops.deform_conv2d(input=x,
                                          offset=offset,
                                          weight=self.regular_conv.weight,
                                          bias=self.regular_conv.bias,
                                          padding=self.padding,
                                          mask=modulator
                                          )

In [20]:
x = torch.rand(1,3,224,224)
print(x.shape)
print(x)

torch.Size([1, 3, 224, 224])
tensor([[[[0.6637, 0.9092, 0.0223,  ..., 0.1978, 0.6091, 0.1626],
          [0.6127, 0.6896, 0.6134,  ..., 0.4563, 0.0801, 0.0066],
          [0.0064, 0.8163, 0.5050,  ..., 0.9591, 0.7924, 0.8178],
          ...,
          [0.3406, 0.5611, 0.5172,  ..., 0.8362, 0.1288, 0.1834],
          [0.6166, 0.6991, 0.3054,  ..., 0.5943, 0.4763, 0.2973],
          [0.0375, 0.5085, 0.3162,  ..., 0.1681, 0.4985, 0.3899]],

         [[0.4493, 0.1283, 0.0991,  ..., 0.9186, 0.8901, 0.3539],
          [0.7714, 0.7382, 0.6651,  ..., 0.6157, 0.7067, 0.1467],
          [0.3317, 0.1995, 0.2478,  ..., 0.7689, 0.7709, 0.8859],
          ...,
          [0.1119, 0.5473, 0.4853,  ..., 0.4273, 0.4752, 0.7518],
          [0.5964, 0.4107, 0.2402,  ..., 0.7750, 0.4884, 0.2953],
          [0.5156, 0.5605, 0.7330,  ..., 0.6830, 0.1961, 0.6793]],

         [[0.8006, 0.1372, 0.0043,  ..., 0.5283, 0.9285, 0.3222],
          [0.3775, 0.4026, 0.3167,  ..., 0.5116, 0.7442, 0.6276],
          [0.

In [6]:
h, w = x.shape[2:]
print(h,w)

224 224


In [7]:
max_offset = max(h, w) / 4.
print(max_offset)

56.0


In [21]:
in_channels=3
out_channels=3
kernel_size=3
stride=1
padding=1
bias=False

offset_conv = nn.Conv2d(in_channels,out_channels=2 * kernel_size * kernel_size,kernel_size=kernel_size,stride=stride,padding=padding,bias=True)

offset = offset_conv(x)
offset = offset.clamp(-max_offset, max_offset)
print(offset.shape)
print(offset)

torch.Size([1, 18, 224, 224])
tensor([[[[-1.7576e-02,  3.4286e-02,  2.6968e-01,  ...,  5.6142e-04,
           -3.9582e-01, -2.2050e-01],
          [ 2.4251e-03, -4.4305e-02,  4.7127e-01,  ...,  3.0009e-01,
            1.3429e-01, -1.1702e-01],
          [ 1.5913e-01, -1.1701e-01,  2.3523e-01,  ...,  1.7050e-01,
            2.6533e-01, -2.5244e-01],
          ...,
          [ 9.9858e-02,  2.7623e-01,  3.4265e-01,  ...,  1.5353e-01,
            1.4916e-01, -1.0651e-01],
          [ 1.3632e-01, -1.0728e-01,  2.3423e-02,  ...,  9.1937e-02,
            3.8159e-01, -1.2229e-01],
          [-1.6383e-02,  8.8191e-02,  5.5780e-02,  ...,  7.8856e-02,
            3.4026e-02,  4.8762e-02]],

         [[-3.4662e-01, -6.1628e-01, -5.8592e-01,  ..., -8.9743e-01,
           -8.1005e-01, -4.5383e-01],
          [-4.1726e-01, -5.3475e-01, -8.7211e-01,  ..., -8.2527e-01,
           -7.2876e-01, -2.1188e-01],
          [-4.2770e-01, -4.2792e-01, -6.7187e-01,  ..., -5.3343e-01,
           -9.3921e-01, -5.0

In [16]:
modulator_conv = nn.Conv2d(in_channels,
                                        out_channels=1 * kernel_size * kernel_size,
                                        kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding,
                                        bias=True)
modulator = 2. * torch.sigmoid(modulator_conv(x))
print(modulator.shape)

torch.Size([1, 9, 224, 224])


In [19]:
regular_conv = nn.Conv2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      padding=padding,
                                      bias=bias)

x = torchvision.ops.deform_conv2d(input=x,offset=offset,weight=regular_conv.weight,bias=regular_conv.bias,padding=padding,mask=modulator)

print(x.shape)
print(x)

torch.Size([1, 3, 224, 224])
tensor([[[[ 0.0020,  0.0129,  0.0209,  ..., -0.0469, -0.0309,  0.0324],
          [ 0.0091,  0.0049,  0.0068,  ..., -0.0462, -0.0217,  0.0256],
          [ 0.0211,  0.0278, -0.0075,  ..., -0.0520, -0.0376,  0.0213],
          ...,
          [-0.0032, -0.0152, -0.0401,  ..., -0.0536, -0.0333, -0.0121],
          [ 0.0296, -0.0050, -0.0164,  ..., -0.0568, -0.0620, -0.0634],
          [-0.0083, -0.0204, -0.0537,  ..., -0.1018, -0.0949, -0.0885]],

         [[-0.0335, -0.0154, -0.0455,  ..., -0.0135,  0.0054, -0.0253],
          [-0.0770, -0.0726, -0.0355,  ..., -0.0852, -0.0790, -0.0297],
          [-0.1309, -0.0898, -0.1053,  ..., -0.1011, -0.1116, -0.0053],
          ...,
          [-0.0845, -0.1279, -0.1345,  ..., -0.1476, -0.1351,  0.0063],
          [-0.0652, -0.0859, -0.0831,  ..., -0.1181, -0.1347, -0.0074],
          [-0.0432, -0.0426, -0.0542,  ..., -0.0916, -0.1086, -0.0304]],

         [[ 0.0176,  0.0344,  0.0258,  ..., -0.0337,  0.0126,  0.0132],
 

In [25]:
input = torch.rand(4, 3, 10, 10)
kh, kw = 3, 3
weight = torch.rand(5, 3, kh, kw)
offset = torch.rand(4, 2 * kh * kw, 8, 8)
mask = torch.rand(4, kh * kw, 8, 8)

print(weight)


tensor([[[[0.0569, 0.7440, 0.0140],
          [0.0148, 0.6590, 0.4926],
          [0.0950, 0.8827, 0.6277]],

         [[0.4516, 0.7992, 0.5346],
          [0.8951, 0.3522, 0.6384],
          [0.1848, 0.9502, 0.1848]],

         [[0.2099, 0.1302, 0.6574],
          [0.1031, 0.8875, 0.5087],
          [0.5917, 0.9960, 0.9076]]],


        [[[0.8687, 0.6939, 0.4749],
          [0.7885, 0.3227, 0.4976],
          [0.9828, 0.9661, 0.3260]],

         [[0.0874, 0.5775, 0.1695],
          [0.5461, 0.6191, 0.0130],
          [0.7924, 0.8268, 0.5407]],

         [[0.5587, 0.9888, 0.0940],
          [0.6395, 0.1014, 0.9123],
          [0.8542, 0.1739, 0.9641]]],


        [[[0.6092, 0.5315, 0.5009],
          [0.0752, 0.9386, 0.0849],
          [0.5047, 0.5962, 0.7937]],

         [[0.0541, 0.0705, 0.9983],
          [0.2636, 0.4181, 0.3774],
          [0.2157, 0.8887, 0.9758]],

         [[0.9315, 0.0792, 0.2303],
          [0.1017, 0.6964, 0.3520],
          [0.7284, 0.5469, 0.0180]]],


    

In [27]:
out = torchvision.ops.deform_conv2d(input, offset, weight, mask=mask)
print(out.shape)

torch.Size([4, 5, 8, 8])
