diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 57bd9074870c..11a4e1cc8069 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -36,8 +36,9 @@ def set_seed(seed: int): """ - Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: seed (`int`): The seed to set. """ random.seed(seed) @@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): + """ + Casts the training parameters of the model to the specified data type. + + Args: + model: The PyTorch model whose parameters will be cast. + dtype: The data type to which the model parameters will be cast. + """ if not isinstance(model, list): model = [model] for m in model: @@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder( def compute_density_for_timestep_sampling( weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None ): - """Compute the density for sampling the timesteps when doing SD3 training. + """ + Compute the density for sampling the timesteps when doing SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling( def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): - """Computes loss weighting scheme for SD3 training. + """ + Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def free_memory(): - """Runs garbage collection. Then clears the cache of the available accelerator.""" + """ + Runs garbage collection. Then clears the cache of the available accelerator. + """ gc.collect() if torch.cuda.is_available(): @@ -494,7 +506,8 @@ def pin_memory(self) -> None: self.shadow_params = [p.pin_memory() for p in self.shadow_params] def to(self, device=None, dtype=None, non_blocking=False) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. + r""" + Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` @@ -528,23 +541,25 @@ def state_dict(self) -> dict: def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" + Saves the current parameters for restoring later. + Args: - Save the current parameters for restoring later. - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. + parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" - Args: - Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: - affecting the original optimization process. Store the parameters before the `copy_to()` method. After + Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters + without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. + + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ + if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") if self.foreach: @@ -560,9 +575,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: def load_state_dict(self, state_dict: dict) -> None: r""" - Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. + + Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """