-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Flux followup #9074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux followup #9074
Changes from all commits
efc7ed9
9b8f8c7
1b4d1c5
1887bda
a9cdfcc
de66c58
abad854
463b910
f23cb1b
568884a
ab3a550
079cb33
4161d93
89e0ccc
0ff2266
40e94e0
293bcd8
72d1cf0
f0301b2
95b0a55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -446,6 +446,7 @@ def get_1d_rotary_pos_embed( | |
| linear_factor=1.0, | ||
| ntk_factor=1.0, | ||
| repeat_interleave_real=True, | ||
| freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) | ||
| ): | ||
| """ | ||
| Precompute the frequency tensor for complex exponentials (cis) with given dimensions. | ||
|
|
@@ -468,6 +469,8 @@ def get_1d_rotary_pos_embed( | |
| repeat_interleave_real (`bool`, *optional*, defaults to `True`): | ||
| If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. | ||
| Otherwise, they are concateanted with themselves. | ||
| freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): | ||
| the dtype of the frequency tensor. | ||
| Returns: | ||
| `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] | ||
| """ | ||
|
|
@@ -476,19 +479,19 @@ def get_1d_rotary_pos_embed( | |
| if isinstance(pos, int): | ||
| pos = np.arange(pos) | ||
| theta = theta * ntk_factor | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] | ||
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] | ||
| t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] | ||
| freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] | ||
| freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] | ||
| if use_real and repeat_interleave_real: | ||
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] | ||
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] | ||
| freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] | ||
| freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] | ||
| return freqs_cos, freqs_sin | ||
| elif use_real: | ||
| freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] | ||
| freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] | ||
| freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] | ||
| freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] | ||
| return freqs_cos, freqs_sin | ||
| else: | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] | ||
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] | ||
| return freqs_cis | ||
|
|
||
|
|
||
|
|
@@ -540,6 +543,31 @@ def apply_rotary_emb( | |
| return x_out.type_as(x) | ||
|
|
||
|
|
||
| class FluxPosEmbed(nn.Module): | ||
| # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 | ||
| def __init__(self, theta: int, axes_dim: List[int]): | ||
| super().__init__() | ||
| self.theta = theta | ||
| self.axes_dim = axes_dim | ||
|
|
||
| def forward(self, ids: torch.Tensor) -> torch.Tensor: | ||
| n_axes = ids.shape[-1] | ||
| cos_out = [] | ||
| sin_out = [] | ||
| pos = ids.squeeze().float().cpu().numpy() | ||
| is_mps = ids.device.type == "mps" | ||
| freqs_dtype = torch.float32 if is_mps else torch.float64 | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul the results for flux are identical with this refactor
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aye, thanks! |
||
| for i in range(n_axes): | ||
| cos, sin = get_1d_rotary_pos_embed( | ||
| self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype | ||
| ) | ||
| cos_out.append(cos) | ||
| sin_out.append(sin) | ||
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) | ||
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) | ||
| return freqs_cos, freqs_sin | ||
|
|
||
|
|
||
| class TimestepEmbedding(nn.Module): | ||
| def __init__( | ||
| self, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a reference to the original BFL inference code?