-
Notifications
You must be signed in to change notification settings - Fork 6.4k
[Half precision] Make sure half-precision is correct #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
42e6d51
4667928
760a071
f3d19e1
c7743d5
b30c8c7
468b548
387a6b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,7 +55,13 @@ def __call__( | |
self.text_encoder.to(torch_device) | ||
|
||
# get prompt text embeddings | ||
text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") | ||
text_input = self.tokenizer( | ||
prompt, | ||
padding="max_length", | ||
max_length=self.tokenizer.model_max_length, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
Comment on lines
+58
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is important here to always pad to max_length, as that's how the model was trained. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes agree, but let's make sure to not do this when we create our text to image training script (it's def cleaner to mask out padding tokens and should help the model learn better as stated by Katherine on Slack as well) |
||
text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] | ||
|
||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
|
@@ -79,19 +85,25 @@ def __call__( | |
latents = torch.randn( | ||
(batch_size, self.unet.in_channels, height // 8, width // 8), | ||
generator=generator, | ||
device=torch_device, | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
latents = latents.to(torch_device) | ||
|
||
# set timesteps | ||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) | ||
extra_set_kwargs = {} | ||
if accepts_offset: | ||
extra_set_kwargs["offset"] = 1 | ||
|
||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | ||
|
||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | ||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
# and should be between [0, 1] | ||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
extra_kwargs = {} | ||
extra_step_kwargs = {} | ||
if accepts_eta: | ||
extra_kwargs["eta"] = eta | ||
|
||
self.scheduler.set_timesteps(num_inference_steps) | ||
extra_step_kwargs["eta"] = eta | ||
|
||
for t in tqdm(self.scheduler.timesteps): | ||
# expand the latents if we are doing classifier free guidance | ||
|
@@ -106,7 +118,7 @@ def __call__( | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
||
# compute the previous noisy sample x_t -> x_t-1 | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] | ||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] | ||
|
||
# scale and decode the image latents with vae | ||
latents = 1 / 0.18215 * latents | ||
|
Uh oh!
There was an error while loading. Please reload this page.