From b649cb99cae89215fe3a972014702b46f348e356 Mon Sep 17 00:00:00 2001 From: Alan Du Date: Mon, 8 Jul 2024 15:34:58 -0400 Subject: [PATCH 1/4] Reformat docstring for `get_timestep_embedding` The original docstring formatting was off (looks like a bad line-wrap mangling the parameters), but I figured I might as well reformat to match the same docstring style as the rest of the file. --- src/diffusers/models/embeddings.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8bc30f7cabcf..319e871b170a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -35,10 +35,15 @@ def get_timestep_embedding( """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the - embeddings. :return: an [N x dim] Tensor of positional embeddings. + Args + timesteps (torch.Tensor): + a 1-D Tensor of N indices, one per batch element. These may be fractional. + embedding_dim (int): + the dimension of the output. + max_period (int): + controls the minimum frequency of the embeddings. + Returns + torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" From 56829fe54b3d5d14870763de86fe4538443d4055 Mon Sep 17 00:00:00 2001 From: Alan Du Date: Mon, 8 Jul 2024 17:00:16 -0400 Subject: [PATCH 2/4] Simplify `get_timestep_embedding` This simplifies the implementation (IMO) by getting rid of the redundant `torch.exp(-math.log(a) * b))` with `torch.pow(a, -b)` and updates the docstring to capture what each argument does. --- src/diffusers/models/embeddings.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 319e871b170a..12fda0d8c657 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -39,26 +39,27 @@ def get_timestep_embedding( timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): - the dimension of the output. + the dimension of the output. + flip_sin_to_cos (bool): + Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) + downscale_freq_shift (float): + Controls the delta between frequencies between dimensions + scale (float): + The minimum frequency of the embeddings. max_period (int): - controls the minimum frequency of the embeddings. + Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) - exponent = exponent / (half_dim - downscale_freq_shift) + steps = torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + steps = steps / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent) + emb = scale * torch.pow(max_period, -steps) emb = timesteps[:, None].float() * emb[None, :] - # scale embeddings - emb = scale * emb - # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) From 1eb3688b62e88ffdbea0c1351aedad81e59055b4 Mon Sep 17 00:00:00 2001 From: Alan Du Date: Tue, 9 Jul 2024 11:27:07 -0400 Subject: [PATCH 3/4] Revert implementation changes --- src/diffusers/models/embeddings.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 12fda0d8c657..99be90f384e3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -54,12 +54,17 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - steps = torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) - steps = steps / (half_dim - downscale_freq_shift) + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) - emb = scale * torch.pow(max_period, -steps) + emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] + # scale embeddings + emb = scale * emb + # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) From 2473871e9ab544bdaac952d0a2c1b2ef50b87fdd Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 10 Jul 2024 15:07:48 -1000 Subject: [PATCH 4/4] Update src/diffusers/models/embeddings.py --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 99be90f384e3..ec1c68b86c89 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -45,7 +45,7 @@ def get_timestep_embedding( downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): - The minimum frequency of the embeddings. + Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns