From f9ead54fda4e86f166e4e42decf340fe384a8377 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Mon, 15 Apr 2024 19:10:49 -0400 Subject: [PATCH 01/13] Add support for _foreach operations and non-blocking to EMAModel --- src/diffusers/training_utils.py | 71 +++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 13 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 25e02a3d1492..5f843e697b34 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -180,6 +180,7 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, + foreach: bool = True, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, @@ -194,6 +195,7 @@ def __init__( inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @@ -248,6 +250,7 @@ def __init__( self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` + self.foreach = foreach self.model_cls = model_cls self.model_config = model_config @@ -324,15 +327,37 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed - for s_param, param in zip(self.shadow_params, parameters): + if self.foreach: if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): - context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) with context_manager(): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) + params_grad = [param for param in parameters if param.requires_grad] + s_params_grad = [s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad] + + if len(params_grad) < len(parameters): + torch._foreach_copy_( + [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], + [param for param in parameters if not param.requires_grad], + non_blocking=True + ) + + torch._foreach_sub_( + s_params_grad, + torch._foreach_sub(s_params_grad, params_grad), + alpha=one_minus_decay + ) + + else: + for s_param, param in zip(self.shadow_params, parameters): + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ @@ -344,10 +369,24 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.to(param.device).data) + if self.foreach: + torch._foreach_copy_( + [param.data for param in parameters], + [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)] + ) + else: + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) - def to(self, device=None, dtype=None) -> None: + def pin_memory(self) -> None: + r""" + Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers + for offloading EMA params to the host. + """ + + self.shadow_params = [p.pin_memory() for p in self.shadow_params] + + def to(self, device=None, dtype=None, non_blocking=False) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: @@ -355,8 +394,8 @@ def to(self, device=None, dtype=None) -> None: """ # .to() on the tensors handles None correctly self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params + p.to(device=device, dtype=dtype, non_blocking=non_blocking) if p.is_floating_point() + else p.to(device=device, non_blocking=non_blocking) for p in self.shadow_params ] def state_dict(self) -> dict: @@ -399,8 +438,14 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") - for c_param, param in zip(self.temp_stored_params, parameters): - param.data.copy_(c_param.data) + if self.foreach: + torch._foreach_copy_( + [param.data for param in parameters], + [c_param.data for c_param in self.temp_stored_params] + ) + else: + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None From caab32699764ee8a038930e02d8fbeb552a290dd Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Mon, 15 Apr 2024 20:07:11 -0400 Subject: [PATCH 02/13] default foreach to false --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 5f843e697b34..6e8a14a8af75 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -180,7 +180,7 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, - foreach: bool = True, + foreach: bool = False, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, From 49c8b606cc0e671923c83985df68d964e74f7e04 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Tue, 16 Apr 2024 17:01:05 -0400 Subject: [PATCH 03/13] add non-blocking EMA offloading to SD1.5 T2I example script --- examples/text_to_image/train_text_to_image.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 84f4c6514cfd..77bd1e202bb2 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -373,6 +373,7 @@ def parse_args(): ), ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") parser.add_argument( "--non_ema_revision", type=str, @@ -643,7 +644,10 @@ def load_model_hook(models, input_dir): if args.use_ema: load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) ema_unet.load_state_dict(load_model.state_dict()) - ema_unet.to(accelerator.device) + if args.offload_ema: + ema_unet.pin_memory() + else: + ema_unet.to(accelerator.device) del load_model for _ in range(len(models)): @@ -819,7 +823,10 @@ def collate_fn(examples): ) if args.use_ema: - ema_unet.to(accelerator.device) + if args.offload_ema: + ema_unet.pin_memory() + else: + ema_unet.to(accelerator.device) # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -985,7 +992,11 @@ def unwrap_model(model): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: + if args.offload_ema: + ema_unet.to(device="cuda", non_blocking=True) ema_unet.step(unet.parameters()) + if args.offload_ema: + ema_unet.to(device="cpu", non_blocking=True) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) From 4bbe8232df62f511f4cd01724506e699ddd51453 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:16:24 -0400 Subject: [PATCH 04/13] fix whitespace --- src/diffusers/training_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 6e8a14a8af75..bb25be5f9951 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -334,14 +334,14 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): with context_manager(): params_grad = [param for param in parameters if param.requires_grad] s_params_grad = [s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad] - + if len(params_grad) < len(parameters): torch._foreach_copy_( [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], [param for param in parameters if not param.requires_grad], non_blocking=True ) - + torch._foreach_sub_( s_params_grad, torch._foreach_sub(s_params_grad, params_grad), @@ -352,7 +352,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): for s_param, param in zip(self.shadow_params, parameters): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) - + with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) @@ -385,7 +385,7 @@ def pin_memory(self) -> None: """ self.shadow_params = [p.pin_memory() for p in self.shadow_params] - + def to(self, device=None, dtype=None, non_blocking=False) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. @@ -394,7 +394,7 @@ def to(self, device=None, dtype=None, non_blocking=False) -> None: """ # .to() on the tensors handles None correctly self.shadow_params = [ - p.to(device=device, dtype=dtype, non_blocking=non_blocking) if p.is_floating_point() + p.to(device=device, dtype=dtype, non_blocking=non_blocking) if p.is_floating_point() else p.to(device=device, non_blocking=non_blocking) for p in self.shadow_params ] From 527b8cb285338128c8caa9e710f8673defb1061d Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 20 Apr 2024 12:05:43 -0400 Subject: [PATCH 05/13] move foreach to cli argument --- examples/text_to_image/train_text_to_image.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 77bd1e202bb2..807451602a3b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -374,6 +374,7 @@ def parse_args(): ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") + parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.") parser.add_argument( "--non_ema_revision", type=str, @@ -611,7 +612,7 @@ def deepspeed_zero_init_disabled_context_manager(): ema_unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config, foreach=args.foreach) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -642,7 +643,7 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach) ema_unet.load_state_dict(load_model.state_dict()) if args.offload_ema: ema_unet.pin_memory() From 9d4794703da411b932aa4ba02121b32e70963687 Mon Sep 17 00:00:00 2001 From: drhead Date: Wed, 24 Apr 2024 14:42:15 -0400 Subject: [PATCH 06/13] linting --- examples/text_to_image/train_text_to_image.py | 8 ++++-- src/diffusers/training_utils.py | 25 ++++++++++--------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 807451602a3b..eb299924ac13 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -612,7 +612,9 @@ def deepspeed_zero_init_disabled_context_manager(): ema_unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config, foreach=args.foreach) + ema_unet = EMAModel( + ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config, foreach=args.foreach + ) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -643,7 +645,9 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach) + load_model = EMAModel.from_pretrained( + os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach + ) ema_unet.load_state_dict(load_model.state_dict()) if args.offload_ema: ema_unet.pin_memory() diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index bb25be5f9951..5e900f6a1a73 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -333,19 +333,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): with context_manager(): params_grad = [param for param in parameters if param.requires_grad] - s_params_grad = [s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad] + s_params_grad = [ + s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad + ] if len(params_grad) < len(parameters): torch._foreach_copy_( [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], [param for param in parameters if not param.requires_grad], - non_blocking=True + non_blocking=True, ) torch._foreach_sub_( - s_params_grad, - torch._foreach_sub(s_params_grad, params_grad), - alpha=one_minus_decay + s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay ) else: @@ -372,7 +372,7 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: if self.foreach: torch._foreach_copy_( [param.data for param in parameters], - [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)] + [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], ) else: for s_param, param in zip(self.shadow_params, parameters): @@ -380,8 +380,8 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: def pin_memory(self) -> None: r""" - Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers - for offloading EMA params to the host. + Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for + offloading EMA params to the host. """ self.shadow_params = [p.pin_memory() for p in self.shadow_params] @@ -394,8 +394,10 @@ def to(self, device=None, dtype=None, non_blocking=False) -> None: """ # .to() on the tensors handles None correctly self.shadow_params = [ - p.to(device=device, dtype=dtype, non_blocking=non_blocking) if p.is_floating_point() - else p.to(device=device, non_blocking=non_blocking) for p in self.shadow_params + p.to(device=device, dtype=dtype, non_blocking=non_blocking) + if p.is_floating_point() + else p.to(device=device, non_blocking=non_blocking) + for p in self.shadow_params ] def state_dict(self) -> dict: @@ -440,8 +442,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") if self.foreach: torch._foreach_copy_( - [param.data for param in parameters], - [c_param.data for c_param in self.temp_stored_params] + [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] ) else: for c_param, param in zip(self.temp_stored_params, parameters): From a09614e9050f7795bf5ae9d25a047e3bcee1ded0 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:22:24 -0400 Subject: [PATCH 07/13] Update README.md re: EMA weight training --- examples/text_to_image/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index f2931d3f347e..63ab1ea4acdd 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -170,6 +170,12 @@ For our small Pokemons dataset, the effects of Min-SNR weighting strategy might Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. +#### Training with EMA weights + +Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights. + +EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization. + ## Training with LoRA Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. From 503a76503d81323a97de5df134ec82970525205b Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Fri, 21 Jun 2024 22:49:29 -0400 Subject: [PATCH 08/13] correct args.foreach_ema --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index eb299924ac13..83d9736e9661 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -613,7 +613,7 @@ def deepspeed_zero_init_disabled_context_manager(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) ema_unet = EMAModel( - ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config, foreach=args.foreach + ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config, foreach=args.foreach_ema ) if args.enable_xformers_memory_efficient_attention: @@ -646,7 +646,7 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): if args.use_ema: load_model = EMAModel.from_pretrained( - os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach + os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema ) ema_unet.load_state_dict(load_model.state_dict()) if args.offload_ema: From 662e95801c0bf453300d2fff28c4ef6b6a26e162 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Fri, 21 Jun 2024 22:57:48 -0400 Subject: [PATCH 09/13] add tests for foreach ema --- tests/others/test_ema.py | 132 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 48437c575a91..253341db4141 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -157,3 +157,135 @@ def test_serialization(self): output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample assert torch.allclose(output, output_loaded, atol=1e-4) + +class EMAModelTestsForeach(unittest.TestCase): + model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" + batch_size = 1 + prompt_length = 77 + text_encoder_hidden_dim = 32 + num_in_channels = 4 + latent_height = latent_width = 64 + generator = torch.manual_seed(0) + + def get_models(self, decay=0.9999): + unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") + unet = unet.to(torch_device) + ema_unet = EMAModel(unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config, foreach=True) + return unet, ema_unet + + def get_dummy_inputs(self): + noisy_latents = torch.randn( + self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator + ).to(torch_device) + timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device) + encoder_hidden_states = torch.randn( + self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator + ).to(torch_device) + return noisy_latents, timesteps, encoder_hidden_states + + def simulate_backprop(self, unet): + updated_state_dict = {} + for k, param in unet.state_dict().items(): + updated_param = torch.randn_like(param) + (param * torch.randn_like(param)) + updated_state_dict.update({k: updated_param}) + unet.load_state_dict(updated_state_dict) + return unet + + def test_optimization_steps_updated(self): + unet, ema_unet = self.get_models() + # Take the first (hypothetical) EMA step. + ema_unet.step(unet.parameters()) + assert ema_unet.optimization_step == 1 + + # Take two more. + for _ in range(2): + ema_unet.step(unet.parameters()) + assert ema_unet.optimization_step == 3 + + def test_shadow_params_not_updated(self): + unet, ema_unet = self.get_models() + # Since the `unet` is not being updated (i.e., backprop'd) + # there won't be any difference between the `params` of `unet` + # and `ema_unet` even if we call `ema_unet.step(unet.parameters())`. + ema_unet.step(unet.parameters()) + orig_params = list(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert torch.allclose(s_param, param) + + # The above holds true even if we call `ema.step()` multiple times since + # `unet` params are still not being updated. + for _ in range(4): + ema_unet.step(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert torch.allclose(s_param, param) + + def test_shadow_params_updated(self): + unet, ema_unet = self.get_models() + # Here we simulate the parameter updates for `unet`. Since there might + # be some parameters which are initialized to zero we take extra care to + # initialize their values to something non-zero before the multiplication. + unet_pseudo_updated_step_one = self.simulate_backprop(unet) + + # Take the EMA step. + ema_unet.step(unet_pseudo_updated_step_one.parameters()) + + # Now the EMA'd parameters won't be equal to the original model parameters. + orig_params = list(unet_pseudo_updated_step_one.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert ~torch.allclose(s_param, param) + + # Ensure this is the case when we take multiple EMA steps. + for _ in range(4): + ema_unet.step(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert ~torch.allclose(s_param, param) + + def test_consecutive_shadow_params_updated(self): + # If we call EMA step after a backpropagation consecutively for two times, + # the shadow params from those two steps should be different. + unet, ema_unet = self.get_models() + + # First backprop + EMA + unet_step_one = self.simulate_backprop(unet) + ema_unet.step(unet_step_one.parameters()) + step_one_shadow_params = ema_unet.shadow_params + + # Second backprop + EMA + unet_step_two = self.simulate_backprop(unet_step_one) + ema_unet.step(unet_step_two.parameters()) + step_two_shadow_params = ema_unet.shadow_params + + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): + assert ~torch.allclose(step_one, step_two) + + def test_zero_decay(self): + # If there's no decay even if there are backprops, EMA steps + # won't take any effect i.e., the shadow params would remain the + # same. + unet, ema_unet = self.get_models(decay=0.0) + unet_step_one = self.simulate_backprop(unet) + ema_unet.step(unet_step_one.parameters()) + step_one_shadow_params = ema_unet.shadow_params + + unet_step_two = self.simulate_backprop(unet_step_one) + ema_unet.step(unet_step_two.parameters()) + step_two_shadow_params = ema_unet.shadow_params + + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): + assert torch.allclose(step_one, step_two) + + @skip_mps + def test_serialization(self): + unet, ema_unet = self.get_models() + noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() + + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) + loaded_unet = loaded_unet.to(unet.device) + + # Since no EMA step has been performed the outputs should match. + output = unet(noisy_latents, timesteps, encoder_hidden_states).sample + output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample + + assert torch.allclose(output, output_loaded, atol=1e-4) From e9ed284539d3e13dbf539c85ec095c220e036226 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 22 Jun 2024 08:42:29 +0530 Subject: [PATCH 10/13] code quality --- examples/text_to_image/train_text_to_image.py | 5 ++++- tests/others/test_ema.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 98add9a39eca..fa09671681b0 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -627,7 +627,10 @@ def deepspeed_zero_init_disabled_context_manager(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) ema_unet = EMAModel( - ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config, foreach=args.foreach_ema + ema_unet.parameters(), + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, + foreach=args.foreach_ema, ) if args.enable_xformers_memory_efficient_attention: diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 253341db4141..5bed42b8488f 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -158,6 +158,7 @@ def test_serialization(self): assert torch.allclose(output, output_loaded, atol=1e-4) + class EMAModelTestsForeach(unittest.TestCase): model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" batch_size = 1 @@ -170,7 +171,9 @@ class EMAModelTestsForeach(unittest.TestCase): def get_models(self, decay=0.9999): unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") unet = unet.to(torch_device) - ema_unet = EMAModel(unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config, foreach=True) + ema_unet = EMAModel( + unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config, foreach=True + ) return unet, ema_unet def get_dummy_inputs(self): From 8bdca089ad2c66645a7f2bb40855b5c8a12c4ab7 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Fri, 21 Jun 2024 23:29:33 -0400 Subject: [PATCH 11/13] add foreach to from_pretrained --- src/diffusers/training_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 58dcf8c81caa..8e1a0a6f208b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -350,11 +350,11 @@ def __init__( self.model_config = model_config @classmethod - def from_pretrained(cls, path, model_cls) -> "EMAModel": + def from_pretrained(cls, path, model_cls, foreach) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) - ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) ema_model.load_state_dict(ema_kwargs) return ema_model From 48039c3eb75d8fbd3ec9ec52dfa7a68cd6ce2525 Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 22 Jun 2024 03:23:04 -0400 Subject: [PATCH 12/13] default foreach false --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 8e1a0a6f208b..db69b436617b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -350,7 +350,7 @@ def __init__( self.model_config = model_config @classmethod - def from_pretrained(cls, path, model_cls, foreach) -> "EMAModel": + def from_pretrained(cls, path, model_cls, foreach = False) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) From cff1e06acb832e43a53733302aab3cb5ad573fcf Mon Sep 17 00:00:00 2001 From: drhead <1313496+drhead@users.noreply.github.com> Date: Sat, 22 Jun 2024 13:24:33 -0400 Subject: [PATCH 13/13] fix linting --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index db69b436617b..dd8889f9bce5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -350,7 +350,7 @@ def __init__( self.model_config = model_config @classmethod - def from_pretrained(cls, path, model_cls, foreach = False) -> "EMAModel": + def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path)