Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile for improved performance/AssertionError #98

Open
phalexo opened this issue May 6, 2023 · 6 comments
Open

torch.compile for improved performance/AssertionError #98

phalexo opened this issue May 6, 2023 · 6 comments

Comments

@phalexo
Copy link

phalexo commented May 6, 2023

In the IF documentation there is a suggestion that "torch.compile" can improve performance. I have tried

if_I = torch.compile(IFStageI('IF-I-XL-v1.0', device='cuda:1'))
if_II = torch.compile(IFStageII('IF-II-L-v1.0', device='cuda:2'))
if_III = torch.compile(StableStageIII('stable-diffusion-x4-upscaler', device='cuda:3'))

It is not clear from the documentation which type of objects can be compiled.

AssertionError Traceback (most recent call last)
Cell In[2], line 1
----> 1 if_I = torch.compile(IFStageI('IF-I-XL-v1.0', device='cuda:1'))
4 if_II = IFStageII('IF-II-L-v1.0', device='cuda:2')
7 if_III = StableStageIII('stable-diffusion-x4-upscaler', device='cuda:3')

File ~/.local/lib/python3.8/site-packages/torch/init.py:1441, in compile(model, fullgraph, dynamic, backend, mode, options, disable)
1439 if backend == "inductor":
1440 backend = _TorchCompileInductorWrapper(mode, options, dynamic)
-> 1441 return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)

File ~/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:182, in _TorchDynamoContext.call(self, fn)
179 new_mod._torchdynamo_orig_callable = mod.forward
180 return new_mod
--> 182 assert callable(fn)
184 callback = self.callback
185 on_enter = self.on_enter

AssertionError:

@MohsenSadeghi
Copy link

I got the same error when I tried to compile something other than a callable. Make sure that the model you use in torch.compile(model) can actually be called on some data as y = model(x).

@phalexo
Copy link
Author

phalexo commented May 15, 2023 via email

@MohsenSadeghi
Copy link

yeah, I did manage to compile the whole model, but did not get much of a performance boost! :') The bottleneck apparently was in the dataloader pipeline.

What are you trying to compile?

@phalexo
Copy link
Author

phalexo commented May 16, 2023 via email

@MohsenSadeghi
Copy link

I'm afraid its just the vanilla opt_net = torch.compile(net), where net is a nn.Module with a forward() method. Alternatively, I could add the decorator @torch.compile to the forward() method itself.

@phalexo
Copy link
Author

phalexo commented May 16, 2023

Regardless where I put it, either it complains about parallelism or simply hangs. What was the specific location where you put it "torch.compile" or "@torch.compile" ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants