Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] Add concrete info prop #1677

Merged
merged 44 commits into from Oct 4, 2022

Conversation

Cypher30
Copy link
Contributor

@Cypher30 Cypher30 commented Oct 3, 2022

What's New?

In this PR, I provide ConcreteInfoProp to facilitate the meta info estimation, we also provide the time profiler that track the running time of each node. You could check the docstring of function _profile_concrete for more information.

For example, now you could run

import torch
from torch.fx import symbolic_trace
from colossalai.fx.passes import ConcreteInfoProp

BATCH_SIZE = 2
DIM_IN = 4
DIM_HIDDEN = 16
DIM_OUT = 16
model = torch.nn.Sequential(
    torch.nn.Linear(DIM_IN, DIM_HIDDEN), 
    torch.nn.Linear(DIM_HIDDEN, DIM_OUT),
    ).cuda()
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device="cuda")
gm = symbolic_trace(model)
interp = ConcreteInfoProp(gm)
interp.run(input_sample)
print(interp.summary(unit='kb')) 

and get the following results

 Op type       Op             Forward time             Backward time    SAVE_FWD_IN    FWD_OUT    FWD_TMP    BWD_OUT    BWD_TMP
-----------  -------  -----------------------  ------------------------  -------------  ---------  ---------  ---------  ---------
placeholder  input_1                    0.0 s                     0.0 s          False    0.00 KB    0.00 KB    0.00 KB    0.00 KB
call_module       _0  0.0003993511199951172 s     0.00706791877746582 s          False    0.50 KB    0.00 KB    0.03 KB    0.66 KB
call_module       _1   6.29425048828125e-05 s  0.00018286705017089844 s          False    0.50 KB    0.00 KB    0.12 KB    0.81 KB
     output   output                    0.0 s                     0.0 s           True    0.00 KB    0.00 KB    0.00 KB    0.00 KB

Cypher30 and others added 30 commits July 14, 2022 16:07


def _profile_concrete(target: Callable, *args, **kwargs) -> Tuple[Tuple[Any, ...], GraphInfo]:
"""
Profile a Callable function with args and kwargs on concrete devices.
"""Profile a Callable function with args and kwargs on concrete devices by https://github.com/Cypher30
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the shortcut for you guys to check what I'm doing!!

colossalai/fx/passes/concrete_info_prop.py Outdated Show resolved Hide resolved
colossalai/fx/passes/concrete_info_prop.py Outdated Show resolved Hide resolved
colossalai/fx/passes/concrete_info_prop.py Outdated Show resolved Hide resolved
# find the grad for parameter in args and kwargs
param_size = 0

def get_param_size(x):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part is for retrieving param size in args and kwargs

@Cypher30 Cypher30 merged commit 132b430 into hpcaitech:main Oct 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants