diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 91af7360b711..0bdf02f804bd 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -170,11 +170,19 @@ For our small Narutos dataset, the effects of Min-SNR weighting strategy might n Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. + +#### Training with EMA weights + +Through the `EMAModel` class, we support a convenient method of tracking an exponential moving average of model parameters. This helps to smooth out noise in model parameter updates and generally improves model performance. If enabled with the `--use_ema` argument, the final model checkpoint that is saved at the end of training will use the EMA weights. + +EMA weights require an additional full-precision copy of the model parameters to be stored in memory, but otherwise have very little performance overhead. `--foreach_ema` can be used to further reduce the overhead. If you are short on VRAM and still want to use EMA weights, you can store them in CPU RAM by using the `--offload_ema` argument. This will keep the EMA weights in pinned CPU memory during the training step. Then, once every model parameter update, it will transfer the EMA weights back to the GPU which can then update the parameters on the GPU, before sending them back to the CPU. Both of these transfers are set up as non-blocking, so CUDA devices should be able to overlap this transfer with other computations. With sufficient bandwidth between the host and device and a sufficiently long gap between model parameter updates, storing EMA weights in CPU RAM should have no additional performance overhead, as long as no other calls force synchronization. + #### Training with DREAM We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper. + ## Training with LoRA Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1d36fd8cc79a..fa09671681b0 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -387,6 +387,8 @@ def parse_args(): ), ) parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.") + parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.") parser.add_argument( "--non_ema_revision", type=str, @@ -624,7 +626,12 @@ def deepspeed_zero_init_disabled_context_manager(): ema_unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) + ema_unet = EMAModel( + ema_unet.parameters(), + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, + foreach=args.foreach_ema, + ) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -655,9 +662,14 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + load_model = EMAModel.from_pretrained( + os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema + ) ema_unet.load_state_dict(load_model.state_dict()) - ema_unet.to(accelerator.device) + if args.offload_ema: + ema_unet.pin_memory() + else: + ema_unet.to(accelerator.device) del load_model for _ in range(len(models)): @@ -833,7 +845,10 @@ def collate_fn(examples): ) if args.use_ema: - ema_unet.to(accelerator.device) + if args.offload_ema: + ema_unet.pin_memory() + else: + ema_unet.to(accelerator.device) # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -1011,7 +1026,11 @@ def unwrap_model(model): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: + if args.offload_ema: + ema_unet.to(device="cuda", non_blocking=True) ema_unet.step(unet.parameters()) + if args.offload_ema: + ema_unet.to(device="cpu", non_blocking=True) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index d3ff926eac8a..dd8889f9bce5 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -274,6 +274,7 @@ def __init__( use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, + foreach: bool = False, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, @@ -288,6 +289,7 @@ def __init__( inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @@ -342,16 +344,17 @@ def __init__( self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` + self.foreach = foreach self.model_cls = model_cls self.model_config = model_config @classmethod - def from_pretrained(cls, path, model_cls) -> "EMAModel": + def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) - ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) ema_model.load_state_dict(ema_kwargs) return ema_model @@ -418,15 +421,37 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed - for s_param, param in zip(self.shadow_params, parameters): + if self.foreach: if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): - context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) with context_manager(): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) + params_grad = [param for param in parameters if param.requires_grad] + s_params_grad = [ + s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad + ] + + if len(params_grad) < len(parameters): + torch._foreach_copy_( + [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], + [param for param in parameters if not param.requires_grad], + non_blocking=True, + ) + + torch._foreach_sub_( + s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay + ) + + else: + for s_param, param in zip(self.shadow_params, parameters): + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ @@ -438,10 +463,24 @@ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.to(param.device).data) + if self.foreach: + torch._foreach_copy_( + [param.data for param in parameters], + [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)], + ) + else: + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.to(param.device).data) - def to(self, device=None, dtype=None) -> None: + def pin_memory(self) -> None: + r""" + Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for + offloading EMA params to the host. + """ + + self.shadow_params = [p.pin_memory() for p in self.shadow_params] + + def to(self, device=None, dtype=None, non_blocking=False) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: @@ -449,7 +488,9 @@ def to(self, device=None, dtype=None) -> None: """ # .to() on the tensors handles None correctly self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + p.to(device=device, dtype=dtype, non_blocking=non_blocking) + if p.is_floating_point() + else p.to(device=device, non_blocking=non_blocking) for p in self.shadow_params ] @@ -493,8 +534,13 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") - for c_param, param in zip(self.temp_stored_params, parameters): - param.data.copy_(c_param.data) + if self.foreach: + torch._foreach_copy_( + [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] + ) + else: + for c_param, param in zip(self.temp_stored_params, parameters): + param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 48437c575a91..5bed42b8488f 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -157,3 +157,138 @@ def test_serialization(self): output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample assert torch.allclose(output, output_loaded, atol=1e-4) + + +class EMAModelTestsForeach(unittest.TestCase): + model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" + batch_size = 1 + prompt_length = 77 + text_encoder_hidden_dim = 32 + num_in_channels = 4 + latent_height = latent_width = 64 + generator = torch.manual_seed(0) + + def get_models(self, decay=0.9999): + unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") + unet = unet.to(torch_device) + ema_unet = EMAModel( + unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config, foreach=True + ) + return unet, ema_unet + + def get_dummy_inputs(self): + noisy_latents = torch.randn( + self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator + ).to(torch_device) + timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device) + encoder_hidden_states = torch.randn( + self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator + ).to(torch_device) + return noisy_latents, timesteps, encoder_hidden_states + + def simulate_backprop(self, unet): + updated_state_dict = {} + for k, param in unet.state_dict().items(): + updated_param = torch.randn_like(param) + (param * torch.randn_like(param)) + updated_state_dict.update({k: updated_param}) + unet.load_state_dict(updated_state_dict) + return unet + + def test_optimization_steps_updated(self): + unet, ema_unet = self.get_models() + # Take the first (hypothetical) EMA step. + ema_unet.step(unet.parameters()) + assert ema_unet.optimization_step == 1 + + # Take two more. + for _ in range(2): + ema_unet.step(unet.parameters()) + assert ema_unet.optimization_step == 3 + + def test_shadow_params_not_updated(self): + unet, ema_unet = self.get_models() + # Since the `unet` is not being updated (i.e., backprop'd) + # there won't be any difference between the `params` of `unet` + # and `ema_unet` even if we call `ema_unet.step(unet.parameters())`. + ema_unet.step(unet.parameters()) + orig_params = list(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert torch.allclose(s_param, param) + + # The above holds true even if we call `ema.step()` multiple times since + # `unet` params are still not being updated. + for _ in range(4): + ema_unet.step(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert torch.allclose(s_param, param) + + def test_shadow_params_updated(self): + unet, ema_unet = self.get_models() + # Here we simulate the parameter updates for `unet`. Since there might + # be some parameters which are initialized to zero we take extra care to + # initialize their values to something non-zero before the multiplication. + unet_pseudo_updated_step_one = self.simulate_backprop(unet) + + # Take the EMA step. + ema_unet.step(unet_pseudo_updated_step_one.parameters()) + + # Now the EMA'd parameters won't be equal to the original model parameters. + orig_params = list(unet_pseudo_updated_step_one.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert ~torch.allclose(s_param, param) + + # Ensure this is the case when we take multiple EMA steps. + for _ in range(4): + ema_unet.step(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert ~torch.allclose(s_param, param) + + def test_consecutive_shadow_params_updated(self): + # If we call EMA step after a backpropagation consecutively for two times, + # the shadow params from those two steps should be different. + unet, ema_unet = self.get_models() + + # First backprop + EMA + unet_step_one = self.simulate_backprop(unet) + ema_unet.step(unet_step_one.parameters()) + step_one_shadow_params = ema_unet.shadow_params + + # Second backprop + EMA + unet_step_two = self.simulate_backprop(unet_step_one) + ema_unet.step(unet_step_two.parameters()) + step_two_shadow_params = ema_unet.shadow_params + + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): + assert ~torch.allclose(step_one, step_two) + + def test_zero_decay(self): + # If there's no decay even if there are backprops, EMA steps + # won't take any effect i.e., the shadow params would remain the + # same. + unet, ema_unet = self.get_models(decay=0.0) + unet_step_one = self.simulate_backprop(unet) + ema_unet.step(unet_step_one.parameters()) + step_one_shadow_params = ema_unet.shadow_params + + unet_step_two = self.simulate_backprop(unet_step_one) + ema_unet.step(unet_step_two.parameters()) + step_two_shadow_params = ema_unet.shadow_params + + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): + assert torch.allclose(step_one, step_two) + + @skip_mps + def test_serialization(self): + unet, ema_unet = self.get_models() + noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() + + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) + loaded_unet = loaded_unet.to(unet.device) + + # Since no EMA step has been performed the outputs should match. + output = unet(noisy_latents, timesteps, encoder_hidden_states).sample + output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample + + assert torch.allclose(output, output_loaded, atol=1e-4)