# NeuralCompression Flop Counter Example

Welcome! In this notebook we'll walkthrough using `neuralcompression`'s flop counter to calculate the computational complexity of a compression model.

In [1]:
import torch
from torch import nn

from neuralcompression.functional import count_flops
from neuralcompression.models import ScaleHyperprior



## Basic Usage

To get started with the flop counter, simply instantiate the model you want to evaluate and pass it to `neuralcompression.functional.count_flops`, along with the inputs it should be evaluated on.

In [2]:
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size = 3, padding = 1),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.Conv2d(16, 32, stride = 2, kernel_size = 5, padding = 2),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(32 * 16 * 16, 10)
)

inputs = (torch.randn(5, 3, 32, 32),)

results = count_flops(model, inputs)

The result returned by the flop counter is a 3-tuple. The first element records the total number of flops performed by the model.

In [3]:
results[0]

19619840

The second element in the return tuple breaks down this count by operation:

In [4]:
results[1]

{'conv': 18595840, 'batch_norm': 614400, 'linear': 409600}

The third item returned by the counter records all the operations that the counter didn't know how to count flops for (more detail about these ops below). This dicitionary (or more specifically a [collections.Counter](https://docs.python.org/3/library/collections.html#collections.Counter)) maps unknown op names to the number of times those ops were called in the model.

In our example, all of the model's operations are supported by the counter, so this dictionary is empty.

In [5]:
results[2]

Counter()

In general, the majority of common ML operations are already supported, although some operations' counts may be approximations (e.g. one floating point addition and one square root are both counted as 1 flop). Much more complicated models than the toy example above can be evaluated without accumulating lots of unsupported operations, such as the [scale hyperprior model](https://arxiv.org/abs/1802.01436):

In [6]:
bigger_model = ScaleHyperprior(network_channels=32, compression_channels=64)
_,_, unsupported_ops = count_flops(bigger_model, (torch.randn(1, 3, 64, 64),))

len(unsupported_ops) == 0 # Verifying no unsupported operations

True

However, if you need to add support for more ops or override the counter's default implementation, read on!

## Advanced Usage

### How the Counter Works

The flop counter in `neuralcompression` makes heavy use of the counting utilities in [fvcore](https://github.com/facebookresearch/fvcore). Counting a model's flops is a two-step process:

1. Using PyTorch's [TorchScript](https://pytorch.org/docs/stable/jit.html) capabilities, the model is first JIT-traced into a computational graph. Each node in the graph corresponds to an ATen (linear algebra) operation, like matrix multiplications, convolutions, and elementwise operations like additions/subtractions.


2. Every node in the traced graph is iterated over, and if a node is associated with a registered flop-counting function, that function is invoked and the counted flops are added to the model's total.


In this second step, the counter's registered flop functions are stored as a dictionary mapping operator names (e.g. `aten::add`, `aten::matmul`) to counter functions. These functions have a signature of:

```
def my_counter_function(inputs: List[torch._C.Value], outputs: List[torch._C.Value]) -> float
```

where objects of type `torch._C.Value` represent the symbolic inputs and outputs for the node in the graph (i.e. they're not actual concrete `Tensor`s). Check out the `fvcore` codebase for [examples of how to write counter functions](https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/jit_handles.py).

### Customizing the Counter

To register additional counter functions or override the counter's default implementation for specific operations, use the `counter_overrides` argument, which takes the form of a dictionary mapping operator names to the corresponding counter functions you wish to use. As an example, consider the following simple module:

In [7]:
class MyModule(nn.Module):
    def forward(self,x,y,z):
        return x + y * z

By default, the elementwise addition or multiplication of two tensors of shape `N x M` will contribute `N x M` flops to the total computational complexity of a model (since each scalar addition/multiplication is counted as 1 flop):

In [8]:
N = 5
M = 32

inp = torch.randn(N, M)

flops,_,_ = count_flops(MyModule(), (inp, inp, inp))
flops == 2 * N * M

True

However, let's say that you wanted to ignore all the flops coming from addition operations. You could do this as follows:

In [9]:
new_flops, _, _ = count_flops(
    MyModule(), 
    (inp, inp, inp), 
    counter_overrides = {
        "aten::add": lambda inps,outs: 0.0
    }
)

new_flops == N * M # only the flops from multiplications are counted

True