From 6cb8e20f92b7a8915ccae44b6a994e4c71232b22 Mon Sep 17 00:00:00 2001 From: Jiwook Han <33192762+mreraser@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:06:34 +0900 Subject: [PATCH 1/4] refac: docstrings in training_utils.py --- src/diffusers/training_utils.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 57bd9074870c..72c8f905d426 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -36,8 +36,9 @@ def set_seed(seed: int): """ - Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: seed (`int`): The seed to set. """ random.seed(seed) @@ -225,7 +226,8 @@ def _set_state_dict_into_text_encoder( def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): - """Compute the density for sampling the timesteps when doing SD3 training. + """ + Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -244,7 +246,8 @@ def compute_density_for_timestep_sampling( def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. + """ + Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -261,7 +264,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def free_memory(): - """Runs garbage collection. Then clears the cache of the available accelerator.""" + """ + Runs garbage collection. Then clears the cache of the available accelerator. + """ gc.collect() if torch.cuda.is_available(): @@ -494,7 +499,8 @@ def pin_memory(self) -> None: self.shadow_params = [p.pin_memory() for p in self.shadow_params] def to(self, device=None, dtype=None, non_blocking=False) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. + r""" + Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` @@ -528,23 +534,25 @@ def state_dict(self) -> dict: def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" + Saves the current parameters for restoring later. + Args: - Save the current parameters for restoring later. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. + parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. + + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ + if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") if self.foreach: @@ -560,9 +568,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: def load_state_dict(self, state_dict: dict) -> None: r""" - Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. + + Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ From 0ae2e22256f3e5414f12278ec91d591b94b742b9 Mon Sep 17 00:00:00 2001 From: Jiwook Han <33192762+mreraser@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:10:14 +0900 Subject: [PATCH 2/4] fix: manual edits --- src/diffusers/training_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 72c8f905d426..04f72ae67c21 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -90,8 +90,7 @@ def resolve_interpolation_mode(interpolation_type: str): in torchvision. Returns: - `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` - transform. + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` transform. """ if not is_torchvision_available(): raise ImportError( From 0c6df60f7c6288a2ca1d6d3f85f94456bd7a3a9c Mon Sep 17 00:00:00 2001 From: Jiwook Han Date: Thu, 10 Oct 2024 11:02:21 +0900 Subject: [PATCH 3/4] run make style --- src/diffusers/training_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 04f72ae67c21..a27c1ad6f176 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -90,7 +90,8 @@ def resolve_interpolation_mode(interpolation_type: str): in torchvision. Returns: - `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` transform. + `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` + transform. """ if not is_torchvision_available(): raise ImportError( @@ -542,8 +543,8 @@ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: - affecting the original optimization process. Store the parameters before the `copy_to()` method. After + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters + without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. Args: @@ -551,7 +552,7 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ - + if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") if self.foreach: From 71a98997c1b18899906ffa94bfa149c4ad8d52d0 Mon Sep 17 00:00:00 2001 From: Jiwook Han <33192762+mreraser@users.noreply.github.com> Date: Tue, 15 Oct 2024 21:51:00 +0900 Subject: [PATCH 4/4] add docstring at cast_training_params --- src/diffusers/training_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index a27c1ad6f176..11a4e1cc8069 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -195,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + """ + Casts the training parameters of the model to the specified data type. + + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ if not isinstance(model, list): model = [model] for m in model: