Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] all_reduce op and distributed info in graphs #284

Merged
merged 42 commits into from Jun 29, 2023

Conversation

soodoshll
Copy link
Collaborator

@soodoshll soodoshll commented Jun 19, 2023

  • add the all_reduce op
  • add nccl-related headers and libs when building tasks (as a new pass include_nccl_pass)
  • support grouping in distributed (by using ncclCommSplit)
  • We now have a example of all_reduce(relu(x * w)) in ./examples/distributed/test.py

@soodoshll soodoshll changed the title [Distributed] DistributedFlowGraph all_reduce op [Distributed] all_reduce op and distributed info in graphs Jun 22, 2023
@soodoshll
Copy link
Collaborator Author

@yaoyaoding this pr is ready for review :)

soodoshll and others added 2 commits June 22, 2023 15:45
Merely assigning environment variables is insufficient for setting up
dev environment now. We need to run pip to install hidet package in
develop mode.

Users still need to build source files written in C++ manually. Consider
integrating that into `setup.py` in the future?
Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @soodoshll !

I left some suggestions on the data organization and implementation.

Comment on lines 96 to 99
def init_unique_id(unqie_id: NcclUniqueId) -> None:
if not nccl_available():
raise RuntimeError("NCCL is not available")
nccl_runtime_api.get_unique_id(unqie_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define init_unique_id(...) as

def create_unique_id() -> NcclUniqueId:
    ...

I feel the current API is not very intuitive.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point here is now we need the NcclUniqueId to be shared by all processes. And the current solution is

  1. Create a shared NcclUniqueId object;
  2. Launch multiple processes with the shared uniqueid object as one argument;
  3. Init the shared uniqueid object in process 0, which need the reference to the shared object
    If we create the NcclUniqueId in process 0 after processes have been launched, it's not so easy to do the broadcast (if there's an elegant way of broadcasting, please let me know).

A workaround is to 1) create the shared object; 2) launch processes; 3) create a unique id object; 4) copy its value back to the shared object.

Comment on lines 125 to 129
# For distributed graphs
self.nrank = nrank
self.rank = rank
self.groups = groups

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define a new class called FlowGraphAttrs and define these attributes in that class. Then add a field in FlowGraph with FlowGraphAttrs type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like

class FlowGraph:
    def __init__(..., attrs=None):
        ...
        self.attrs: FlowGraphAttrs = attrs if attrs else FlowGraphAttrs()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment on lines 195 to 202
def is_distributed(self):
return self.nrank is not None or self.rank is not None

def set_dist_attrs(self, nrank: int, rank: int, groups: Optional[List[List[int]]] = None):
self.nrank = nrank
self.rank = rank
self.groups = groups

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's define thses functions at the module that will use these functionality, instead of defining them as FlowGraph methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have replaced them with set_attrs

self.comm_id = comm_id
self.op = op

super().__init__('all_reduce', inputs=[x], outputs=[y], attributes={})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better also add comm_id and op to attributes, so that the user can see the comm_id and op when compiling the task.

return f"all_reduce_{self.op}_{self.comm_id}"

def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModule]:
# we may need current rank here to avoid duplicated working_dirs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify the problem here? Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add the comm_id to attributes, then the op hash would be different.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we run the compilation concurrently in multiple processes, for the same op, there might be race conditions in the local filesystem.

Comment on lines 315 to 316
comms_array = comms_to_array(self.nccl_comms)
runtime_api.set_nccl_comms(comms_array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create this when initialize the dist-related info, to avoid repeating creating the comm Array.

@@ -105,6 +114,10 @@ def __init__(
self.cuda_workspace: Optional[Storage] = None
self.cpu_workspace: Optional[Storage] = None

# distributed properties
self.dist_info: Optional[GraphDistributedInfo] = dist_info
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to put this in GraphMetaData.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better idea is to put the FlowGraphAttr in the GraphMetaData as a whole, instead of reiterating all attributes. But then where should we put FlowGraphAttr? Putting it in flow_graph.py will cause circular import.

@@ -105,6 +114,10 @@ def __init__(
self.cuda_workspace: Optional[Storage] = None
self.cpu_workspace: Optional[Storage] = None

# distributed properties
self.dist_info: Optional[GraphDistributedInfo] = dist_info
self.nccl_comms: List[NcclCommunicator] = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

store it as Array of NcclCommunicator directly, to avoid repeating creating the Array in run_async.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array of NcclCommunicator cannot be directly passed into C++. C++ needs an array of ncclComm_t, which is basically the handle of NcclCommunicator. And to avoid NcclCommunicators being released by GC, we need to maintain the list of NcclCommunicator. If we also maintain the ncclComm_t array, we will have two redundant arrays which almost save the same value

Comment on lines 19 to 29
def _recursive_find(root: Stmt):
if isinstance(root, BlackBoxStmt):
if root.template_string.startswith('nccl'):
return True
for child in dir(root):
if isinstance(child, Stmt):
if _recursive_find(child):
return True
return False

ret = _recursive_find(func.body)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use hidet.ir.tools.collect to collect all BlackStmt.

@@ -80,5 +81,6 @@ def lower(ir_module: IRModule) -> IRModule:
rule_based_simplify_pass(),
inline_let_stmt_pass(),
simplify_stmt_pass(),
include_nccl_pass(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later, we will use a pass to add the header information. Let's make this pass a general one and give a name like "annotate_headers" or "annotate_include_headers". Or "annotate_header_and_libs".

soodoshll and others added 8 commits June 22, 2023 23:00
)

Previously, if a primitive function calls a primitive function, the
`instantiate_symbols` pass will update the corresponding
`hidet.ir.primitives.func.PrimitiveFunctionRegistry.function` in-place
(I am not sure exactly how it's done, but this is what I observed),
adding symbol variables to its parameters. The primitive function pool
is a global variable, therefore this effect is cumulative across tuning
candidates. So while candidate 0 will have no problem, candidate 1 will
have two extra copies of symbol params, and so on, leading to compile
errors.

Since primitive functions do not need symbol vars, a quick fix is just
to not instantiate any symbols for them.
Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @soodoshll !

I left some comments.

python/hidet/distributed/distributed.py Show resolved Hide resolved
python/hidet/distributed/group.py Show resolved Hide resolved
Comment on lines 63 to 64
NCCL_COMMS = []
_NCCL_ARRAY = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NCCL_COMMS = []
_NCCL_ARRAY = None
NCCL_COMMS: List[NcclCommunicator] = []
_NCCL_ARRAY: 'Array' = None

python/hidet/distributed/store.py Show resolved Hide resolved
Comment on lines 53 to 59
self._filename = filename
self._lock_filename = filename + '.lock'
self._world_size = world_size

self._lock = filelock.FileLock(self._lock_filename)
self._cache = {}
self._timeout = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to add some type annotations to reduce the time of the code reader.

key = self.REGULAR_PREFIX + key
with self._lock:
with open(self._filename, "ab+") as f:
f.seek(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f.seek(0)

f.seek(0)
self._update(f)
has_key = key in self._cache
print(has_key, self._cache[key])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(has_key, self._cache[key])

if k is None:
return
v = self._read(f)
k = str(k, encoding='raw_unicode_escape')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I know why we choose this encoding, instead of encoding like 'utf-8'?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also better to add the reason to the comments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No special reasons besides pickle uses that. Switching to utf-8 since it is the default value of encoding/decoding.

python/hidet/distributed/store.py Show resolved Hide resolved
@@ -0,0 +1,137 @@
# Licensed under the Apache License, Version 2.0 (the "License");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to place this test to hidet/tests/distributed/test_file_store.py.

manually if required.

We use a 4-byte integer to record the length of each (encoded) key and value. So do not insert
more than 32768 bytes for each entry.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 byte integer could represent up to 2^31-1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. let me fix it


Deletion of an entry is done by adding a new entry with a suffix '-' (DELETE_PREFIX). It will
overwrite the insertion of the given entry when we scanning the file.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments, now the design is very clear!

@yaoyaoding
Copy link
Member

Thanks @soodoshll !

Looks good to me now. Good job!

There seems is a typo in the comment. Feel free to merge this PR by yourself after fixing that.

@soodoshll soodoshll merged commit 7c52c9d into hidet-org:main Jun 29, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants