Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug #1710

Merged
merged 20 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 0 additions & 7 deletions colossalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

Expand Down
7 changes: 7 additions & 0 deletions colossalai/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
try:
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would still print a statement to the user upon importing colossalai if colossalai.fx is used by any other module in the library.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not print something like this. Instead, we can offer an API like xxx_is_compatible to check the compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not expose a constant to the user, no matter he is a user of the colossalai or the developer who wants to use colossalai.fx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. For the compatibility concerns, probably I have to come up with a global registration for this. Currently in this PR I will just define an API for checking though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But is there any existing modules in ColossalAI for compatibility checks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be only within this fx module as compatibility is specific to this module.

from .tracer import ColoTracer, meta_trace
from .graph_module import ColoGraphModule
from .passes import MetaInfoProp
2 changes: 1 addition & 1 deletion colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.passes.algorithms.ckpt_solver_rotor import _construct_chain, _compute_table, _rec
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

INF = float("inf")

Expand Down
2 changes: 1 addition & 1 deletion colossalai/fx/profiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ... import META_COMPATIBILITY
from .. import META_COMPATIBILITY
if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
import pytest
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor

Expand Down
9 changes: 5 additions & 4 deletions tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor

Expand Down Expand Up @@ -54,8 +54,9 @@ def _is_graph_linearized(gm: GraphModule):
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32)
label = torch.rand(2, 5)
m.cuda()
data = torch.rand(2, 3, 32, 32).cuda()
label = torch.rand(2, 5).cuda()
loss = criterion(m(data), label)
loss.backward()
loss = criterion(gm(data), label)
Expand All @@ -77,7 +78,7 @@ def _run_ckpt_solver(rank):
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = ColoGraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm.cuda()).run(MetaTensor(data, fake_device='cuda'))
MetaInfoProp(gm.cuda()).run(MetaTensor(data).cuda())
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
if solver == solver_rotor:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_ckpt_solvers/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from colossalai.fx.passes.algorithms import solver_rotor, linearize
from colossalai.fx.passes.algorithms.operation import Loss, ForwardCheck, ForwardEnable, ForwardNograd
import pytest
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
if META_COMPATIBILITY:
from colossalai.fx.profiler.tensor import MetaTensor

Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_comm_size_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
import pytest

MODEL_DIM = 16
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_meta/test_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY

import pytest

Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_meta/test_backward.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision.models as tm
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
import pytest

if META_COMPATIBILITY:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_meta/test_meta_trace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torchvision.models as tm
import timm.models as tmm
import torch
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
import pytest

if META_COMPATIBILITY:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fx/test_meta_info_prop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch.fx import symbolic_trace
from colossalai import META_COMPATIBILITY
from colossalai.fx import META_COMPATIBILITY
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata

BATCH_SIZE = 2
Expand Down