In [1]:
import sys
assert sys.version_info[0] == 3, 'The major version of Python must be 3 to run this notebook.'

In [2]:
import flops_counter
import flops_counter.nn as nn
from flops_counter.tensorsize import TensorSize

In [3]:
# input
x = TensorSize(1, 3, 224, 224)

In [4]:
# single definition
conv1 = nn.Conv2d(3, 64, 3, 1, 1)
y = conv1(x)
print(y)
print(conv1.flops)
conv1.set_flops_zero()
y = conv1(x)
print(y)
print(conv1.flops)

flops_counter.TensorSize([1, 64, 224, 224])
173408256
flops_counter.TensorSize([1, 64, 224, 224])
173408256


In [5]:
module = nn.Sequential(
    *[nn.Conv2d(3,64,3,1,1),
      nn.ReLU()]
)
y = module(x)
print(y)
print(module)
module.set_flops_zero()
y = module(x)
print(y)
print(module)

flops_counter.TensorSize([1, 64, 224, 224])
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1)), FLOPs = 173,408,256, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
  (1): ReLU(inplace=False), FLOPs = 6,422,528, input = flops_counter.TensorSize([1, 64, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
), FLOPs = 179,830,784, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
flops_counter.TensorSize([1, 64, 224, 224])
Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1)), FLOPs = 173,408,256, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
  (1): ReLU(inplace=False), FLOPs = 6,422,528, input = flops_counter.TensorSize([1, 64, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
), FLOPs = 179,830,784, input = flops_counter.TensorSize([1, 3, 224, 224]), output = 

In [6]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu1(out)
        return out

simplenet = SimpleNet()
y = simplenet(x)
print(y)
print(simplenet)

simplenet.set_flops_zero()

y = simplenet(x)
print(y)
print(simplenet)


flops_counter.TensorSize([1, 64, 224, 224])
SimpleNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1)), FLOPs = 173,408,256, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
  (relu1): ReLU(inplace=False), FLOPs = 6,422,528, input = flops_counter.TensorSize([1, 64, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
), FLOPs = 179,830,784, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
flops_counter.TensorSize([1, 64, 224, 224])
SimpleNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1)), FLOPs = 173,408,256, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
  (relu1): ReLU(inplace=False), FLOPs = 6,422,528, input = flops_counter.TensorSize([1, 64, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
), FLOPs = 179,830,784, input = flops_counter.TensorSize([1, 3, 224, 22

In [8]:
class DoubleNet(nn.Module):
    def __init__(self):
        super(DoubleNet, self).__init__()
        
        self.sn1 = SimpleNet()
        self.sn2 = nn.Conv2d(64, 3, 1, 1)

    def forward(self, x):
        out = self.sn1(x)
        out = self.sn2(out)
        
        return x
    
dn = DoubleNet()
y = dn(x)
# print(y)
print(dn)

dn.set_flops_zero()

y = dn(x)
# print(y)
print(dn)

DoubleNet(
  (sn1): SimpleNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1)), FLOPs = 173,408,256, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
    (relu1): ReLU(inplace=False), FLOPs = 6,422,528, input = flops_counter.TensorSize([1, 64, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
  ), FLOPs = 179,830,784, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 64, 224, 224])
  (sn2): Conv2d(64, 3, kernel_size=(1, 1)), FLOPs = 19,267,584, input = flops_counter.TensorSize([1, 64, 224, 224]), output = flops_counter.TensorSize([1, 3, 224, 224])
), FLOPs = 199,098,368, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.TensorSize([1, 3, 224, 224])
DoubleNet(
  (sn1): SimpleNet(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1)), FLOPs = 173,408,256, input = flops_counter.TensorSize([1, 3, 224, 224]), output = flops_counter.T