In [1]:
import torch
from torch import nn
from typing import Tuple

Let $\mathcal{R}$ be a $3 \times 3$ kernel used to sample a small region of the input.

$$
\mathcal{R} = \{ (-1, -1), (-1, 0), \cdots, (0, 1), (1, 1) \}
$$

Then the equation of the normal 2d convolution operation will be given as shown in the figure below where $w$ is the weights of the kernel, $x$ is the input feature map, $y$ is the output of convolution operation, $p_0$ is the starting position of each kernel and $p_n$ is enumerating along with all the positions in $\mathcal{R}$.

$$
y(p_0) = \sum_{p_n \in \mathcal{R}} w(p_n) \cdot x(p_0 + p_n)
$$

The equation denotes the convolution operation where each position on the sampled frid is first multiplied by the corresponding value of the weight matrix and then summed to give a scalar output and repeating the same operation over the entire image give us the new feature map.


The deformable convolution instead of using a simple fixed sampling grid introduces 2D offsets to the normal convolution operation depicted above.

If $\mathcal{R}$ is the normal grid, then the deformable convolution operation augments learned offsets to the grid, thereby deforming the sampling position of the grid.

The deformable convolution operation is depicted by the equation below where $\Delta p_n$ denotes the offets added to the normal convolution.

$$
y(p_0) = \sum_{p_n \in \mathcal{R}} w(p_n) \cdot x(p_0 + p_n + \Delta p_n)
$$

Now as the sampling is done on the irregular and offset locations and $\Delta p_n$ is generally fractional, we use bilinear interpolation to implement the above equation. 

**Bilinear interpolation** is used because as we add offsets to the existing sampling positions, we obtain fractional points which are not defined locations on the grid and i order to estimate their values we use bilinear interpolation which uses a 2x2 grid of the neighbouring values to estimate the value of the new deformed position.

The eq. that is used to perform bi-linear interpolation and estimate the pixel value at the fractional position is given below where $p(p_0 + p_n + \Delta p_n)$ is the deformed position, $q$ enumerates all the valid positions on the input feature map and $G(.)$ is the bilinear interpolation kernel.

$$
x(p) = \sum_q G(q, p) \cdot x(q)
$$

Note: G(..) is a 2 dimensional and can be broken down according to the axis into two one dimensional kernel as shown below

$$
G(q, p) = g(q_x, p_x) \cdot g(q_y, p_y)
$$

![Visual Representation of Deformable Convolution](deform_convolution.png)

As shown in Figure above, the offsets are obtained by applying a convolution layer over the input feature map. The convolution kernel used has spatial resolution and dilation as those of the current convolution layer. The output offset field has the same resolution as that of the input feature map and has $2N$ channels where $2N$ correspond to $N$ 2d offsets.

As illustrated above, the offsets are obtained by applying a convolutional layer over the same input feature map. The convolution kernel is of the same spatial resolution and dilation as those of the current convolutional layer.

In [2]:
in_channels = 4
kernel_size = 3
stride = 1
padding = 1

lr_ratio = 1.

offset_conv = nn.Conv2d(
    in_channels=in_channels,
    out_channels=2 * kernel_size * kernel_size, # 2 * 3 * 3 => (2N) => N is the number of values in the kernel
    kernel_size=kernel_size, # The same spatial resolution and dilation as those of the curent convolutional layer
    stride=stride,
    padding=padding
)

offset_conv

Conv2d(4, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In the training, these added conv and fc layers for offset learning are initialized with zero weights.

In [3]:
# This function allows a custom learning rate for the offset layers
def _set_lr(module, grad_input, grad_output):
    new_grad_input = []

    for i in range(len(grad_input)):
        if grad_input[i] is not None:
            new_grad_input.append(grad_input[i] * lr_ratio)
        else:
            new_grad_input.append(grad_input[i])
    new_grad_input = tuple(new_grad_input)

    return new_grad_input

nn.init.constant_(offset_conv.weight, 0)  # the offset learning are initialized with zero weights

offset_conv.register_backward_hook(_set_lr)

<torch.utils.hooks.RemovableHandle at 0x7f6adb01a910>

First step as illustred is to apply a offset conv into the sample. Initially this offset will have only the bias term since we setup all weights to 0.

In [4]:
sample = torch.randn((1, 4, 100, 100)) # Creating a 1 sample with 4 channels with 100x100 as size

offset = offset_conv(sample)
offset.shape # As a result we have a output with shape equals to 1, 2N, 100, 100



torch.Size([1, 18, 100, 100])

The offset obtained here are equivalent to $\Delta p_n$ in the equations.

In [5]:
dtype = offset.data.type()
ks = kernel_size
N = offset.size(1) // 2 # Number of elements that are in the kernel (3x3 kernel = 9 elements)

dtype, ks, N 

('torch.FloatTensor', 3, 9)

In order to obtain $p$ value, we must calculate $p_0$ and $p_n$ since we already have $\Delta p_n$ values. First lets start with $p_n$.

In [6]:
h, w = offset.size(2), offset.size(3) # Obtaining the height and the width of the input

# (1, 2N, 1, 1)
# TODO: only had to understand why apply a meshgrid here.
p_n_x, p_n_y = torch.meshgrid(
    torch.arange(-(kernel_size - 1) // 2, (kernel_size - 1) // 2 + 1),
    torch.arange(-(kernel_size - 1) // 2, (kernel_size - 1) // 2 + 1)
)

# print(p_n_x)
# print(p_n_y)

# (2N, 1)
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)

p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)


p_n.requires_grad = False

print(p_n.shape)

torch.Size([1, 18, 1, 1])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


So lets go to the $p_0$.

In [7]:
p_0_x, p_0_y = torch.meshgrid(
    torch.arange(1, h * stride + 1, stride),
    torch.arange(1, w * stride + 1, stride)
)

p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)

p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

p_0.requires_grad = False

p_0.shape

torch.Size([1, 18, 100, 100])

Then we calculate $p = p_0 + p_n + \Delta p_n$.

In [8]:
p = p_0 + p_n + offset
p.shape

torch.Size([1, 18, 100, 100])

We only do a permutation to make it more easier to do the interpolation between the scalar values.

In [9]:
p = p.contiguous().permute(0, 2, 3, 1)
p.shape

torch.Size([1, 100, 100, 18])

If $q$ is a float value, using bilinear interpolation, it has four integer values in the grid corresponding to that position (cause we are in a image). The four positions is left top, right top, left bottom, and right bottom, defined as: $q_{lt}, q_{rb}, q_{lb}, q_{rt}$. To calculate $q_{lt}$ we simply can take the floor of the value $q$.

```
(y,   x)   (y+1,   x)
(y, x+1)   (y+1, x+1)
```

In [10]:
q_lt = p.detach().floor()
q_lt.shape

torch.Size([1, 100, 100, 18])

Now to have $q_{lt}$, the next value is $q_{rb}$:

In [11]:
q_rb = q_lt + 1
q_rb.shape

torch.Size([1, 100, 100, 18])

In [12]:
# Since its 2N the first N parts are the first coordinates and so on so long
# We ensure that q_lt are in a valid interval 0 <= p_y < h - 1
# and 0 <= p_x <= w - 1.

q_lt = torch.cat([
    torch.clamp(q_lt[..., :N], 0, sample.size(2) - 1),
    torch.clamp(q_lt[..., N:], 0, sample.size(3) - 1)], dim=-1).long()

q_lt.shape

torch.Size([1, 100, 100, 18])

In [13]:
# Same thing here

q_rb = torch.cat([
    torch.clamp(q_rb[..., :N], 0, sample.size(2) - 1),
    torch.clamp(q_rb[..., N:], 0, sample.size(3) - 1)], dim=-1).long()

q_rb.shape

torch.Size([1, 100, 100, 18])

In [14]:
# For $q_{lb}$ its x is equal to right bottom, its y is equal to left top.
# Therefore, its y is from q_lt, its x is from q_rb

q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], -1)
print(q_lb.shape)

# $y$ from $q_{rb}$ and x from $q_{lt}$
# For right top point, its $x$ is equal t to left top, its $y$ is equal to right bottom.

q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], -1)
print(q_rt.shape)

torch.Size([1, 100, 100, 18])
torch.Size([1, 100, 100, 18])


In [15]:
"""
find p_y <= padding or p_y >= h - 1 - padding, find p_x <= padding or p_x >= x - 1 - padding
This is to find the points in the area where the pixel value is meaningful.
"""
# (b, h, w, N)
mask = torch.cat([
    p[..., :N].lt(padding) + p[..., :N].gt(sample.size(2) - 1 - padding),
    p[..., N:].lt(padding) + p[..., N:].gt(sample.size(3) - 1 - padding)], dim=-1).type_as(p)

mask = mask.detach()
mask.shape

torch.Size([1, 100, 100, 18])

In [16]:
floor_p = torch.floor(p)
floor_p.shape

torch.Size([1, 100, 100, 18])

In [17]:
"""
when mask is 1, take floor_p;
when mask is 0, take original p.
When thr point in the padding area, interpolation is not meaningful and we can take the nearest
point which is the most possible to have meaningful value.
"""
p = p * (1 - mask) + floor_p * mask
p = torch.cat([
    torch.clamp(p[..., :N], 0, sample.size(2) - 1),
    torch.clamp(p[..., N:], 0, sample.size(3) - 1)], dim=-1)

Now we must apply the bilinear interpolation to find each valid value in the original grid.

In [18]:
# bilinear kernel (b, h, w, N)
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

In [34]:
q_lt[:].shape

torch.Size([1, 100, 100, 18])

In [35]:
g_lt.shape

torch.Size([1, 100, 100, 9])

In [31]:
p[0,0,0,:]

tensor([0.0000, 0.0000, 0.0000, 1.0101, 0.0000, 1.0119, 2.0175, 1.8959, 2.0392,
        0.0000, 1.1488, 1.8343, 0.0000, 0.0000, 2.1384, 0.0000, 1.0210, 1.9317],
       grad_fn=<SliceBackward0>)

In the paper:
$$
G(q, p) = g(q_x, p_x) \cdot g(q_y, p_y)
$$

$$
g(a, b) = max(0, 1-|a-b|)
$$

In [19]:
def _get_x_q(x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        
        c = x.size(1)
        
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N] * padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset


# (b, c, h, w, N)
x_q_lt = _get_x_q(sample, q_lt, N)
x_q_rb = _get_x_q(sample, q_rb, N)
x_q_lb = _get_x_q(sample, q_lb, N)
x_q_rt = _get_x_q(sample, q_rt, N)

In [36]:
x_q_lt.shape

torch.Size([1, 4, 100, 100, 9])

In [32]:
"""
    In the paper, x(p) = ΣG(p, q) * x(q), G is bilinear kernal
"""
# (b, c, h, w, N)
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
    g_rb.unsqueeze(dim=1) * x_q_rb + \
    g_lb.unsqueeze(dim=1) * x_q_lb + \
    g_rt.unsqueeze(dim=1) * x_q_rt
x_offset.shape

torch.Size([1, 4, 100, 100, 9])

In [21]:
def _reshape_x_offset(x_offset, ks):
    b, c, h, w, N = x_offset.size()
    x_offset = torch.cat([x_offset[..., s:s + ks].contiguous().view(b, c, h, w * ks) for s in range(0, N, ks)],
                            dim=-1)
    x_offset = x_offset.contiguous().view(b, c, h * ks, w * ks)

    return x_offset

In [22]:
"""
x_offset is kernel_size * kernel_size(N) times x. 
"""
conv = nn.Conv2d(
    in_channels=in_channels,
    out_channels=4,
    kernel_size=kernel_size,
    stride=kernel_size,
    bias=None
)

# x_offset = _reshape_x_offset(x_offset, ks)

out = conv(x_offset)
out.shape

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 4, 100, 100, 9]