Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

errer: tramsform Model to Torch Script #89

Open
huangzhenjie opened this issue Oct 15, 2019 · 11 comments
Open

errer: tramsform Model to Torch Script #89

huangzhenjie opened this issue Oct 15, 2019 · 11 comments

Comments

@huangzhenjie
Copy link

when I want to save the trained Model as .pt. The error is (RuntimeError:
Could not export Python function call 'SwishImplementation'. Remove calls to Python functions before export. Did you forget add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants:) . How could I solve it? thanks!!!

@lukemelas
Copy link
Owner

Hello! Yes, the latest update to the swish implementation saves memory but broke torchscript. An update to fix this is coming.

@lukemelas
Copy link
Owner

Done -- check the latest update!

@huangzhenjie
Copy link
Author

thanks!! let me try.

@xiaoliumi
Copy link

xiaoliumi commented Jan 19, 2020

thanks!! let me try.

Do you get it done on this issue? I am precisely in the same trouble. My pytorch version is 1.3.1.

@ne-bo
Copy link

ne-bo commented Jan 24, 2020

Hi Luke!
Thanks a ton for this wonderful repo!

I have problems with scripting your implementation:
Getting attributes of tuples is not supported:
Getting attributes of tuples is not supported:
at /home/pavlovskaya/anaconda3/envs/antispoofing/lib/python3.7/site-packages/efficientnet_pytorch/model.py:76:11
def forward(self, inputs, drop_connect_rate=None):
"""
:param inputs: input tensor
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
:return: output of block
"""

    # Expansion and Depthwise Convolution
    x = inputs
    if self._block_args.expand_ratio != 1:
       ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        x = self._swish(self._bn0(self._expand_conv(inputs)))
    x = self._swish(self._bn1(self._depthwise_conv(x)))

And with tracing:
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:

So, unfortunately, can't use the EN((((

@jesford
Copy link

jesford commented Jun 24, 2020

I am having the same issue as @ne-bo, has there been any progress on this? Thanks so much!

@Daonancai
Copy link

I am having the same issue too

@Podidiving
Copy link

0.7.0 version has the same issue

@mzuzic
Copy link

mzuzic commented Nov 29, 2020

A workaround is to set model.set_swish(memory_efficient=False) before saving.

@medric49
Copy link

medric49 commented Mar 24, 2021

Hi !
Thanks a lot for this repo !

I have encountered the same problem stated at the beginning of this issue. When I try to get the torch script of my efficientnet model I get this error

RuntimeError:
Python builtin <built-in method apply of FunctionMeta object at 0x9145330> is currently not supported in Torchscript:
File "/home/xxx/.local/lib/python3.8/site-packages/efficientnet_pytorch/utils.py", line 76
def forward(self, x):
return SwishImplementation.apply(x)
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

This is my code

model = EfficientNet.from_pretrained('efficientnet-b2', num_classes=10)
torch.jit.script(model).save('model.ptc')

My pytorch version is 1.7.1
And my efficientnet_pytorch version is 0.7.0

Thanks

@mikel-brostrom
Copy link

mikel-brostrom commented Aug 12, 2022

Than you @mzuzic for #89 (comment)
Any news regarding export with memory efficient swish?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants