Skip to content

Commit

Permalink
[AudioLDM] Generalise conversion script (huggingface#3328)
Browse files Browse the repository at this point in the history
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
2 people authored and hari10599 committed May 20, 2023
1 parent 7315d6d commit de96990
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions scripts/convert_original_audioldm_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema: bool = False,
scheduler_type: str = "ddim",
num_in_channels: int = None,
model_channels: int = None,
num_head_channels: int = None,
device: str = None,
from_safetensors: bool = False,
) -> AudioLDMPipeline:
Expand All @@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt(
global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
:param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file
corresponding to the original architecture.
If `None`, will be automatically instantiated based on default values.
:param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param
prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original
AudioLDM checkpoints.
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
inferred.
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`.
:param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract
the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually
yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If
`checkpoint_path` is in `safetensors` format, load checkpoint with safetensors
instead of PyTorch.
:return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
Args:
checkpoint_path (`str`): Path to `.ckpt` file.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
set to the audioldm-s-full-v2 config.
image_size (`int`, *optional*, defaults to 512):
The image size that the model was trained on.
prediction_type (`str`, *optional*):
The prediction type that the model was trained on. If `None`, will be automatically
inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`.
num_in_channels (`int`, *optional*, defaults to None):
The number of UNet input channels. If `None`, it will be automatically inferred from the config.
model_channels (`int`, *optional*, defaults to None):
The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override
to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.
num_head_channels (`int`, *optional*, defaults to None):
The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override
to 32 for the small and medium checkpoints, and 64 for the large.
scheduler_type (`str`, *optional*, defaults to 'pndm'):
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
"ddim"]`.
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
inference. Non-EMA weights are usually better to continue fine-tuning.
device (`str`, *optional*, defaults to `None`):
The device to use. Pass `None` to determine automatically.
from_safetensors (`str`, *optional*, defaults to `False`):
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
"""

if not is_omegaconf_available():
Expand Down Expand Up @@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt(
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels

if model_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels

if num_head_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels

if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
Expand Down Expand Up @@ -960,6 +981,20 @@ def load_pipeline_from_original_audioldm_ckpt(
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--model_channels",
default=None,
type=int,
help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override"
" to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.",
)
parser.add_argument(
"--num_head_channels",
default=None,
type=int,
help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override"
" to 32 for the small and medium checkpoints, and 64 for the large.",
)
parser.add_argument(
"--scheduler_type",
default="ddim",
Expand Down Expand Up @@ -1009,6 +1044,8 @@ def load_pipeline_from_original_audioldm_ckpt(
extract_ema=args.extract_ema,
scheduler_type=args.scheduler_type,
num_in_channels=args.num_in_channels,
model_channels=args.model_channels,
num_head_channels=args.num_head_channels,
from_safetensors=args.from_safetensors,
device=args.device,
)
Expand Down

0 comments on commit de96990

Please sign in to comment.