In [1]:
import math
from numpy import pad
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_

import torch
import torch.nn.functional as F
from torch import nn

HIDDEN_DIM = 256
NUM_PROPOSALS = 100
CONV_KERNEL_SIZE_1D = 3

class MultiHeadCrossAtten(nn.Module):
    def __init__(self):
        super(MultiHeadCrossAtten, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.num_proposals = NUM_PROPOSALS
        self.conv_kernel_size_1d = CONV_KERNEL_SIZE_1D

        self.atten = nn.MultiheadAttention(embed_dim=self.hidden_dim * 1**2, num_heads=8, dropout=0.0)
        self.f_norm = nn.LayerNorm(self.hidden_dim)

    def forward(self, query, value):
        # torch.cuda.synchronize()
        query = query.permute(1, 0, 2)
        value = value.permute(1, 0, 2)

        # TODO: check
        out = self.atten(query, value, value)[0]
        out = out.permute(1, 0, 2)
        out = self.f_norm(out)
        # torch.cuda.synchronize()
        return out


class DyConvAtten(nn.Module):
    def __init__(self):
        super(DyConvAtten, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.num_proposals = NUM_PROPOSALS
        self.conv_kernel_size_1d = CONV_KERNEL_SIZE_1D

        self.f_linear = nn.Linear(self.hidden_dim, self.num_proposals * self.conv_kernel_size_1d)
        self.f_norm = nn.LayerNorm(self.hidden_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        # print("init weights")
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, f, k):
        # torch.cuda.synchronize()
        # f: [B, N, C]
        # k: [B, N, C * K * K]
        B = f.shape[0]
        weight = self.f_linear(f)
        weight = weight.view(B, self.num_proposals, self.num_proposals, self.conv_kernel_size_1d)
        res = []
        for i in range(B):
            # input: [1, N, C * K * K]
            # weight: [N, N, convK]
            # output: [1, N, C * K * K]
            out = F.conv1d(input=k.unsqueeze(1)[i], weight=weight[i], padding='same')
            res.append(out)
        # [B, N, C * K * K] 
        f_tmp = torch.cat(res, dim=0) #.permute(1, 0, 2).reshape(self.num_proposals, B, self.hidden_dim)
        f_tmp = self.f_norm(f_tmp)
        # [N, B, C * K * K]
        # f_tmp = f_tmp.permute(1, 0, 2)
        # torch.cuda.synchronize()
        return f_tmp


class DySepConvAtten(nn.Module):
    def __init__(self):
        super(DySepConvAtten, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.num_proposals = NUM_PROPOSALS
        self.kernel_size = CONV_KERNEL_SIZE_1D

        # self.depth_weight_linear = nn.Linear(hidden_dim, kernel_size)
        # self.point_weigth_linear = nn.Linear(hidden_dim, num_proposals)
        self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals + self.kernel_size)
        self.norm = nn.LayerNorm(self.hidden_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        # print("init weights")
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward(self, query, value):
        # torch.cuda.synchronize()
        assert query.shape == value.shape
        B, N, C = query.shape
        
        # dynamic depth-wise conv
        # dy_depth_conv_weight = self.depth_weight_linear(query).view(B, self.num_proposals, 1,self.kernel_size) # B, N, 1, K
        # dy_point_conv_weight = self.point_weigth_linear(query).view(B, self.num_proposals, self.num_proposals, 1)

        dy_conv_weight = self.weight_linear(query)
        # dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B,self.num_proposals,1,self.kernel_size)
        # dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B,self.num_proposals,self.num_proposals,1)

        res = []
        value = value.unsqueeze(1)
        for i in range(B):
            # input: [1, N, C]
            # weight: [N, 1, K]
            # output: [1, N, C]
            # 
            # dy_depth_conv_weight[i]
            out = F.relu(F.conv1d(input=value[i], weight=dy_conv_weight[i, :, :self.kernel_size].view(self.num_proposals,1,self.kernel_size), groups=N, padding="same"))
            # input: [1, N, C]
            # weight: [N, N, 1]
            # output: [1, N, C]
            # 
            # dy_point_conv_weight[i]
            out = F.conv1d(input=out, weight=dy_conv_weight[i, :, self.kernel_size:].view(self.num_proposals,self.num_proposals,1), padding='same')

            res.append(out)
        point_out = torch.cat(res, dim=0)
        point_out = self.norm(point_out)
        # torch.cuda.synchronize()
        return point_out


class DyDepthwiseConvAtten(nn.Module):
    def __init__(self):
        super(DyDepthwiseConvAtten, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.num_proposals = NUM_PROPOSALS
        self.kernel_size = CONV_KERNEL_SIZE_1D

        # self.depth_weight_linear = nn.Linear(hidden_dim, kernel_size)
        # self.point_weigth_linear = nn.Linear(hidden_dim, num_proposals)
        self.weight_linear = nn.Linear(self.hidden_dim, self.kernel_size)
        self.norm = nn.LayerNorm(self.hidden_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        # print("init weights")
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward(self, query, value):
        # torch.cuda.synchronize()
        assert query.shape == value.shape
        B, N, C = query.shape
        
        # dynamic depth-wise conv
        # dy_depth_conv_weight = self.depth_weight_linear(query).view(B, self.num_proposals, 1,self.kernel_size) # B, N, 1, K
        # dy_point_conv_weight = self.point_weigth_linear(query).view(B, self.num_proposals, self.num_proposals, 1)
        dy_conv_weight = self.weight_linear(query).view(B,self.num_proposals,1,self.kernel_size)
        # dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B,self.num_proposals,1,self.kernel_size)
        # dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B,self.num_proposals,self.num_proposals,1)

        res = []
        value = value.unsqueeze(1)
        for i in range(B):
            # input: [1, N, C]
            # weight: [N, 1, K]
            # output: [1, N, C]
            out = F.conv1d(input=value[i], weight=dy_conv_weight[i], groups=N, padding="same")
            # input: [1, N, C]
            # weight: [N, N, 1]
            # output: [1, N, C]
            # out = F.conv1d(input=out, weight=dy_point_conv_weight[i], padding='same')
            res.append(out)
        point_out = torch.cat(res, dim=0)
        point_out = self.norm(point_out)
        # torch.cuda.synchronize()
        return point_out


class DyPointwiseConvAtten(nn.Module):
    def __init__(self):
        super(DyPointwiseConvAtten, self).__init__()
        self.hidden_dim = HIDDEN_DIM
        self.num_proposals = NUM_PROPOSALS
        self.kernel_size = CONV_KERNEL_SIZE_1D

        # self.depth_weight_linear = nn.Linear(hidden_dim, kernel_size)
        # self.point_weigth_linear = nn.Linear(hidden_dim, num_proposals)
        self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals)
        self.norm = nn.LayerNorm(self.hidden_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        # print("init weights")
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)


    def forward(self, query, value):
        # torch.cuda.synchronize()
        assert query.shape == value.shape
        B, N, C = query.shape
        
        # dynamic depth-wise conv
        # dy_depth_conv_weight = self.depth_weight_linear(query).view(B, self.num_proposals, 1,self.kernel_size) # B, N, 1, K
        # dy_point_conv_weight = self.point_weigth_linear(query).view(B, self.num_proposals, self.num_proposals, 1)

        dy_conv_weight = self.weight_linear(query).view(B,self.num_proposals,self.num_proposals,1)

        res = []
        value = value.unsqueeze(1)
        for i in range(B):
            # input: [1, N, C]
            # weight: [N, 1, K]
            # output: [1, N, C]
            # out = F.relu(F.conv1d(, weight=dy_depth_conv_weight[i], groups=N, padding="same"))
            # input: [1, N, C]
            # weight: [N, N, 1]
            # output: [1, N, C]
            out = F.conv1d(input=value[i], weight=dy_conv_weight[i], padding='same')

            res.append(out)
        point_out = torch.cat(res, dim=0)
        point_out = self.norm(point_out)
        # torch.cuda.synchronize()
        return point_out

In [2]:
# FLOPs
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import ActivationCountAnalysis

from ptflops import get_model_complexity_info

q = torch.rand((1, NUM_PROPOSALS, HIDDEN_DIM))
v = torch.rand((1, NUM_PROPOSALS, HIDDEN_DIM))

MHCA = MultiHeadCrossAtten()
DCA = DyConvAtten()
DSCA = DySepConvAtten()
DDCA = DyDepthwiseConvAtten()
DPCA = DyPointwiseConvAtten()

flops = FlopCountAnalysis(MHCA, (q, v))
print("MHCA flops counter: ")
print(flops.total())
print(flops.by_operator())

flops = FlopCountAnalysis(DCA, (q, v))
print("DCA flops counter: ")
conv = nn.Conv1d(in_channels=NUM_PROPOSALS, out_channels=NUM_PROPOSALS, kernel_size=CONV_KERNEL_SIZE_1D, bias=False, padding='same')
macs, _ = get_model_complexity_info(conv, (NUM_PROPOSALS, HIDDEN_DIM), as_strings=False, print_per_layer_stat=False, verbose=True)
print(flops.total() + macs)
print(flops.by_operator(), "conv: ", macs)

flops = FlopCountAnalysis(DSCA, (q, v))
print("DSCA flops counter: ")
depthwise_conv = nn.Conv1d(in_channels=NUM_PROPOSALS, out_channels=NUM_PROPOSALS, kernel_size=CONV_KERNEL_SIZE_1D, bias=False, groups=NUM_PROPOSALS, padding='same')
macs_depthwise, _ = get_model_complexity_info(depthwise_conv, (NUM_PROPOSALS, HIDDEN_DIM), as_strings=False, print_per_layer_stat=False, verbose=True)
pointwise_conv = nn.Conv1d(in_channels=NUM_PROPOSALS, out_channels=NUM_PROPOSALS, kernel_size=1, bias=False, padding='same')
macs_pointwise, _ = get_model_complexity_info(pointwise_conv, (NUM_PROPOSALS, HIDDEN_DIM), as_strings=False, print_per_layer_stat=False, verbose=True)
print(flops.total() + macs_depthwise + macs_pointwise)
print(flops.by_operator(), "depthwise: ", macs_depthwise, "pointwise: ", macs_pointwise)

flops = FlopCountAnalysis(DDCA, (q, v))
print("DDCA flops counter: ")
print(flops.total() + macs_depthwise)
print(flops.by_operator(), "depthwise: ", macs_depthwise,)

flops = FlopCountAnalysis(DPCA, (q, v))
print("DPCA flops counter: ")
print(flops.total() + macs_pointwise)
print(flops.by_operator(), "pointwise: ", macs_pointwise)

Unsupported operator aten::div encountered 2 time(s)
Unsupported operator aten::mul encountered 5 time(s)
Unsupported operator aten::softmax encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
atten.out_proj
Unsupported operator aten::_convolution_mode encountered 1 time(s)
Unsupported operator aten::_convolution_mode encountered 2 time(s)
Unsupported operator aten::_convolution_mode encountered 1 time(s)
Unsupported operator aten::_convolution_mode encountered 1 time(s)


MHCA flops counter: 
31462400
Counter({'linear': 26214400, 'bmm': 5120000, 'layer_norm': 128000})
DCA flops counter: 
15488000.0
Counter({'linear': 7680000, 'layer_norm': 128000}) conv:  7680000.0
DSCA flops counter: 
5401600.0
Counter({'linear': 2636800, 'layer_norm': 128000}) depthwise:  76800.0 pointwise:  2560000.0
DDCA flops counter: 
281600.0
Counter({'layer_norm': 128000, 'linear': 76800}) depthwise:  76800.0
DPCA flops counter: 
5248000.0
Counter({'linear': 2560000, 'layer_norm': 128000}) pointwise:  2560000.0


In [157]:
# Latency
import torch
q = torch.rand((1, NUM_PROPOSALS, HIDDEN_DIM), requires_grad=False).cuda()
v = torch.rand((1, NUM_PROPOSALS, HIDDEN_DIM), requires_grad=False).cuda()

# MHCA = MultiHeadCrossAtten().cuda()
DCA = DyConvAtten().cuda()
# DSCA = DySepConvAtten().cuda()
# DDCA = DyDepthwiseConvAtten().cuda()
# DPCA = DyPointwiseConvAtten().cuda()

# warm up
# o = MHCA(q, v)
o = DCA(q, v)
# o = DSCA(q, v)
# o = DDCA(q, v)
# o = DPCA(q, v)

# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
#     o = MHCA(q, v)
# # NOTE: some columns were removed for brevity
# print(prof.key_averages().table(sort_by="self_cpu_time_total"))

with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
    o = DCA(q, v)
# NOTE: some columns were removed for brevity
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
#     o = DSCA(q, v)
# # NOTE: some columns were removed for brevity
# print(prof.key_averages().table(sort_by="self_cpu_time_total"))

# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
#     o = DDCA(q, v)
# # NOTE: some columns were removed for brevity
# print(prof.key_averages().table(sort_by="self_cpu_time_total"))

# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
#     o = DPCA(q, v)
# # NOTE: some columns were removed for brevity
# print(prof.key_averages().table(sort_by="self_cpu_time_total"))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       cudaEventDestroy        26.96%     572.000us        26.96%     572.000us       6.500us       0.000us         0.00%       0.000us       0.000us            88  
                                        cudaEventRecord        13.95%     296.000us        13.95%     296.000us       3.326us       0.000us         0.00%       0.000us       0.000us            89  
         