Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
931182f
[Repro] Correct reproducability
patrickvonplaten Jan 3, 2023
57b6694
up
patrickvonplaten Jan 3, 2023
70a7afd
up
patrickvonplaten Jan 3, 2023
5e5035c
uP
patrickvonplaten Jan 3, 2023
ebfb3b7
up
patrickvonplaten Jan 3, 2023
2d826b4
need better image
patrickvonplaten Jan 3, 2023
3275375
allow conversion from no state dict checkpoints
patrickvonplaten Jan 3, 2023
880a5ef
up
patrickvonplaten Jan 4, 2023
6988144
up
patrickvonplaten Jan 4, 2023
0417a36
up
patrickvonplaten Jan 4, 2023
2ff5f0b
up
patrickvonplaten Jan 4, 2023
e5ae3e6
check tensors
patrickvonplaten Jan 4, 2023
34eec6d
check tensors
patrickvonplaten Jan 4, 2023
c1273c2
check tensors
patrickvonplaten Jan 4, 2023
1d3aa46
check tensors
patrickvonplaten Jan 4, 2023
0e7ce64
next try
patrickvonplaten Jan 4, 2023
e314d2e
up
patrickvonplaten Jan 4, 2023
5da5aab
up
patrickvonplaten Jan 4, 2023
604a93c
better name
patrickvonplaten Jan 4, 2023
188ed46
up
patrickvonplaten Jan 4, 2023
1cfa031
up
patrickvonplaten Jan 4, 2023
c673e4f
Apply suggestions from code review
patrickvonplaten Jan 4, 2023
05ad68c
correct more
patrickvonplaten Jan 4, 2023
f2b00b2
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Jan 4, 2023
bce82e5
up
patrickvonplaten Jan 4, 2023
5021981
replace all torch randn
patrickvonplaten Jan 4, 2023
abfbecd
fix
patrickvonplaten Jan 4, 2023
2334802
correct
patrickvonplaten Jan 4, 2023
50654a9
correct
patrickvonplaten Jan 4, 2023
25096d3
finish
patrickvonplaten Jan 4, 2023
5f02e58
fix more
patrickvonplaten Jan 4, 2023
e6c1847
up
patrickvonplaten Jan 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/experimental/rl/value_guided_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...models.unet_1d import UNet1DModel
from ...pipelines import DiffusionPipeline
from ...utils import randn_tensor
from ...utils.dummy_pt_objects import DDPMScheduler


Expand Down Expand Up @@ -127,7 +128,7 @@ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, sca
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)

# generate initial noise and apply our conditions (to make the trajectories start at current state)
x1 = torch.randn(shape, device=self.unet.device)
x1 = randn_tensor(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/prior_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)

causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Jan 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, doesn't change the outputs here. Actually surprised that PyTorch seems to be able to handle + float("-inf") in newer versions, but think in older versions it wasn't the case.

)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
Expand Down
9 changes: 4 additions & 5 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

from ..utils import BaseOutput
from ..utils import BaseOutput, randn_tensor
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block


Expand Down Expand Up @@ -323,11 +323,10 @@ def __init__(self, parameters, deterministic=False):
)

def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype)
sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
)
x = self.mean + self.std * sample
return x

Expand Down
16 changes: 2 additions & 14 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, logging, replace_example_docstring
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
Expand Down Expand Up @@ -401,20 +401,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
)

if latents is None:
rand_device = "cpu" if device.type == "mps" else device

if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)

# scale the initial noise by the standard deviation required by the scheduler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
Expand Down Expand Up @@ -461,16 +461,8 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
else:
init_latents = torch.cat([init_latents], dim=0)

rand_device = "cpu" if device.type == "mps" else device
shape = init_latents.shape
if isinstance(generator, list):
shape = (1,) + shape[1:]
noise = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
]
noise = torch.cat(noise, dim=0).to(device)
else:
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from .mel import Mel

Expand Down Expand Up @@ -126,7 +127,7 @@ def __call__(
input_dims = self.get_input_dims()
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
if noise is None:
noise = torch.randn(
noise = randn_tensor(
(
batch_size,
self.unet.in_channels,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from ...utils import logging
from ...utils import logging, randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline


Expand Down Expand Up @@ -100,16 +100,7 @@ def __call__(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

rand_device = "cpu" if self.device.type == "mps" else self.device
if isinstance(generator, list):
shape = (1,) + shape[1:]
audio = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
audio = torch.cat(audio, dim=0).to(self.device)
else:
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)

# set step values
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
Expand Down
14 changes: 2 additions & 12 deletions src/diffusers/pipelines/ddim/pipeline_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from ...utils import deprecate
from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -103,17 +103,7 @@ def __call__(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

rand_device = "cpu" if self.device.type == "mps" else self.device
if isinstance(generator, list):
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch

from ...configuration_utils import FrozenDict
from ...utils import deprecate
from ...utils import deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -100,10 +100,10 @@ def __call__(

if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device)

# set step values
self.scheduler.set_timesteps(num_inference_steps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -143,20 +144,7 @@ def __call__(
)

if latents is None:
rand_device = "cpu" if self.device.type == "mps" else self.device

if isinstance(generator, list):
latents_shape = (1,) + latents_shape[1:]
latents = [
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0)
else:
latents = torch.randn(
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
)
latents = latents.to(self.device)
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import PIL_INTERPOLATION, deprecate
from ...utils import PIL_INTERPOLATION, deprecate, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -121,12 +121,7 @@ def __call__(
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_dtype = next(self.unet.parameters()).dtype

if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype)
latents = latents.to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)

image = image.to(device=self.device, dtype=latents_dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...models import UNet2DModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -71,7 +72,7 @@ def __call__(
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""

latents = torch.randn(
latents = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/paint_by_example/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, config, proj_size=768):
self.proj_out = nn.Linear(config.hidden_size, self.proj_size)

# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.rand((1, 1, self.proj_size)))
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))

def forward(self, pixel_values):
clip_output = self.model(pixel_values=pixel_values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...utils import logging
from ...utils import logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
Expand Down Expand Up @@ -300,20 +300,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
)

if latents is None:
rand_device = "cpu" if device.type == "mps" else device

if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)

# scale the initial noise by the standard deviation required by the scheduler
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/pndm/pipeline_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from ...models import UNet2DModel
from ...schedulers import PNDMScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -72,11 +73,11 @@ def __call__(
# the official paper: https://arxiv.org/pdf/2202.09778.pdf

# Sample gaussian noise to begin loop
image = torch.randn(
image = randn_tensor(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image = image.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
Expand Down
14 changes: 2 additions & 12 deletions src/diffusers/pipelines/repaint/pipeline_repaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ...models import UNet2DModel
from ...schedulers import RePaintScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -143,18 +143,8 @@ def __call__(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

rand_device = "cpu" if self.device.type == "mps" else self.device
image_shape = original_image.shape
if isinstance(generator, list):
shape = (1,) + image_shape[1:]
image = [
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
for i in range(batch_size)
]
image = torch.cat(image, dim=0).to(self.device)
else:
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
image = image.to(self.device)
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)

# set step values
self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ...models import UNet2DModel
from ...schedulers import ScoreSdeVeScheduler
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput


Expand Down Expand Up @@ -69,7 +70,7 @@ def __call__(

model = self.unet

sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
sample = randn_tensor(shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
Expand Down
Loading