-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning #12814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -165,10 +165,12 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): | |||||
| self.activation = nn.SiLU() | ||||||
| self.out_layer = nn.Linear(time_dim, time_dim, bias=True) | ||||||
|
|
||||||
| @torch.autocast(device_type="cuda", dtype=torch.float32) | ||||||
| def forward(self, time): | ||||||
| args = torch.outer(time, self.freqs.to(device=time.device)) | ||||||
| time = time.to(dtype=torch.float32) | ||||||
| freqs = self.freqs.to(device=time.device, dtype=torch.float32) | ||||||
| args = torch.outer(time, freqs) | ||||||
| time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | ||||||
| time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I cast to |
||||||
| time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) | ||||||
| return time_embed | ||||||
|
|
||||||
|
|
@@ -269,8 +271,8 @@ def __init__(self, time_dim, model_dim, num_params): | |||||
| self.out_layer.weight.data.zero_() | ||||||
| self.out_layer.bias.data.zero_() | ||||||
|
|
||||||
| @torch.autocast(device_type="cuda", dtype=torch.float32) | ||||||
| def forward(self, x): | ||||||
| x = x.to(dtype=self.out_layer.weight.dtype) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. umm actually this did not look correct to me - we want to upcast it to float32, no?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, if we force |
||||||
| return self.out_layer(self.activation(x)) | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.