Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug #1710

Merged
merged 20 commits into from Oct 18, 2022

Conversation

super-dainiu
Copy link
Contributor

@super-dainiu super-dainiu commented Oct 14, 2022

Why move?

Currently _meta_registration.py is only used for MetaTensor in fx.profiler. Users with PyTorch 11 or lower cannot use the MetaTensor and will get an error message. I believe we should show the error message only when they import colossalai.fx. So move it from colossalai to colossalai.fx is a good choice.

Other changes

I removed all the color_debug because it causes conflicts on our CI. (should I?) cc @LSTM-Kirigaya

Tests

image

Comment on lines 1 to 7
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would still print a statement to the user upon importing colossalai if colossalai.fx is used by any other module in the library.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not print something like this. Instead, we can offer an API like xxx_is_compatible to check the compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not expose a constant to the user, no matter he is a user of the colossalai or the developer who wants to use colossalai.fx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. For the compatibility concerns, probably I have to come up with a global registration for this. Currently in this PR I will just define an API for checking though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is there any existing modules in ColossalAI for compatibility checks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be only within this fx module as compatibility is specific to this module.

@FrankLeeeee
Copy link
Contributor

The CI failed, please check it out.

@FrankLeeeee
Copy link
Contributor

Remember to remove the test label when your PR is not completely ready for review as it will incur testing runs for every push. I will help you remove it first.

@super-dainiu super-dainiu changed the title [fx] move _meta_registration.py to fx folder. [fx/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug Oct 17, 2022
@super-dainiu super-dainiu changed the title [fx/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug [fx/MetaTensor/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug Oct 17, 2022
@super-dainiu super-dainiu changed the title [fx/MetaTensor/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug [fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug Oct 17, 2022
Comment on lines +8 to +31
def compatibility(is_backward_compatible: bool = False) -> Callable:
"""A decorator to make a function compatible with different versions of PyTorch.

Args:
is_backward_compatible (bool, optional): Whether the function is backward compatible. Defaults to False.

Returns:
Callable: The decorated function
"""

def decorator(func):
if META_COMPATIBILITY:
return func
else:
if is_backward_compatible:
return func
else:

def wrapper(*args, **kwargs):
raise RuntimeError(f'Function `{func.__name__}` is not compatible with PyTorch {torch.__version__}')

return wrapper

return decorator
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a decorator that wraps our defined function. If the PyTorch version is not satisfied, we can wrap the function to be a RuntimeError.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between this and the PyTorch one?

Comment on lines 34 to 42
def check_meta_compatibility() -> bool:
"""Check the meta compatibility. Normally it should be called before importing some of the `colossalai.fx`
modules. If the meta compatibility is not satisfied, the `colossalai.fx` modules will be replaced by its
experimental counterparts.

Returns:
bool: The meta compatibility
"""
return META_COMPATIBILITY
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An API for checking

@@ -0,0 +1,32 @@
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file colossalai/fx/profiler/constant.py is deleted and then re-created? You should change the file directly instead of deleting it to minimize git change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I split this file into two, so git cannot identify this as a filename change.

Comment on lines +1 to +5
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .profiler import profile_function, profile_method, profile_module
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module
from .registry import meta_profiler_function, meta_profiler_module
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some documentation in this file to explain why this folder is experimental.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might refactor this experimental/ folder later, it hasn't been updated for a long time.

@FrankLeeeee FrankLeeeee merged commit 393f594 into hpcaitech:main Oct 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants