New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hotfix/rotor] fix variable names #1597
[hotfix/rotor] fix variable names #1597
Conversation
…o feature/better_flop_tensor
…o feature/better_flop_tensor
…o hotfix/rotor_variable_names
@@ -5,7 +5,7 @@ | |||
from .memory import activation_size | |||
|
|||
|
|||
class Stage(Enum): | |||
class Phase(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stage
should be Phase
with respect to RPC phase.
# TODO: the attribute node_size should be removed in the future | ||
setattr(n, 'node_size', n.meta.get('fwd_mem_tmp', 0) + n.meta.get('fwd_mem_out', 0)) | ||
for par in n.all_input_nodes: | ||
par.meta['fwd_mem_out'] = par.meta.get('fwd_mem_out', 0) + n.meta.get('fwd_mem_in', 0) | ||
par.meta['fwd_mem_out'] = max(par.meta.get('fwd_mem_out', 0), n.meta.get('fwd_mem_in', 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max
is more plausible for this calculation.
@@ -94,12 +94,11 @@ def extract_tensor_meta(obj): | |||
|
|||
tensor_meta = tree_map(extract_tensor_meta, result) | |||
n.meta['tensor_meta'] = tensor_meta | |||
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta` | |||
|
|||
n.meta = {**n.meta, **asdict(meta_info), 'fwd_mem_out': 0} # extend MetaInfo to `n.meta` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid doubled MetaInfoProp
that introduces doubled fwd_mem_out
.
print(chain) | ||
print(node_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops
colossalai/fx/profiler/dataflow.py
Outdated
def is_forward(n: Node): | ||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
return n.meta['stage'] == Stage.FORWARD | ||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' | ||
return n.meta['phase'] == Phase.FORWARD | ||
|
||
|
||
def is_loss(n: Node): | ||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
return n.meta['stage'] == Stage.LOSS | ||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' | ||
return n.meta['phase'] == Phase.LOSS | ||
|
||
|
||
def is_placeholder(n: Node): | ||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
return n.meta['stage'] == Stage.PLACEHOLDER | ||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' | ||
return n.meta['phase'] == Phase.PLACEHOLDER | ||
|
||
|
||
def is_backward(n: Node): | ||
assert 'stage' in n.meta, f'Node meta of {n} has no key `stage`!' | ||
return n.meta['stage'] == Stage.BACKWARD | ||
assert 'phase' in n.meta, f'Node meta of {n} has no key `phase`!' | ||
return n.meta['phase'] == Phase.BACKWARD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These can be merged into one function e.g. is_stage(node: Node, stage: Phase)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The var name in my code should be phase for consistency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
What's fixed?
In the last PR #1587, we modified some naming of variables. This hotfix will change the names correctly.
Testing