Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/inference/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,8 @@ with autocast("cuda"):

images[0].save("fantasy_landscape.png")
```
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)
You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb)

## Tweak prompts reusing seeds and latents

You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb).
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __call__(
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
**kwargs,
):
Expand Down Expand Up @@ -98,12 +99,18 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# get the intial random noise
latents = torch.randn(
(batch_size, self.unet.in_channels, height // 8, width // 8),
generator=generator,
device=self.device,
)
# get the initial random noise unless the user supplied it
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)

# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
Expand Down