In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(torch.nn.Module):
    def __init__(self, in_channels: int=32):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, 64, bias=False)
        self.fc2 = nn.Linear(64, 128, bias=False)
        self.fc3 = nn.Linear(128, 64, bias=False)
        self.fc4 = nn.Linear(64, 10, bias=False)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        return x

In [6]:
@torch.no_grad
def sparsify_mlp(mod: nn.Module, sparsity=0.5):
    for n,p in mod.named_parameters():
        if "weight" not in n:
            continue
        mask = torch.zeros_like(p)
        idx = torch.randperm(mask.numel())
        ones = int(mask.numel() * (1-sparsity))
        idx = idx[:ones]
        mask.flatten()[idx] = 1
        p.data = p*mask
        p.mask = mask
    return mod

from copy import deepcopy
mlp = MLP()
mlp_sparse = deepcopy(mlp)
mlp_sparse = sparsify_mlp(mlp_sparse)

In [10]:
from deepspeed.profiling.flops_profiler import get_model_profile
input_shape = (1,32)
flops, macs, params = get_model_profile(
        model=mlp,  # model
        input_shape=input_shape,  # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
        args=None,  # list of positional arguments to the model.
        kwargs=None,  # dictionary of keyword arguments to the model.
        print_profile=False,  # prints the model graph with the measured profile attached to each module
        detailed=False,  # print the detailed profile
        module_depth=-1,  # depth into the nested modules, with -1 being the inner most modules
        top_modules=1,  # the number of top modules to print aggregated profile
        warm_up=0,  # the number of warm-ups before measuring the time of each module
        as_string=False,  # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
        output_file=None,  # path to the output file. If None, the profiler prints to stdout.
        ignore_modules=None,
    )  # the list of modules to ignore in the profiling
print(flops, macs, params)

[2024-10-24 13:40:34,162] [INFO] [profiler.py:1222:get_model_profile] Flops profiler warming-up...
[2024-10-24 13:40:34,164] [INFO] [profiler.py:83:start_profile] Flops profiler started
[2024-10-24 13:40:34,166] [INFO] [profiler.py:229:end_profile] Flops profiler finished
38410 19072 19072


In [11]:
sflops, smacs, sparams = get_model_profile(
        model=mlp_sparse,  # model
        input_shape=input_shape,  # input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument.
        args=None,  # list of positional arguments to the model.
        kwargs=None,  # dictionary of keyword arguments to the model.
        print_profile=False,  # prints the model graph with the measured profile attached to each module
        detailed=False,  # print the detailed profile
        module_depth=-1,  # depth into the nested modules, with -1 being the inner most modules
        top_modules=1,  # the number of top modules to print aggregated profile
        warm_up=0,  # the number of warm-ups before measuring the time of each module
        as_string=False,  # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
        output_file=None,  # path to the output file. If None, the profiler prints to stdout.
        ignore_modules=None,
    )  # the list of modules to ignore in the profiling
print(flops, macs, params)

[2024-10-24 13:40:37,860] [INFO] [profiler.py:1222:get_model_profile] Flops profiler warming-up...
[2024-10-24 13:40:37,861] [INFO] [profiler.py:83:start_profile] Flops profiler started
sparse linear:
Input sparsity 0, weight sparsity 0.5
Input shape torch.Size([1, 32]), weight shape torch.Size([64, 32])
sparse linear:
Input sparsity 0, weight sparsity 0.5
Input shape torch.Size([1, 64]), weight shape torch.Size([128, 64])
sparse linear:
Input sparsity 0, weight sparsity 0.5
Input shape torch.Size([1, 128]), weight shape torch.Size([64, 128])
sparse linear:
Input sparsity 0, weight sparsity 0.5
Input shape torch.Size([1, 64]), weight shape torch.Size([10, 64])
[2024-10-24 13:40:37,864] [INFO] [profiler.py:229:end_profile] Flops profiler finished
38410 19072 19072


In [12]:
flops/sflops

1.9862446995552798

In [13]:
macs/smacs

2.0

In [15]:
params/sparams

1.0