In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from fvcore.nn import FlopCountAnalysis


class SimIFAModule(nn.Module):
    def __init__(self, in_channels=128, out_channels=[1024, 512, 256, 128]):
        super().__init__()
        self.weights = nn.Parameter(torch.rand((in_channels, sum(out_channels), 1, 1)))

    def forward(self, p5, p4, p3, p2):
        # torch.cuda.synchronize()
        x5 = F.interpolate(p5, scale_factor=8, align_corners=False, mode='bilinear')
        x4 = F.interpolate(p4, scale_factor=4, align_corners=False, mode='bilinear')
        x3 = F.interpolate(p3, scale_factor=2, align_corners=False, mode='bilinear')
        x2 = p2
        x_fuse = torch.concat([x5, x4, x3, x2], dim=1)
        # x_fuse: [1, 1920, 256, 256]
        output = F.conv2d(x_fuse, self.weights)
        # torch.cuda.synchronize()
        return output


class SimCFAModuleForFlops(nn.Module):
    def __init__(self, in_channels=128, out_channels=[1024, 512, 256, 128]):
        super().__init__()
        self.weights = nn.Parameter(torch.rand((in_channels, sum(out_channels), 1, 1)))
        self.conv = nn.Conv3d(4, 1, 1, bias=False)
        nn.init.constant_(self.conv.weight, 1)
        self.out_channels = out_channels

    def forward(self, p5, p4, p3, p2):
        x5 = F.interpolate(F.conv2d(p5, self.weights[:, :self.out_channels[0]]), scale_factor=8, align_corners=False, mode='bilinear')
        x4 = F.interpolate(F.conv2d(p4, self.weights[:, sum(self.out_channels[:1]):sum(self.out_channels[:2])]), scale_factor=4, align_corners=False, mode='bilinear')
        x3 = F.interpolate(F.conv2d(p3, self.weights[:, sum(self.out_channels[:2]):sum(self.out_channels[:3])]), scale_factor=2, align_corners=False, mode='bilinear')
        x2 = F.conv2d(p2, self.weights[:, sum(self.out_channels[:3]):])
        concat = torch.concat([x5.unsqueeze(0), x4.unsqueeze(0), x3.unsqueeze(0), x2.unsqueeze(0)], dim=0)
        concat = concat.permute(1, 0, 2, 3, 4)
        output = self.conv(concat)
        # output = x5 + x4 + x3 + x2
        # output = torch.add([x5, x4, x3, x2])
        return output


class SimCFAModule(nn.Module):
    def __init__(self, in_channels=128, out_channels=[1024, 512, 256, 128]):
        super().__init__()
        self.weights = nn.Parameter(torch.rand((in_channels, sum(out_channels), 1, 1)))
        self.out_channels = out_channels

    def forward(self, p5, p4, p3, p2):
        # torch.cuda.synchronize()
        x5 = F.interpolate(F.conv2d(p5, self.weights[:, :self.out_channels[0]]), scale_factor=8, align_corners=False, mode='bilinear')
        x4 = F.interpolate(F.conv2d(p4, self.weights[:, sum(self.out_channels[:1]):sum(self.out_channels[:2])]), scale_factor=4, align_corners=False, mode='bilinear')
        x3 = F.interpolate(F.conv2d(p3, self.weights[:, sum(self.out_channels[:2]):sum(self.out_channels[:3])]), scale_factor=2, align_corners=False, mode='bilinear')
        x2 = F.conv2d(p2, self.weights[:, sum(self.out_channels[:3]):])
        output = x5 + x4 + x3 + x2
        # torch.cuda.synchronize()
        return output


In [None]:
# FLOPs
inputs = (
    torch.rand((1, 1024, 32, 32)),
    torch.rand((1, 512, 64, 64)),
    torch.rand((1, 256, 128, 128)),
    torch.rand((1, 128, 256, 256)),
)

IFA = SimIFAModule()
CFA = SimCFAModuleForFlops()

print()
flops = FlopCountAnalysis(IFA, inputs)
print("IFA flops counter: ")
print(flops.total())
print(flops.by_operator())

print()
flops = FlopCountAnalysis(CFA, inputs)
print("CFA flops counter: ")
print(flops.total())
print(flops.by_operator())

In [31]:
# Latency
inputs = (
    torch.rand((1, 1024, 32, 32)).cuda(),
    torch.rand((1, 512, 64, 64)).cuda(),
    torch.rand((1, 256, 128, 128)).cuda(),
    torch.rand((1, 128, 256, 256)).cuda(),
)

# IFA = SimIFAModule().cuda()
CFA = SimCFAModule().cuda()
# IFA.weights.data = CFA.weights.data

# warm up
# ifa_output = IFA(inputs[0], inputs[1], inputs[2], inputs[3])
cfa_output = CFA(inputs[0], inputs[1], inputs[2], inputs[3])
# print("check the same output: ", (cfa_output.sum() - ifa_output.sum()))

# with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
#     IFA(inputs[0], inputs[1], inputs[2], inputs[3])
# # NOTE: some columns were removed for brevity
# print(prof.key_averages().table(sort_by="self_cpu_time_total"))

with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False) as prof:
    CFA(inputs[0], inputs[1], inputs[2], inputs[3])
# NOTE: some columns were removed for brevity
print(prof.key_averages().table(sort_by="self_cpu_time_total"))



-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
      aten::cudnn_convolution        20.78%     446.000us        43.57%     935.000us     233.750us     814.000us        42.82%       1.138ms     284.500us             4  
        cudaDeviceSynchronize        16.22%     348.000us        16.22%     348.000us     348.000us       0.000us         0.00%       0.000us       0.000us             1  
             cudaEventDestroy        14.86%     319.000us        14.86%     319.000us       2.185us       0.000us         0.00%       0.000