diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index a119e12f73d1..9de27d38b44f 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -451,19 +451,18 @@ def main(): # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized # from the pre-trained checkpoints. For the extra channels added to the first layer, they are # initialized to zero. - if accelerator.is_main_process: - logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") - in_channels = 8 - out_channels = unet.conv_in.out_channels - unet.register_to_config(in_channels=in_channels) - - with torch.no_grad(): - new_conv_in = nn.Conv2d( - in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding - ) - new_conv_in.weight.zero_() - new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) - unet.conv_in = new_conv_in + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in # Freeze vae and text_encoder vae.requires_grad_(False) @@ -892,9 +891,12 @@ def collate_fn(examples): # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) + # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=unet, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + vae=accelerator.unwrap_model(vae), revision=args.revision, torch_dtype=weight_dtype, ) @@ -904,7 +906,9 @@ def collate_fn(examples): # run inference original_image = download_image(args.val_image_url) edited_images = [] - with torch.autocast(str(accelerator.device), enabled=accelerator.mixed_precision == "fp16"): + with torch.autocast( + str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16" + ): for _ in range(args.num_validation_images): edited_images.append( pipeline( @@ -959,7 +963,7 @@ def collate_fn(examples): if args.validation_prompt is not None: edited_images = [] pipeline = pipeline.to(accelerator.device) - with torch.autocast(str(accelerator.device)): + with torch.autocast(str(accelerator.device).replace(":0", "")): for _ in range(args.num_validation_images): edited_images.append( pipeline(