Skip to content

Commit c6fe3ab

Browse files
committed
upcast vae to float32.
1 parent 6133592 commit c6fe3ab

File tree

1 file changed

+6
-14
lines changed

1 file changed

+6
-14
lines changed

examples/dreambooth/train_dreambooth_lora_sd_xl.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,6 @@ def collate_fn(examples, with_prior_preservation=False):
545545

546546
return batch
547547

548-
def dump_keys(parent, suffix=''):
549-
for k in sorted(parent.keys()):
550-
if isinstance(parent[k], torch.Tensor):
551-
print(f'{suffix}{k} {list(parent[k].shape)} mean={torch.mean(parent[k]):.3g} std={torch.std(parent[k]):.3g}')
552-
else:
553-
dump_keys(parent[k], f'{suffix}{k}.')
554548

555549
class PromptDataset(Dataset):
556550
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
@@ -754,8 +748,9 @@ def main(args):
754748
weight_dtype = torch.bfloat16
755749

756750
# Move unet, vae and text_encoder to device and cast to weight_dtype
751+
# The VAE is in float32 to avoid NaN losses.
757752
unet.to(accelerator.device, dtype=weight_dtype)
758-
vae.to(accelerator.device, dtype=weight_dtype)
753+
vae.to(accelerator.device, dtype=torch.float32)
759754
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
760755
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
761756

@@ -795,7 +790,6 @@ def main(args):
795790

796791
unet.set_attn_processor(unet_lora_attn_procs)
797792
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
798-
print(f"UNet LoRA dict: {len(unet_lora_layers.state_dict())}")
799793

800794
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
801795
def save_model_hook(models, weights, output_dir):
@@ -1029,12 +1023,10 @@ def compute_embeddings(prompt, text_encoders, tokenizers):
10291023
with accelerator.accumulate(unet):
10301024
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
10311025

1032-
if vae is not None:
1033-
# Convert images to latent space
1034-
model_input = vae.encode(pixel_values).latent_dist.sample()
1035-
model_input = model_input * vae.config.scaling_factor
1036-
else:
1037-
model_input = pixel_values
1026+
# Convert images to latent space
1027+
model_input = vae.encode(pixel_values).latent_dist.sample()
1028+
model_input = model_input * vae.config.scaling_factor
1029+
print(f"Model input dtype: {model_input.dtype}")
10381030

10391031
# Sample noise that we'll add to the latents
10401032
noise = torch.randn_like(model_input)

0 commit comments

Comments
 (0)