In [None]:
import pathlib
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2 as cv
import pandas as pd
import matplotlib.pyplot as plt
from torchsummary import summary
from matplotlib import rc

%matplotlib inline

%load_ext autoreload
%autoreload 2

font = {'family': 'Times New Roman', 'weight': 'bold', 'size': 12}
rc('font', **font)

In [None]:
from maskrcnn_benchmark.layers import modulated_deform_conv

class ModulatedDeformXCorrDepthwise(nn.Module):
    def __init__(self, n_channels, template_size):
        super().__init__()

        out_channels = (template_size ** 2) * 3
        self._conv_offset_mask = nn.Conv2d(
            n_channels, out_channels, kernel_size=3, padding=1
        )

        self._init_offset()
    
    def forward(self, sr_features, template_features):
        offset_mask_pred = self._conv_offset_mask(sr_features)
        offset_part_1, offset_part_2, mask = torch.chunk(
            offset_mask_pred, chunks=3, dim=1
        )
        offset = torch.cat((offset_part_1, offset_part_2), dim=1)
        mask = torch.sigmoid(mask)

        batch_size, n_channels, sr_height, sr_width = sr_features.shape
        *_, t_height, t_width = template_features.shape

        sr_features = sr_features.view(1, -1, sr_height, sr_width)
        template_features = template_features.view(-1, 1, t_height, t_width)

        out = modulated_deform_conv(
            sr_features, offset, mask, template_features,
            None,  # bias
            1,  # stride
            0,  # padding
            1,  # dilation
            batch_size * n_channels,  # groups
            1  # deformable groups
        )
        *_, out_height, out_width = out.shape
        out = out.view(batch_size, n_channels, out_height, out_width)

        return out

    def _init_offset(self):
        self._conv_offset_mask.weight.data.zero_()
        self._conv_offset_mask.bias.data.zero_()


def xcorr_depthwise(x, kernel):
    """depthwise cross correlation
    """
    print(f"x: {x.shape}")
    print(f"kernel: {kernel.shape}")
    batch = kernel.size(0)
    channel = kernel.size(1)
    x = x.view(1, batch * channel, x.size(2), x.size(3))
    print(f"x view: {x.shape}")
    kernel = kernel.view(batch * channel, 1, kernel.size(2), kernel.size(3))
    print(f"kernel view: {kernel.shape}")
    out = F.conv2d(x, kernel, groups=batch * channel)
    print(f"out: {out.shape}")
    out = out.view(batch, channel, out.size(2), out.size(3))
    print(f"out view: {out.shape}")
    return out

batch_size = 8
n_channels = 128
template_size = 15
sr_size = template_size * 2

template_features = torch.rand((batch_size, n_channels, template_size, template_size)).cuda()
sr_features = torch.rand((batch_size, n_channels, sr_size, sr_size)).cuda()

mdxcorr_module = ModulatedDeformXCorrDepthwise(n_channels, template_size).cuda()
res_orig = xcorr_depthwise(sr_features, template_features) * 0.5
res_new = mdxcorr_module(sr_features, template_features)
res_orig.shape, res_new.shape
torch.sum(torch.abs(res_orig - res_new))
