In [4]:
import torch
from torchvision.models import resnet50  # Example model

from thop import clever_format, profile

# Load a pre-trained model (e.g., ResNet50)
model = resnet50()

# Create a dummy input tensor matching the model's expected input shape
dummy_input = torch.randn(1, 3, 224, 224)

# Profile the model
macs, params = profile(model, inputs=(dummy_input,))

# Format the numbers into a readable format (e.g., 4.14 GMac, 25.56 MParams)
macs_readable, params_readable = clever_format([macs, params], "%.3f")

print(f"Formatted MACs: {macs_readable}, Formatted Parameters: {params_readable}")
# Expected output: Formatted MACs: 4.140G, Formatted Parameters: 25.557M
# Expected output: MACs: 4139975680.0, Parameters: 25557032.0

  warn(


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Formatted MACs: 4.134G, Formatted Parameters: 25.557M


In [3]:
import torch
import torch.nn as nn

from thop import profile

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

def count_your_custom_module(module, x, y):
    macs = 0
    if isinstance(module, nn.Conv2d):
        _, _, H, W = y.shape  # Output shape
        k_h, k_w = module.kernel_size
        in_c = module.in_channels
        out_c = module.out_channels
        groups = module.groups
        macs = (k_h * k_w * in_c * out_c * H * W) / groups
    module.total_ops += torch.DoubleTensor([macs])  

model = MyModule()  # Or a larger model incorporating this module
dummy_input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(dummy_input,), custom_ops={MyModule: count_your_custom_module})
print(f"Custom MACs: {macs}, Parameters: {params}")


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Customize rule count_your_custom_module() <class '__main__.MyModule'>.
Custom MACs: 86704128.0, Parameters: 1792.0
