@@ -545,12 +545,6 @@ def collate_fn(examples, with_prior_preservation=False):
545
545
546
546
return batch
547
547
548
- def dump_keys (parent , suffix = '' ):
549
- for k in sorted (parent .keys ()):
550
- if isinstance (parent [k ], torch .Tensor ):
551
- print (f'{ suffix } { k } { list (parent [k ].shape )} mean={ torch .mean (parent [k ]):.3g} std={ torch .std (parent [k ]):.3g} ' )
552
- else :
553
- dump_keys (parent [k ], f'{ suffix } { k } .' )
554
548
555
549
class PromptDataset (Dataset ):
556
550
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
@@ -754,8 +748,9 @@ def main(args):
754
748
weight_dtype = torch .bfloat16
755
749
756
750
# Move unet, vae and text_encoder to device and cast to weight_dtype
751
+ # The VAE is in float32 to avoid NaN losses.
757
752
unet .to (accelerator .device , dtype = weight_dtype )
758
- vae .to (accelerator .device , dtype = weight_dtype )
753
+ vae .to (accelerator .device , dtype = torch . float32 )
759
754
text_encoder_one .to (accelerator .device , dtype = weight_dtype )
760
755
text_encoder_two .to (accelerator .device , dtype = weight_dtype )
761
756
@@ -795,7 +790,6 @@ def main(args):
795
790
796
791
unet .set_attn_processor (unet_lora_attn_procs )
797
792
unet_lora_layers = AttnProcsLayers (unet .attn_processors )
798
- print (f"UNet LoRA dict: { len (unet_lora_layers .state_dict ())} " )
799
793
800
794
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
801
795
def save_model_hook (models , weights , output_dir ):
@@ -1029,12 +1023,10 @@ def compute_embeddings(prompt, text_encoders, tokenizers):
1029
1023
with accelerator .accumulate (unet ):
1030
1024
pixel_values = batch ["pixel_values" ].to (dtype = weight_dtype )
1031
1025
1032
- if vae is not None :
1033
- # Convert images to latent space
1034
- model_input = vae .encode (pixel_values ).latent_dist .sample ()
1035
- model_input = model_input * vae .config .scaling_factor
1036
- else :
1037
- model_input = pixel_values
1026
+ # Convert images to latent space
1027
+ model_input = vae .encode (pixel_values ).latent_dist .sample ()
1028
+ model_input = model_input * vae .config .scaling_factor
1029
+ print (f"Model input dtype: { model_input .dtype } " )
1038
1030
1039
1031
# Sample noise that we'll add to the latents
1040
1032
noise = torch .randn_like (model_input )
0 commit comments