Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.jit support #25

Closed
mdraw opened this issue Feb 1, 2019 · 6 comments
Closed

torch.jit support #25

mdraw opened this issue Feb 1, 2019 · 6 comments
Labels
enhancement New feature or request v0.4

Comments

@mdraw
Copy link

mdraw commented Feb 1, 2019

All network layers in this repo can be traced with torch.jit.trace(), but their control flow won't work correctly in traced modules (e.g. training/eval mode is not respected and the check in

if self.i < len(self.drop_values):
will be ignored if it's true during the tracing process), unlike PyTorch's native Dropout implementations. I've tried to fix this by rewriting the control flow explicitly in TorchScript, but I didn't get it to work yet.

Since tracing this code does not necessarily emit warnings (see example below), I think this incompatibility should be documented here to make sure no one mistakenly trains traced networks. OTOH networks that are traced in eval mode after a complete training should work as intended, as long as they are not put in training mode again.

Example code that produces a wrong TracedModule without any warning (using PyTorch 1.0.0):

import torch
from dropblock import DropBlock2D, LinearScheduler

drop_block = LinearScheduler(
    DropBlock2D(block_size=3, drop_prob=0.),
    start_value=0.,
    stop_value=0.25,
    nr_steps=5
)

x = torch.randn(1, 1, 8, 8)
traced = torch.jit.trace(drop_block, x)
@miguelvr miguelvr added enhancement New feature or request v0.4 labels Feb 3, 2019
@miguelvr
Copy link
Owner

miguelvr commented Feb 3, 2019

I haven't messed with PyTorch 1.0 tracing features yet. However, this is something that interests me and that I would like very much to support.

Both modules will have indeed to be converted to TorchScript due to control flow.

I'll have a look today to see if I can get any progress.

@miguelvr
Copy link
Owner

miguelvr commented Feb 3, 2019

@mdraw I'm surprised you had no errors tracing the DropBlock2D module. I'm converting it to ScriptModule and it requires a bunch of subtle changes in the code.

@miguelvr
Copy link
Owner

miguelvr commented Feb 3, 2019

I think it might not be possible to trace the LinearScheduler because it changes the value of the drop_prob of a module, which I think can not be done with TorchScript.

I have to investigate this, however.

@miguelvr miguelvr mentioned this issue Feb 3, 2019
3 tasks
@mdraw
Copy link
Author

mdraw commented Feb 4, 2019

I'm surprised you had no errors tracing the DropBlock2D module.

In my (limited) experience, implementing things as ScriptModules is currently pretty hard because of the limited language and library features that the compiler understands, but tracing has somehow worked even for complicated models - although you have to watch out for code that the tracer doesn't recognize.

I think it might not be possible to trace the LinearScheduler because it changes the value of the drop_prob of a module, which I think can not be done with TorchScript

Yes, I also don't know how this could be done. __constants__ work fine for constants, but I don't see any way to define module-level variables that you can change dynamically. I just checked: torch.nn.Dropout2d doesn't support changing p either if it's in ScriptModule form.

Besides this issue, your implementation in #27 looks good to me. DropBlock2D and DropBlock3D work as intended now in my local tests.

@miguelvr
Copy link
Owner

miguelvr commented Feb 4, 2019

Yes, the DropBlock modules are now working fine. I didn't merge yet due to the LinearScheduler. I will do some more research but I really think it is not possible at the moment

@miguelvr
Copy link
Owner

closing this as DropBlock does not benefit from JIT scripting, as it can be traced for inference time without problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request v0.4
Projects
None yet
Development

No branches or pull requests

2 participants