In [None]:
import os
import rootutils

from calflops import calculate_flops

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from src.models.components.cnn_cam_multihead import CNNCAMMultihead
from src.models.components.vit_rollout_multihead import VitRolloutMultihead

#### CNN

In [None]:
model = CNNCAMMultihead(
    backbone='torchvision.models/efficientnet_v2_s',
    return_node='features.6.0.block.0',
    multi_head=True,
)

input_shape = (1, 3, 224, 224)
flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

In [None]:
model = CNNCAMMultihead(
    backbone='torchvision.models/efficientnet_v2_s',
    return_node='features.7',
    multi_head=True,
)

flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

In [None]:
model = CNNCAMMultihead(
    backbone='torchvision.models/mobilenet_v3_large',
    return_node='features.16',
    multi_head=True,
)

flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

In [None]:
model = CNNCAMMultihead(
    backbone='torchvision.models/mobilenet_v3_large',
    return_node='features.13.block.0',
    multi_head=True,
)

flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

#### ViT

In [None]:
model = VitRolloutMultihead(
    backbone='timm/vit_tiny_patch16_224.augreg_in21k_ft_in1k',
    multi_head=True
)
flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))

In [None]:
model = VitRolloutMultihead(
    backbone='timm/deit_tiny_patch16_224.fb_in1k',
    multi_head=True
)
flops, macs, params = calculate_flops(model=model, 
                                      input_shape=input_shape,
                                      output_as_string=True,
                                      output_precision=4)
print("FLOPs:%s   MACs:%s   Params:%s \n" %(flops, macs, params))