diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0890842f5775..7684fdf9cd6c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h - emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) - emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_h = get_1d_rotary_pos_embed( + embed_dim // 2, grid[0].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) + emb_w = get_1d_rotary_pos_embed( + embed_dim // 2, grid[1].reshape(-1), use_real=use_real + ) # (H*W, D/2) if use_real else (H*W, D/4) if use_real: - cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) - sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) return cos, sin else: emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) @@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed( Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ + assert dim % 2 == 0 + if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor