Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix split module interaction with dead code (pytorch#104554)
Summary: Pull Request resolved: pytorch#104554 This change fixes split_module's interaction with dead code. Previously if a dead region was split out, split module would throw an error while attempting to access the outputs for the partition even though the partition has no outputs. This change adds a new unit test to cover the dead code case and changes the output check to allow no output. The split module with no output will now output None like a normal python function Unit Test Added: test_split_module_dead_code A module with dead code: ``` class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output ``` Before: ``` torch/fx/passes/split_module.py, line 357, in split_module base_mod_env[list(partition.outputs)[0]] = output_val IndexError: list index out of range ``` After: ``` class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes submod_2 = self.submod_2(x) submod_1 = self.submod_1(x); x = None return submod_1 class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes add = x + 2; x = None return None class GraphModule(torch.nn.Module): def forward(self, x): # No stacktrace found for following nodes mul = x * 2; x = None return mul ``` Submod 2 is correctly extracted Test Plan: Tested with new unit test Reviewed By: mustafaozdal Differential Revision: D47196732 fbshipit-source-id: 86a6ab33d292735d07b0efd25dc93e0763f4b552
- Loading branch information