Skip to content

Commit

Permalink
[JIT] Add __prepare_scriptable__ duck typing to allow replacing nn.…
Browse files Browse the repository at this point in the history
…modules with scriptable preparations (pytorch#45645) (pytorch#49242)

Summary:
Pull Request resolved: pytorch#49242

Fixes pytorch#45072

As discussed with zdevito gchanan cpuhrsch and suo, this change allows developers to create custom preparations for their modules before scripting. This is done by adding a `__prepare_scriptable__` method to a module which returns the prepared scriptable module out-of-place. It does not expand the API surface for end users.

Prior art by jamesr66a: pytorch#42244

Test Plan: Imported from OSS

Reviewed By: dongreenberg

Differential Revision: D25500303

fbshipit-source-id: d3ec9005de27d8882fc29d02f0d08acd2a5c6b2c
  • Loading branch information
Meghan Lele authored and hwangdeyu committed Jan 14, 2021
1 parent e9af379 commit 3763093
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 0 deletions.
53 changes: 53 additions & 0 deletions test/jit/test_recursive_script.py
Expand Up @@ -495,6 +495,59 @@ def forward(self, x):

self.checkModule(M(), (torch.randn(5, 5),))

def test_prepare_scriptable_basic(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()

t = torch.randn(5, 5)
m = SeluButReluWhenScripted()
sm = torch.jit.script(m)
eager_out = m(t)
script_out = sm(t)
self.assertNotEqual(eager_out, script_out)

def test_prepare_scriptable_iterable_modules(self):
class SeluButReluWhenScripted(torch.nn.SELU):
def __prepare_scriptable__(self):
return nn.ReLU()

class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
shared = SeluButReluWhenScripted()
self.sequential = nn.Sequential(
SeluButReluWhenScripted(),
SeluButReluWhenScripted(),
nn.Sequential(SeluButReluWhenScripted(), shared, SeluButReluWhenScripted()),
shared,
)
self.module_list = nn.ModuleList([SeluButReluWhenScripted(),
shared,
SeluButReluWhenScripted()])

def forward(self, x):
for mod in self.module_list:
x += mod(x)
x += self.sequential(x)
return x

t = torch.randn(5, 5)
m = M()
eager_out = m(t.clone())
sm = torch.jit.script(m)
script_out = sm(t.clone())
self.assertNotEqual(eager_out, script_out)

def test_prepare_scriptable_cycle(self):
t = torch.randn(5, 5)
c = torch.nn.Module()
p = torch.nn.Module()
c.__dict__["_p"] = p
p.__dict__["_c"] = c

sm = torch.jit.script(p)

def test_attributes(self):
@torch.jit.script
class Inner2(object):
Expand Down
26 changes: 26 additions & 0 deletions test/jit/test_torchbind.py
Expand Up @@ -62,6 +62,32 @@ def f():
return ss1.pop() + ss2.pop()
test_equality(f, lambda x: x)

# test nn module with prepare_scriptable function
class NonJitableClass(object):
def __init__(self, int1, int2):
self.int1 = int1
self.int2 = int2

def return_vals(self):
return self.int1, self.int2

class CustomWrapper(torch.nn.Module):
def __init__(self, foo):
super(CustomWrapper, self).__init__()
self.foo = foo

def forward(self) -> None:
self.foo.increment(1)
return

def __prepare_scriptable__(self):
int1, int2 = self.foo.return_vals()
foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
return CustomWrapper(foo)

foo = CustomWrapper(NonJitableClass(1, 2))
jit_foo = torch.jit.script(foo)

def test_torchbind_take_as_arg(self):
global StackString # see [local resolution in python]
StackString = torch.classes._TorchScriptTesting._StackString
Expand Down
38 changes: 38 additions & 0 deletions torch/jit/_script.py
Expand Up @@ -741,6 +741,43 @@ class RecursiveScriptModule(ScriptModule): # type: ignore
def __init__(self, arg=None):
super().__init__()

def call_prepare_scriptable_func_impl(obj, memo):
if not isinstance(obj, torch.nn.Module):
return obj

obj_id = id(obj)

# If obj_id is in memo, obj has already been prepared or is being
# prepared in another call up the stack.
if obj_id in memo:
return memo[id(obj)]

obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
# hierarchy when recursing below.
memo[obj_id] = obj

new_obj_dict = {}

for name in obj.__dict__:
sub_module = obj.__dict__.get(name)
if name == '_modules':
for k, v in sub_module.items():
sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
new_obj_dict[name] = sub_module
elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
else:
new_obj_dict[name] = sub_module

for k, v in new_obj_dict.items():
obj.__dict__[name] = v

return obj

def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)

def script(obj, optimize=None, _frames_up=0, _rcb=None):
r"""
Expand Down Expand Up @@ -894,6 +931,7 @@ def forward(self, input):
return obj

if isinstance(obj, torch.nn.Module):
obj = call_prepare_scriptable_func(obj)
return torch.jit._recursive.create_script_module(
obj, torch.jit._recursive.infer_methods_to_compile
)
Expand Down

0 comments on commit 3763093

Please sign in to comment.