In [1]:
import torch
from models.utils.continual_model import save_input_size
from backbone.ResNet18 import resnet18


In [2]:
def BN_flops(B, C, H, W):
    """
    S=B x C x H x W

    Flops of each operation:
    - mean:       S+C
    - var:        3S+C
    - normalize:  2S+2C
    - affine:     2S
    - EMA:        6C
    """
    S = B * C * H * W
    return 8 * S + 10 * C


def AdaB2N_flops(B, C, H, W):
    """
    S=B x C x H x W

    Flops of each operation:
    - mean:       S+BC
    - var:        3S+BC
    - normalize:  2S+2C
    - affine:     2S
    - EMA:        6C
    - Loss:       6C
    """
    S = B * C * H * W
    return 8 * S + 2 * B * C + 14 * C


def GN_flops(B, C, H, W, G=32, affline=True):
    """
    S=B x C x H x W

    Flops of each operation:
    - mean:       S+BG
    - var:        3S+BG
    - normalize:  2S+2BG
    - affine:     2S
    """
    S = B * C * H * W
    if affline:
        return 8 * S + 4 * B * G
    return 6 * S + 4 * B * G



def CN_flops(B, C, H, W, G=32):
    return BN_flops(B, C, H, W) + GN_flops(B, C, H, W, G, False)

In [3]:
input_size = (10, 3, 32, 32)
model = resnet18(100)
model.apply(save_input_size)

with torch.no_grad():
    model(torch.randn(*input_size))

norm_modules = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]

for method in ["BN", "GN", "CN", "AdaB2N"]:
    flops = 0
    for mod in norm_modules:
        shape = mod._input_size
        flops += eval(method + "_flops")(*shape)
    print(f"FLOPs of {method}: {flops}")

FLOPs of BN: 49200000
FLOPs of GN: 49177600
FLOPs of CN: 86089600
FLOPs of AdaB2N: 49315200
