## Introduction
This notebooks compare two approaches to calculating FLOPS:
* `torchprofile.profile_macs` (additional dependency)
* `torch.profiler.profile(..., with_flops=True)` (part of pytorch)

`torchprofile.profile_macs` uses static code analysis, whereas `torch.profiler.profile` performs runtime-based profiling.

`idlmav` was initially developed to use `torchprofile.profile_macs`, but switching to `torch.profiler.profile` seems to promise several benefits
* It reduces the number of dependencies
* By using a context manager, it is not restricted descendants of `nn.Module` 
  * In most modern models, `torch.fx.symbolic_trace` cannot produce a graph and we need to fall back to `torch.compile` (see [16_explore_multiple_control_paths.ipynb](./16_explore_multiple_control_paths.ipynb))
  * In this configuration, `nn.Module` operations are typically "optimized out", i.e. replaced by direct function calls
  * `torchprofile.profile_macs` would therefore be unable to provide FLOPS estimates at all for models that cannot be traced with `torch.fx.symbolic_trace`

## Observations
* `torch.profiler.profile` seems to execute slightly faster than `torchprofile.profile_macs`
* Results are very similar in general and exactly equal for many modules, e.g. `nn.Conv2d`, `nn.Linear`
* For batch normalization and average pooling, `torch.profiler.profile` estimates zero FLOPS, whereas `torchprofile.profile_macs` often estimates substantial FLOPS
* For function calls such as `mul()` or `add()`, `torch.profiler.profile` reports FLOPS, which `torchprofile.profile_macs` is unable to do

Some discrepancies are explained by the fact that `torchprofile.profile_macs` performs static code analysis whereas `torch.profiler.profile` performs runtime-based profiling. For example, in inference mode, many computationally intensive operations in `nn.BatchNorm` may per-calculated and/or optimized out.

## Decision
It was ultimately decided to replace `torchprofile.profile_macs` by `torch.profiler.profile`, mainly due to the ability of the latter to also produce FLOPS estimates in cases where `nn.Module` was optimized out / replaced by direct function calls by `torch.compile`

## FLOPS calculation alternatives

In [1]:
import sys
sys.path.append('..')
from typing import List, Dict, Set, Mapping, Tuple, Optional, Any, Union
from transformers import CLIPProcessor, CLIPModel
import torch
from torch import nn, fx, Tensor
import torchprofile
from idlmav import merge_graph_nodes, layout_graph_nodes, color_graph_nodes, FigureRenderer, WidgetRenderer, plotly_renderer
from idlmav.tracing import MavTracer, ShapeMacInterpreter, get_num_trainable_params, warnings
from idlmav.mavtypes import MavNode, MavGraph, MavConnection



### Using `torchprofiler.profile_macs`
* Just copied over `MavTracer.run()` and modified interpreter initialization to use `Interpreter1`

In [2]:
class Interpreter1(ShapeMacInterpreter):
    def call_module(self, target, args, kwargs):
        # Run the module
        result = fx.Interpreter.call_module(self, target, args, kwargs)

        # Estimate the FLOPS
        try:
            submod = self.fetch_attr(target)
            macs = torchprofile.profile_macs(submod, args)
        except Exception as e:
            warnings.warn(f'FLOPS calculation via torchprofile.profile_macs failed for module {submod.__class__.__name__}: {e}')
            macs = 0  
        self.cur_macs = macs

        # Return the result
        return result

class MavTracer1(MavTracer):
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # 1st pass: symbolic tracing using torch.fx
        self.gm = fx.symbolic_trace(self.model, concrete_args)
        self.interp = Interpreter1(self.gm)

        # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
        try:
            for n in self.gm.graph.nodes:
                if n.op == 'call_module':
                    m:nn.Module = self.interp.fetch_attr(n.target)
                    self.target_names[n] = m.__class__.__name__
                    self.param_counts[n] = get_num_trainable_params(m)
                elif n.op == 'call_function':
                    self.target_names[n] = n.target.__name__
        except Exception as e:
            self.err_node = n
            warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

        # 3rd pass: forward pass using torch.fx.Interpreter
        try:
            self.interp.run(self.inputs)
        except Exception as e:
            msg = 'Forward pass failed.'
            n1 = self.interp.last_successful_node
            if n1:
                target_name = self.target_names.get(n1, None)
                node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                msg += f' Last successful node: "{node_name}".'
            n2 = self.interp.running_node
            if n2:
                target_name = self.target_names.get(n2, None)
                node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                msg += f' Possible error node: "{node_name}".'
            self.err_node = self.interp.running_node
            warnings.warn(f'{msg}: {e}')

### Using `torch.profiler.profile`
* Just copied over `MavTracer.run()` and modified interpreter initialization to use `Interpreter2`

In [3]:
class Interpreter2(ShapeMacInterpreter):
    def run_node(self, n:fx.Node) -> Any:
        # Run the node
        self.cur_macs = None
        self.running_node = n
        result = fx.Interpreter.run_node(self, n)
        self.running_node = None

        # Retrieve the shape
        if isinstance(result, Tensor):
            shape = tuple(result.shape)
        else:
            shape = (0,0,0,0)
        self.shapes[n] = shape

        # Store the number of MACs if calculated
        if n.op == 'call_module' or n.op == 'call_function':
            if self.cur_macs is not None: self.macs[n] = self.cur_macs

        # Update the state and return the result
        self.last_successful_node = n
        return result
    

    def call_module(self, target, args, kwargs):
        # Run the module
        result = fx.Interpreter.call_module(self, target, args, kwargs)

        # Estimate the FLOPS
        try:
            submod = self.fetch_attr(target)
            with torch.profiler.profile(
                activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
                record_shapes=True,
                with_flops=True
            ) as prof:
                submod(*args)
            flops = prof.key_averages().total_average().flops
            macs = int(flops/2)
        except Exception as e:
            warnings.warn(f'FLOPS calculation via torch.profiler.profile failed for module {submod.__class__.__name__}: {e}')
            macs = 0  
        self.cur_macs = macs

        # Return the result
        return result
    
    def call_function(self, target, args, kwargs):
        result = fx.Interpreter.call_function(self, target, args, kwargs)
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
            record_shapes=True,
            with_flops=True
        ) as prof:
            target(*args, **kwargs)
        flops = prof.key_averages().total_average().flops
        self.cur_macs = int(flops/2)
        return result

class MavTracer2(MavTracer):
    def run(self, concrete_args: Optional[Dict[str, Any]]=None):
        # 1st pass: symbolic tracing using torch.fx
        self.gm = fx.symbolic_trace(self.model, concrete_args)
        self.interp = Interpreter2(self.gm)

        # 2nd pass: iterate through `nn.Module` and update module types and parameter counts
        try:
            for n in self.gm.graph.nodes:
                if n.op == 'call_module':
                    m:nn.Module = self.interp.fetch_attr(n.target)
                    self.target_names[n] = m.__class__.__name__
                    self.param_counts[n] = get_num_trainable_params(m)
                elif n.op == 'call_function':
                    self.target_names[n] = n.target.__name__
        except Exception as e:
            self.err_node = n
            warnings.warn(f'2nd tracing pass failed for module {n.target}: {e}')

        # 3rd pass: forward pass using torch.fx.Interpreter
        try:
            self.interp.run(self.inputs)
        except Exception as e:
            msg = 'Forward pass failed.'
            n1 = self.interp.last_successful_node
            if n1:
                target_name = self.target_names.get(n1, None)
                node_name = f'{n1.name}:{target_name}' if target_name else n1.name
                msg += f' Last successful node: "{node_name}".'
            n2 = self.interp.running_node
            if n2:
                target_name = self.target_names.get(n2, None)
                node_name = f'{n2.name}:{target_name}' if target_name else n2.name
                msg += f' Possible error node: "{node_name}".'
            self.err_node = self.interp.running_node
            warnings.warn(f'{msg}: {e}')

## MobileNetV3 small

In [4]:
import sys
sys.path.append('..')
import torch
from torchvision.models import mobilenet_v3_small
from idlmav import MAV, plotly_renderer

model = mobilenet_v3_small(weights=None)
model.eval()
inputs = torch.randn(1,3,224,224)
device = 'cpu'

### Current library implementation

In [5]:
mav = MAV(model, inputs, device)
with plotly_renderer('notebook_connected'): mav.show_figure()

INFO:2025-02-11 19:59:25 26166:26166 init.cpp:181] If you see CUPTI_ERROR_INSUFFICIENT_PRIVILEGES, refer to https://developer.nvidia.com/nvidia-development-tools-solutions-err-nvgpuctrperm-cupti


### Using `torchprofiler.profile_macs`

In [6]:
tracer = MavTracer1(model, inputs, device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardswish_". Skipped.


No handlers found: "aten::hardsigmoid". Ski

In [7]:
nn.BatchNorm2d

torch.nn.modules.batchnorm.BatchNorm2d

### Using `torch.profiler.profile`

In [8]:
tracer = MavTracer2(model, inputs, device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()

In [9]:

with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU],  # Use ProfilerActivity.CUDA for GPU
    record_shapes=True,
    with_flops=True
) as prof:
    model(inputs)
ka = prof.key_averages()
total_flops = ka.total_average().flops

In [10]:
[a.flops for a in ka]

[109793152,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 2296,
 0,
 0,
 0,
 0,
 0,
 0,
 3227648,
 0,
 0,
 0]

## EfficientNetV2 small

In [13]:
import sys
sys.path.append('..')
import torch
from torchvision.models import efficientnet_v2_s
from idlmav import MAV, plotly_renderer

model = efficientnet_v2_s(pretrained=False)
model.eval()
inputs = torch.randn(1, 3, 224, 224)
device = 'cpu'


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.



### Current library implementation

In [14]:
mav = MAV(model, inputs, device)
with plotly_renderer('notebook_connected'): mav.show_figure()

### Using `torchprofiler.profile_macs`

In [15]:
tracer = MavTracer1(model, inputs, device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handlers found: "aten::silu_". Skipped.


No handle

### Using `torch.profiler.profile`

In [16]:
tracer = MavTracer2(model, inputs, device)
merge_graph_nodes(tracer.g)
layout_graph_nodes(tracer.g)
color_graph_nodes(tracer.g)
with plotly_renderer('notebook_connected'):
    renderer = FigureRenderer(tracer.g)
    renderer.render().show()