Skip to content

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Aug 18, 2022

The following will run the pipeline in cuda:1 (not cuda:0) as would be expected:

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-3-diffusers",
    use_auth_token=True).to("cuda:1")
pipe("Some prompt")

However, passing a device to __call__ will still move the pipeline to that device.

The implementation of DiffusionPipeline.to() does nothing if the device is None. This is to preserve the same semantics as PyTorch, where AFAIK if you use to(None) the object is not moved anywhere.

If we were to forgo that behaviour, we could make DiffusionPipeline.to(None) select cuda by default (when available). This would make for much simpler code in all pipeline implementations, as they'd just need to invoke self.to, but might break the expectations of PyTorch users. The current implementation in all the pipelines __call__ methods does exactly that: select cuda if available, and if no previous device was already set by the user. It is a bit repetitive and maybe a bit fragile.

Is it important to preserve PyTorch semantics in this regard, or is it better to make all __call__ implementations simpler? What do you think @anton-l @patil-suraj @patrickvonplaten ?

Fixes #195

Note that pipelines will still be moved to the default cuda device
during `__call__` unless the same device is used there. Addressing that
in a separate commit.
The following will run the pipeline in cuda:1 (not cuda:0) as expected:

```Python
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-3-diffusers",
    use_auth_token=True).to("cuda:1")
pipe("Some prompt")
```

I debated whether to place this logic in `DiffusionPipeline.to()`. It
would make for much simpler code in all pipeline implementations (they
just need to invoke `self.to`), but might break the expectations of
PyTorch users, where AFAIK using `to(None)` does not move the object
anywhere.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 18, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @pcuenca!

Haven't seen cases of .to(None) in the wild before, so to me it's fine to break the expectations a bit :)
So implementing it like this:

def to(self, torch_device: Optional[Union[str, torch.device]] = None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
        ...

And then just doing self.to(torch_device) inside the pipelines would be much cleaner and less implementation error-prone for future pipelines

@pcuenca
Copy link
Member Author

pcuenca commented Aug 18, 2022

Haven't seen cases of .to(None) in the wild before, so to me it's fine to break the expectations a bit :)

Cool, I agree! That's how I would have done it if it had been for an internal project. Because it's open source, I might have been extremely cautious here :)

def to(self, torch_device: Optional[Union[str, torch.device]] = None):
        if torch_device is None:
            torch_device = "cuda" if torch.cuda.is_available() else "cpu"
        ...

And then just doing self.to(torch_device) inside the pipelines would be much cleaner and less implementation error-prone for future pipelines

One thing, though. If we expect users to use .to(device) and then not provide any device during __call__, our implementation must avoid moving the models to the default cuda device if a previous one was selected.

I'll prepare it right away so you can take a look.

@pcuenca
Copy link
Member Author

pcuenca commented Aug 18, 2022

@anton-l not super happy after the change, .to() works as expected, but the __call__ methods need to be careful to use self.device instead of the argument. What do you think?

@anton-l
Copy link
Member

anton-l commented Aug 18, 2022

Hmm, indeed. But choosing between the first version and this one, I would still go with the current image.to(self.device) 🙂
Guess there are just two options available, wdyt @patil-suraj?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing in an idea -> we could also just remove torch_device as in input to the forward function. It's more "PyTorch'y" to use the .to(...) API IMO.

So we could just deprecate the "torch_device" parameter, saying that it'll be removed in 0.3.0 and only rely on to(...).

What do you think?

@pcuenca
Copy link
Member Author

pcuenca commented Aug 18, 2022

we could also just remove torch_device as in input to the forward function.

I think that's a very sensible idea. But I also love that if you do nothing it woks as expected (uses cuda if available), so we'd do an automatic placement on creation. Does that sound right?

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR @pcuenca !

I agree with @patrickvonplaten about removing the torch_device argument. Would be nice to have just one API for device handling, rather than having both options. That would be much cleaner IMO.

@patrickvonplaten
Copy link
Contributor

torch_device

I'd actually say to not do an automatic displacement to really stay 1-to-1 the same as PyTorch. I really like the fact that you know in PyTorch models are always on "cpu" by default.

@anton-l @patil-suraj what do you think?

@pcuenca
Copy link
Member Author

pcuenca commented Aug 18, 2022

I'd actually say to not do an automatic displacement to really stay 1-to-1 the same as PyTorch. I really like the fact that you know in PyTorch models are always on "cpu" by default.

That's a breaking change with respect to the current version. No big deal, and it's easy to understand in terms of PyTorch resemblance. I personally liked that you had to do nothing and the pipeline selected the GPU by default; to me, the pipeline is a high-level solution that is there to help you by providing sensible defaults. Similar to disabling gradients computation.

But the alternatives are a bit feeble, so this might well be the best compromise.

@patrickvonplaten
Copy link
Contributor

+1 it is indeed a bigger breaking change.

For me, one of the two options sounds like the best:

1.)
Remove torch_device from forward(...) and keep defaulting the init on GPU. Thinking more about it I'm actually fine with it! Also, considering that when passing device="auto" to from_pretrained(...) of Transformers the model is automatically moved to GPU. I would then maybe add a logging statement though (pipeline moved on GPU...).
However the drawback of this approach is that what do we do when multiple GPUs are available? device="auto" in this case moves different layers on different GPUs - that's something that is out of scope here IMO, so we would probably just move it on the first GPU

2.)
Remove torch_device from forward(...) and change default init to CPU. Big advantage that we leave the complexity of possible multi-device placement up to the user & it's more PyTorchy.
Drawback: More of a breaking change, but I think it's fine at that stage of the library.

@patil-suraj @anton-l @pcuenca - let's maybe try to decide somewhat quickly now so that we can include this PR in today's release?

@anton-l
Copy link
Member

anton-l commented Aug 19, 2022

I'm in favor of option 1 out of what Patrick suggested. "auto" can be improved a bit later, while defaulting to "cpu" might increase friction for first-timers

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Aug 19, 2022

Thinking a bit more about it, I'm actually much more for option 2, mainly because:

Happy to adhere to 1. though if @patil-suraj @anton-l and @pcuenca you prefer

@pcuenca
Copy link
Member Author

pcuenca commented Aug 19, 2022

Those are great points, @patrickvonplaten, let's go for the simpler option and do 2 instead.

@anton-l
Copy link
Member

anton-l commented Aug 19, 2022

Ok, good points, option 2 it is! :)

@patil-suraj
Copy link
Contributor

I'm also very much in favor of 2. with .to we trying to be more pytorchy, so we should try to mimic as close as possible to avoid any confusion. Let the user handle everything related to device.

`pipeline.to()` now has PyTorch semantics.
@pcuenca
Copy link
Member Author

pcuenca commented Aug 19, 2022

I did 2, can you please take another look? @patil-suraj @anton-l @patrickvonplaten

We possibly need to change some documentation and examples. Should we do that in a separate PR?

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Pretty much the same comment as @anton-l and @patrickvonplaten .

Also think thekwargs should go in all pipelines.

"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0.
Consider using `pipe.to(torch_device)` instead."
)
# ...set device as previously
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) this comment should go above the if cond

@anton-l
Copy link
Member

anton-l commented Aug 19, 2022

Yes, kwargs should be supported in all updated pipelines @pcuenca, sorry for commenting only on DDIM :)

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@anton-l
Copy link
Member

anton-l commented Aug 19, 2022

Ok, let's resolve the conflicts and merge if @patrickvonplaten and @patil-suraj don't have any objections :)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! @anton-l feel free to merge!

@pcuenca
Copy link
Member Author

pcuenca commented Aug 19, 2022

Yes, kwargs should be supported in all updated pipelines @pcuenca, sorry for commenting only on DDIM :)

Sure, no problem at all, that's what I understood :)

@pcuenca
Copy link
Member Author

pcuenca commented Aug 19, 2022

Sorry, I have family business going on and can't fix the conflicts right now. Can you do it @anton-l ?, otherwise I'll do it later.

@anton-l anton-l merged commit 71ba8ae into main Aug 19, 2022
@patil-suraj patil-suraj deleted the pipeline-to-device branch August 19, 2022 16:56
natolambert pushed a commit that referenced this pull request Sep 7, 2022
* Implement `pipeline.to(device)`

* DiffusionPipeline.to() decides best device on None.

* Breaking change: torch_device removed from __call__

`pipeline.to()` now has PyTorch semantics.

* Use kwargs and deprecation notice

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply torch_device compatibility to all pipelines.

* style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: anton-l <anton@huggingface.co>
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
The PyTorch decomposition for the op `aten.upsample_bilinear2d.vec`
is merged in the upstream repo and hence removed from this file.
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Implement `pipeline.to(device)`

* DiffusionPipeline.to() decides best device on None.

* Breaking change: torch_device removed from __call__

`pipeline.to()` now has PyTorch semantics.

* Use kwargs and deprecation notice

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply torch_device compatibility to all pipelines.

* style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: anton-l <anton@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement pipeline.to(device)
5 participants