Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] add profiler for fx nodes. #1480

Merged
merged 43 commits into from
Aug 24, 2022
Merged

[fx] add profiler for fx nodes. #1480

merged 43 commits into from
Aug 24, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Aug 23, 2022

What's new?

After patching all possible ops, we can now profile the memory cost and FLOPs with lines of code. We only support the original torch.nn.functional and torch.nn, but it is not too challenging to profile your own model using MetaInfoProp.

import torch
from colossalai.fx.profiler import profile_function, profile_module


input = torch.rand(100, 100, 100, 100, device='meta')
func = torch.nn.functional.relu
output, profile = profile_function(func)(input, inplace=False)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")

output, profile = profile_function(func)(input, inplace=True)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")

input = torch.rand(4, 3, 224, 224, device='meta')
mod = torch.nn.Conv2d(3, 128, 3)
output, profile = profile_module(mod)(input)
print(f"Profiling function {mod},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")

===============================================================================
Result:
Profiling function <function relu at 0x7f3b6f8ead30>,
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
Profiling function <function relu at 0x7f3b6f8ead30>,
Param size: 0.000 MB, Activation size: 0.000 MB, 100000000 FLOPs, 0 MACs
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
===============================================================================

Also using MetaInfoProp, we can trace the model using option device='meta' solely and get all the required results.

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tm
from torch.fx import symbolic_trace
import torch.fx
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from colossalai.fx.passes.meta_info_prop import MetaInfoProp


def _forward_mem(gm: torch.fx.GraphModule):
    node_size = 0
    param_size = 0
    for node in gm.graph.nodes:
        node_size += getattr(node, '__param__', 0) + getattr(node, '__activation__', 0)
        param_size += getattr(node, '__param__', 0)
    return node_size / 1024**2, param_size / 1024**2


def _forward_flops(gm: torch.fx.GraphModule):
    flops = 0
    macs = 0
    for node in gm.graph.nodes:
        flops += getattr(node, '__flops__', 0)
        macs += getattr(node, '__macs__', 0)
    return flops / 1e9, macs / 1e9


def data_gen(batch_size: int, shape: Tuple[int, int, int], device='cuda'):
    data = torch.rand(batch_size, *shape, device=device)
    label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
    return data, label


def test_forward(gm: torch.fx.GraphModule, num_steps: int=5):
    def get_gpu_mem():
        result = torch.cuda.max_memory_allocated() / 1024**2
        torch.cuda.reset_peak_memory_stats()
        return result

    get_gpu_mem()   # reset
    forward_mem = -get_gpu_mem()
    param_mem = -get_gpu_mem()
    gm.train()
    gm.cuda()
    param_mem += get_gpu_mem()
    criterion = CrossEntropyLoss()
    optimizer = Adam(gm.parameters(), lr=1e-3)
    for n in range(num_steps):
        data, label = data_gen(1, (3, 224, 224))
        output = gm(data)
        optimizer.zero_grad()
        loss = criterion(output, label)
        forward_mem += get_gpu_mem() / num_steps
        loss.backward()
        optimizer.step()
    return forward_mem, param_mem

        
def test_meta_info_prop():
    for M in [tm.densenet121, tm.densenet161, tm.densenet169, tm.densenet201]:
        model = M()
        data = torch.rand(1, 3, 224, 224, device='meta')
        gm = symbolic_trace(model)
        MetaInfoProp(gm).run(data)
        meta_forward_mem, meta_param_mem = _forward_mem(gm)
        flops, macs = _forward_flops(gm)
        concrete_forward_mem, concrete_param_mem = test_forward(gm, num_steps=1)

        print(f'|{M}|{meta_forward_mem:.3f} MB|{meta_param_mem:.3f} MB|{concrete_forward_mem:.3f} MB|{concrete_param_mem:.3f} MB|{flops:.3f}GFLOPs|{macs:.3f}GMACs|')
    
        
if __name__ == '__main__':
    test_meta_info_prop()

===============================================================================
Result:
|<function densenet121 at 0x7f99d58f7b80>|158.786 MB|30.437 MB|156.183 MB|30.859 MB|5.717GFLOPs|2.834GMACs|
|<function densenet161 at 0x7f99d58f7d30>|347.533 MB|109.409 MB|349.309 MB|112.571 MB|15.546GFLOPs|7.728GMACs|
|<function densenet169 at 0x7f99d58f7ee0>|208.338 MB|53.976 MB|209.491 MB|54.724 MB|6.778GFLOPs|3.360GMACs|
|<function densenet201 at 0x7f99d58ff0d0>|274.686 MB|76.347 MB|277.507 MB|77.392 MB|8.659GFLOPs|4.291GMACs|
===============================================================================

super-dainiu and others added 30 commits August 9, 2022 23:23
* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.
Copy link
Contributor

@Cypher30 Cypher30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌GREAT🙌

@FrankLeeeee
Copy link
Contributor

Great work!

@super-dainiu
Copy link
Contributor Author

image
I passed all tests/test_fx tests locally on A100.

@FrankLeeeee FrankLeeeee merged commit 32efe8e into hpcaitech:main Aug 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants