## Counting Floating Point Operations in each model

In [18]:
import torch
import torch.nn as nn

def calculate_flops(model, input_size=(3, 256, 256)):
    flops = 0
    input = torch.randn(1, *input_size).to(next(model.parameters()).device)

    for layer in model.modules():
        if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d):
            out_h = (input_size[1] - layer.kernel_size[0]) // layer.stride[0] + 1
            out_w = (input_size[2] - layer.kernel_size[1]) // layer.stride[1] + 1
            flops += 2 * layer.in_channels * layer.out_channels * layer.kernel_size[0] * layer.kernel_size[1] * out_h * out_w / layer.stride[0]
            input_size = (layer.out_channels, out_h, out_w)
        elif isinstance(layer, nn.Linear):
            flops += 2 * layer.in_features * layer.out_features
        elif isinstance(layer, nn.ReLU):
            continue  # ReLU doesn't involve FLOPs
        else:
            # print(f"Warning: layer {layer} not counted")
            pass

    return flops


In [19]:
import segmentation_models_pytorch as smp
for i in range(1, 9):
    model = smp.DeepLabV3Plus(
            in_channels=i,
            classes=1,
            activation='sigmoid',
            encoder_name='resnet34',
            encoder_weights=None,
        )
    print(f'{model.__class__.__name__} with {i} input ch., GFLOPs: {calculate_flops(model)/1e9 :.3f}')

print(30*'=')

for i in range(1, 9):
    model = smp.UnetPlusPlus(
            in_channels=i,
            classes=1,
            activation='sigmoid',
            encoder_name='resnet34',
            encoder_weights=None,
        )
    print(f'{model.__class__.__name__} with {i} input ch., GFLOPs: {calculate_flops(model)/1e9 :.3f}')

DeepLabV3Plus with 1 input ch., GFLOPs: 47.356
DeepLabV3Plus with 2 input ch., GFLOPs: 47.405
DeepLabV3Plus with 3 input ch., GFLOPs: 47.454
DeepLabV3Plus with 4 input ch., GFLOPs: 47.503
DeepLabV3Plus with 5 input ch., GFLOPs: 47.552
DeepLabV3Plus with 6 input ch., GFLOPs: 47.601
DeepLabV3Plus with 7 input ch., GFLOPs: 47.650
DeepLabV3Plus with 8 input ch., GFLOPs: 47.699
UnetPlusPlus with 1 input ch., GFLOPs: 17.597
UnetPlusPlus with 2 input ch., GFLOPs: 17.646
UnetPlusPlus with 3 input ch., GFLOPs: 17.695
UnetPlusPlus with 4 input ch., GFLOPs: 17.744
UnetPlusPlus with 5 input ch., GFLOPs: 17.793
UnetPlusPlus with 6 input ch., GFLOPs: 17.842
UnetPlusPlus with 7 input ch., GFLOPs: 17.891
UnetPlusPlus with 8 input ch., GFLOPs: 17.940
