Skip to content

Commit

Permalink
Fix split module interaction with dead code (pytorch#104554)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#104554

This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs.

This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function

Unit Test Added:
test_split_module_dead_code

A module with dead code:
```
class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2 # we want this
                dead_line = x + 2 # this is dead
                return output
```

Before:
```
torch/fx/passes/split_module.py, line 357, in split_module
base_mod_env[list(partition.outputs)[0]] = output_val
IndexError: list index out of range
```

After:
```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        submod_2 = self.submod_2(x)
        submod_1 = self.submod_1(x);  x = None
        return submod_1

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            add = x + 2;  x = None
            return None

    class GraphModule(torch.nn.Module):
        def forward(self, x):
            # No stacktrace found for following nodes
            mul = x * 2;  x = None
            return mul
```
Submod 2 is correctly extracted

Test Plan: Tested with new unit test

Reviewed By: mustafaozdal

Differential Revision: D47196732

fbshipit-source-id: 81bb570788a4baf6ba8953688d663dfc8fdd8d4d
  • Loading branch information
benghaem authored and facebook-github-bot committed Aug 2, 2023
1 parent 30442c0 commit 89ff0e6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
32 changes: 32 additions & 0 deletions test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,38 @@ def mod_partition(node: Node):

self.assertEqual(orig_out, submodules_out)

def test_split_module_dead_code(self):
class ModWithDeadCode(torch.nn.Module):
def forward(self, x):
output = x * 2 # we want this
dead_line = x + 2 # this is dead
return output

mod = ModWithDeadCode()
traced = torch.fx.symbolic_trace(mod)

# split into before (0), target (1), and after(2)
saw_mul = False

def split_callback(n):
nonlocal saw_mul
if n.target == operator.mul:
saw_mul = True
return 1

if not saw_mul:
return 0
if saw_mul:
return 2

split = split_module(traced, mod, split_callback)

x = torch.randn((5,))
torch.testing.assert_close(
split(x), traced(x)
)


def test_split_module_kwargs_expansion(self):
class ModuleWithKwargsExpansion(torch.nn.Module):
def forward(self, x, **kwargs):
Expand Down
12 changes: 8 additions & 4 deletions torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ def record_cross_partition_use(
output_vals = tuple(
partition.environment[orig_nodes[name]] for name in partition.outputs
)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
partition.graph.output(output_vals)
num_output_vals = len(output_vals)
if num_output_vals > 0:
output_vals = output_vals[0] if num_output_vals == 1 else output_vals
partition.graph.output(output_vals)

if keep_original_order:
# first get the attr nodes required by this partition
Expand All @@ -346,12 +348,14 @@ def record_cross_partition_use(
partition.submod_name,
tuple(base_mod_env[name] for name in partition.inputs),
)
if len(partition.outputs) > 1:

num_outputs = len(partition.outputs)
if num_outputs > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
else:
elif num_outputs == 1:
base_mod_env[list(partition.outputs)[0]] = output_val

for node in m.graph.nodes:
Expand Down

0 comments on commit 89ff0e6

Please sign in to comment.