Skip to content

Commit

Permalink
Refactor funcol for readability and dynamo tracing (pytorch#104387)
Browse files Browse the repository at this point in the history
Move eager kernel impls to separate file, which is eaiser to read
(since users may be confused about 2 versions of each kernel in the same file)
and easier to set a dynamo policy to trace only the first file currently.

Pull Request resolved: pytorch#104387
Approved by: https://github.com/wanchaol, https://github.com/fduwjj, https://github.com/kumpera
  • Loading branch information
wconstab authored and pytorchmergebot committed Jul 6, 2023
1 parent 456ecef commit d64bada
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 264 deletions.
14 changes: 7 additions & 7 deletions test/distributed/test_inductor_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def func(inp, *, tag, ranks, group_size):
.check("buf0.copy_(arg0_1)") \
.check("buf1 = buf0") \
.check("buf1_work = dist.all_reduce(buf1") \
.check("fun_col._register_tensor_work(buf1, buf1_work)") \
.check("fun_col_impl._register_tensor_work(buf1, buf1_work)") \
.check("_wait_tensor(buf0)") \
.check("return (buf2, )") \
.run(code)
Expand Down Expand Up @@ -280,7 +280,7 @@ def func(inp, *, tag, ranks, group_size):
.check_not("buf1.copy_(") \
.check("buf2 = buf1") \
.check("buf2_work = dist.all_reduce(buf2") \
.check("fun_col._register_tensor_work(buf2, buf2_work)") \
.check("fun_col_impl._register_tensor_work(buf2, buf2_work)") \
.check("_wait_tensor(buf1)") \
.check("buf3 = buf1") \
.check("buf4 = empty_strided") \
Expand Down Expand Up @@ -319,7 +319,7 @@ def func(inp, *, tag, ranks, group_size):
.check("buf1 = buf0; del buf0 # reuse") \
.check("buf2 = buf1") \
.check("buf2_work = dist.all_reduce(buf2") \
.check("fun_col._register_tensor_work(buf2, buf2_work)") \
.check("fun_col_impl._register_tensor_work(buf2, buf2_work)") \
.check("_wait_tensor(buf1)") \
.check("buf3 = buf1") \
.check("return (buf3, buf4, buf5") \
Expand Down Expand Up @@ -561,9 +561,9 @@ def func(inp, *, tag, ranks, group_size):
.check_not("copy_(") \
.check("buf3_inputs = [buf0,arg0_1]") \
.check("buf3 = [buf1,buf2]") \
.check("buf3_work = fun_col._all_gather_into_tensor_coalesced_fallback("
.check("buf3_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback("
"output_tensors=buf3, input_tensors=buf3_inputs") \
.check("fun_col._register_tensor_work(buf3, buf3_work)") \
.check("fun_col_impl._register_tensor_work(buf3, buf3_work)") \
.check("_wait_tensor(buf1)") \
.check("buf4 = buf1") \
.check("buf6 = buf0; del buf0 # reuse") \
Expand Down Expand Up @@ -605,9 +605,9 @@ def func(inp, *, tag, ranks, group_size):
.check("buf2 = empty_strided") \
.check_not("copy_(") \
.check("buf3 = [buf1,buf2]") \
.check("buf3_work = fun_col._reduce_scatter_tensor_coalesced_fallback("
.check("buf3_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
"output_tensors=buf3, input_tensors=buf3_inputs") \
.check("fun_col._register_tensor_work(buf3, buf3_work)") \
.check("fun_col_impl._register_tensor_work(buf3, buf3_work)") \
.check("_wait_tensor(buf1)") \
.check("buf4 = buf1") \
.check("buf6 = buf0; del buf0 # reuse") \
Expand Down
18 changes: 9 additions & 9 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4598,7 +4598,7 @@ def should_allocate(self):

def codegen(self, wrapper):
wrapper.add_import_once(
"from torch.distributed._functional_collectives import _wait_tensor"
"from torch.distributed._functional_collectives_impl import _wait_tensor"
)
(input_collective,) = [t.codegen_reference() for t in self.inputs]
wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})")
Expand Down Expand Up @@ -4657,7 +4657,7 @@ def codegen(self, wrapper):
wrapper.add_import_once("import torch.distributed as dist")
wrapper.add_import_once("import torch.distributed.distributed_c10d as c10d")
wrapper.add_import_once(
"import torch.distributed._functional_collectives as fun_col"
"import torch.distributed._functional_collectives_impl as fun_col_impl"
)
# extract references to our args in string form for codegen output
input_names = [t.codegen_reference() for t in self.inputs]
Expand All @@ -4672,7 +4672,7 @@ def codegen(self, wrapper):
self.codegen_output(wrapper, output_name, input_names)
self.codegen_collective(wrapper, output_name, input_names)
wrapper.writeline(
f"fun_col._register_tensor_work({output_name}, {output_name}_work)"
f"fun_col_impl._register_tensor_work({output_name}, {output_name}_work)"
)


Expand Down Expand Up @@ -4842,7 +4842,7 @@ def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = dist.all_reduce_coalesced("
f"{output_name}, "
f"op=fun_col._str_to_reduce_op('{str(self.reduce_op)}'), "
f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), "
f"group={output_name}_pg, "
"async_op=True)"
)
Expand Down Expand Up @@ -4871,7 +4871,7 @@ def create(
def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = dist.all_reduce("
f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col._str_to_reduce_op('{str(self.reduce_op)}'))"
f"{output_name}, async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))"
)


Expand Down Expand Up @@ -4941,7 +4941,7 @@ def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = dist.reduce_scatter_tensor("
f"{output_name}[0], {output_name}_inputs[0], "
f"async_op=True, group={output_name}_pg, op=fun_col._str_to_reduce_op('{str(self.reduce_op)}'))"
f"async_op=True, group={output_name}_pg, op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'))"
)


Expand Down Expand Up @@ -4978,7 +4978,7 @@ def compute_size(new_size):

def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = fun_col._all_gather_into_tensor_coalesced_fallback("
f"{output_name}_work = fun_col_impl._all_gather_into_tensor_coalesced_fallback("
f"output_tensors={output_name}, "
f"input_tensors={output_name}_inputs, "
f"group={output_name}_pg, "
Expand Down Expand Up @@ -5021,10 +5021,10 @@ def compute_size(new_size):

def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = fun_col._reduce_scatter_tensor_coalesced_fallback("
f"{output_name}_work = fun_col_impl._reduce_scatter_tensor_coalesced_fallback("
f"output_tensors={output_name}, "
f"input_tensors={output_name}_inputs, "
f"op=fun_col._str_to_reduce_op('{str(self.reduce_op)}'), "
f"op=fun_col_impl._str_to_reduce_op('{str(self.reduce_op)}'), "
f"group={output_name}_pg, "
"async_op=True)"
)
Loading

0 comments on commit d64bada

Please sign in to comment.