Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fx] supported data-dependent control flow in model tracing #1185

Merged
merged 2 commits into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions colossalai/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tracer import ColoTracer
30 changes: 30 additions & 0 deletions colossalai/fx/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def has_meta_tensor(self):
def _assert_has_meta(self):
assert self.has_meta_tensor, f'Meta tensor is not set for {self.node.name}'

@property
def device(self):
# Hack so we can track when devices are used. During meta-tensor propagation,
# replace these values with a constant 'meta'
return MetaDeviceAttribute(self, "device")

@property
def dtype(self):
self._assert_has_meta()
Expand Down Expand Up @@ -72,3 +78,27 @@ def __getattr__(self, k):

def __setitem__(self, indices, values):
return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {})


class ColoAttribute(ColoProxy):

def __init__(self, root, attr: str):
# this class is copied from torch.fx.Attribute
# but inherits ColoProxy
self.root = root
self.attr = attr
self.tracer = root.tracer
self._node = None

@property
def node(self):
if self._node is None:
self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node

def __call__(self, *args, **kwargs):
return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)


class MetaDeviceAttribute(ColoAttribute):
pass
1 change: 1 addition & 0 deletions colossalai/fx/tracer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tracer import ColoTracer
31 changes: 31 additions & 0 deletions colossalai/fx/tracer/_tracer_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List, Union, Any
from ..proxy import ColoProxy, MetaDeviceAttribute

__all__ = ['is_element_in_list', 'extract_meta']


def is_element_in_list(elements: Union[List[Any], Any], list_: List[Any]):
if isinstance(elements, (tuple, list, set)):
for ele in elements:
if ele not in list_:
return False, ele
else:
if elements not in list_:
return False, elements

return True, None


def extract_meta(*args, **kwargs):

def _convert(val):
if isinstance(val, MetaDeviceAttribute):
return 'meta'
elif isinstance(val, ColoProxy):
assert val.meta_tensor is not None
return val.meta_tensor
return val

new_args = [_convert(val) for val in args]
new_kwargs = {k: _convert(v) for k, v in kwargs.items()}
return new_args, new_kwargs
4 changes: 4 additions & 0 deletions colossalai/fx/tracer/meta_patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sys import meta_path
from .registry import *
from .patched_function import *
from .patched_module import *
Empty file.
7 changes: 7 additions & 0 deletions colossalai/fx/tracer/meta_patch/patched_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
from .registry import meta_patched_module


@meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input):
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
25 changes: 25 additions & 0 deletions colossalai/fx/tracer/meta_patch/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class PatchRegistry:

def __init__(self, name):
self.name = name
self.store = {}

def register(self, source):

def wrapper(func):
self.store[source] = func
return func

return wrapper

def get(self, source):
assert source in self.store
target = self.store[source]
return target

def has(self, source):
return source in self.store


meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution')
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution')
Loading