In [8]:
import torch
# torch.set_grad_enabled(False)

from lvae import get_model

Initialize model.

For FLOPs computation, $\lambda$ is arbitrary since it does not affect FLOPs.

In [9]:
# initialize model
model = get_model('qres34m', lmb=2048, pretrained=True)

model.eval()
model._flops_mode = True

input_shape = (3, 256, 256)
inputs = (torch.randn(1, *input_shape), )

Pytorch profiler

In [10]:
import torch.profiler as tp

with tp.profile(activities=[tp.ProfilerActivity.CPU], with_flops=True) as prof:
    model(*inputs)
torch_flops = sum([event.flops for event in prof.events()]) / 2
torch_param = sum([p.numel() for p in model.parameters()])

print(f'torch estimated FLOPs (MACs) = {torch_flops/1e9:.3g}B, parameters = {torch_param/1e6:.3f}M')


torch estimated FLOPs (MACs) = 23.2B, parameters = 34.037M


THOP: PyTorch-OpCounter

https://github.com/Lyken17/pytorch-OpCounter

In [11]:
from thop import profile, clever_format
thop_macs, thop_params = profile(model, inputs=inputs, verbose=False)

print(f'thop estimated FLOPs (MACs) = {thop_macs/1e9:.3g}B, parameters = {thop_params/1e6:.3f}M')

thop estimated FLOPs (MACs) = 23.2B, parameters = 33.984M


flops-counter.pytorch

https://github.com/sovrasov/flops-counter.pytorch

In [12]:
from ptflops import get_model_complexity_info

ptfl_macs, ptfl_params = get_model_complexity_info(model, input_shape, as_strings=False, verbose=False)

print(f'ptflops estimated FLOPs (MACs) = {ptfl_macs/1e9:.3g}B, parameters = {ptfl_params/1e6:.3f}M')

HierarchicalVAE(
  33.98 M, 99.845% Params, 23.25 GMac, 100.000% MACs, 
  (encoder): BottomUpEncoder(
    15.84 M, 46.535% Params, 10.69 GMac, 45.997% MACs, 
    (enc_blocks): ModuleList(
      15.84 M, 46.535% Params, 10.69 GMac, 45.997% MACs, 
      (0): Conv2d(9.41 k, 0.028% Params, 38.54 MMac, 0.166% MACs, 3, 192, kernel_size=(4, 4), stride=(4, 4))
      (1): MyConvNeXtBlock(
        157.63 k, 0.463% Params, 643.3 MMac, 2.767% MACs, 
        (conv_dw): Conv2d(9.6 k, 0.028% Params, 39.32 MMac, 0.169% MACs, 192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
        (norm): LayerNorm(0, 0.000% Params, 0.0 Mac, 0.000% MACs, (192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          148.03 k, 0.435% Params, 603.98 MMac, 2.598% MACs, 
          (fc1): Linear(74.11 k, 0.218% Params, 301.99 MMac, 1.299% MACs, in_features=192, out_features=384, bias=True)
          (act): GELU(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
          (drop1): Dropout(0, 0.000%