In [9]:
import torch
from torch import nn
from module import Inv2d
from Inv import involution
from Inv_cuda import involution as inv_cuda
from fvcore.nn import FlopCountAnalysis, flop_count_table

In [10]:
class Conv(nn.Module):
    def __init__(self, c_in, kernel_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_in, kernel_size, padding='same', groups=c_in),
            nn.Conv2d(c_in, c_in, kernel_size=1),
        )

    def forward(self, x):
        x = self.net(x)
        return x
    
data = torch.randn(1, 64, 32, 32)
md1 = Conv(64, 3)
# md2 = Inv2d(64, 64, kernel_size=3, padding=1)
md2 = involution(64, 3, stride=1)
md3 = inv_cuda(64, 3, stride=1)

In [11]:
print(flop_count_table(FlopCountAnalysis(md1, data)))

| module     | #parameters or shape   | #flops   |
|:-----------|:-----------------------|:---------|
| net        | 4.8K                   | 4.194M   |
|  0         |  0.64K                 |  0       |
|   0.weight |   (64, 1, 3, 3)        |          |
|   0.bias   |   (64,)                |          |
|  1         |  4.16K                 |  4.194M  |
|   1.weight |   (64, 64, 1, 1)       |          |
|   1.bias   |   (64,)                |          |


In [12]:
print(flop_count_table(FlopCountAnalysis(md2, data)))

| module               | #parameters or shape   | #flops   |
|:---------------------|:-----------------------|:---------|
| model                | 1.668K                 | 1.72M    |
|  conv1               |  1.056K                |  1.13M   |
|   conv1.conv         |   1.024K               |   1.049M |
|    conv1.conv.weight |    (16, 64, 1, 1)      |          |
|   conv1.bn           |   32                   |   81.92K |
|    conv1.bn.weight   |    (16,)               |          |
|    conv1.bn.bias     |    (16,)               |          |
|  conv2.conv          |  0.612K                |  0.59M   |
|   conv2.conv.weight  |   (36, 16, 1, 1)       |          |
|   conv2.conv.bias    |   (36,)                |          |


In [15]:
import cupy

print(cupy.__version__)

13.3.0


In [14]:
device = torch.device('cuda')
md3 = md3.to(device)
data = data.to(device)
print(flop_count_table(FlopCountAnalysis(md3, data)))

AttributeError: module 'cupy.cuda' has no attribute 'compile_with_cache'