Skip to content

audioldm2 model/pipeline review #13602

@hlky

Description

@hlky

audioldm2 model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against .ai/review-rules.md and all present referenced rule files. AGENTS.md was referenced by the rules but is not present in this checkout.

Duplicate search: checked GitHub issues/PRs for audioldm2, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel, pipeline_audioldm2, modeling_audioldm2, cross_attention_kwargs, gradient-checkpointing masks, scoring, and cross_attention_dim IndexError. I did not find likely duplicates for the issues below. Existing #12630 / PR #13111 cover a different GPT2Model AttributeError.

Issue 1: Projection mask fallback crashes when only one mask is provided

Affected code:

# concatenate clap and t5 text encoding
hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)
# concatenate attention masks
if attention_mask is None and attention_mask_1 is not None:
attention_mask = attention_mask_1.new_ones((hidden_states[:2]))
elif attention_mask is not None and attention_mask_1 is None:
attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2]))
if attention_mask is not None and attention_mask_1 is not None:
attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1)

Problem:
AudioLDM2ProjectionModel.forward() calls new_ones((hidden_states[:2])) and new_ones((hidden_states_1[:2])). Those are tensor slices, not shape tuples, so direct projection-model use crashes when one encoder mask is provided and the other is omitted. The first branch is also placed after concatenating hidden_states, so even changing it to hidden_states.shape[:2] there would create the wrong sequence length.

Impact:
Users providing one precomputed attention mask cannot use the projection model directly, and pipeline paths that mix precomputed embeddings/masks are fragile.

Reproduction:

import torch
from diffusers import AudioLDM2ProjectionModel

model = AudioLDM2ProjectionModel(text_encoder_dim=3, text_encoder_1_dim=4, langauge_model_dim=5)
h0 = torch.randn(2, 1, 3)
h1 = torch.randn(2, 7, 4)
mask1 = torch.ones(2, 7, dtype=torch.long)

model(hidden_states=h0, hidden_states_1=h1, attention_mask=None, attention_mask_1=mask1)
# TypeError: new_ones(): argument 'size' must be tuple of ints

Relevant precedent:
Use tensor .shape[:2] for mask creation, as the pipeline already does for default prompt masks.

Suggested fix:

# before concatenating hidden_states and hidden_states_1
if attention_mask is None and attention_mask_1 is not None:
    attention_mask = attention_mask_1.new_ones(hidden_states.shape[:2])
elif attention_mask is not None and attention_mask_1 is None:
    attention_mask_1 = attention_mask.new_ones(hidden_states_1.shape[:2])

hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)

Issue 2: Omitted second encoder states ignore the first encoder mask

Affected code:

encoder_hidden_states_1 = (
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
)
encoder_attention_mask_1 = (
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
)

encoder_hidden_states_1 = (
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
)
encoder_attention_mask_1 = (
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
)

encoder_hidden_states_1 = (
encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
)
encoder_attention_mask_1 = (
encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
)

Problem:
The blocks first replace encoder_hidden_states_1=None with encoder_hidden_states, then decide whether to fallback encoder_attention_mask_1 based on the already-mutated encoder_hidden_states_1. As a result, the second cross-attention stream uses the first hidden states but drops the first mask.

Impact:
Calling AudioLDM2UNet2DConditionModel with only encoder_hidden_states and encoder_attention_mask produces different results than explicitly passing the same states/mask as stream 1.

Reproduction:

import torch
from diffusers import AudioLDM2UNet2DConditionModel

torch.manual_seed(0)
model = AudioLDM2UNet2DConditionModel(
    sample_size=8, in_channels=4, out_channels=4, block_out_channels=(8,),
    layers_per_block=1, norm_num_groups=1,
    down_block_types=("CrossAttnDownBlock2D",),
    up_block_types=("CrossAttnUpBlock2D",),
    cross_attention_dim=((None, 8, 8),),
    attention_head_dim=1,
).eval()

sample = torch.randn(1, 4, 8, 8)
encoder = torch.randn(1, 5, 8)
mask = torch.tensor([[1, 1, 1, 0, 0]])

with torch.no_grad():
    implicit = model(sample, 1, encoder_hidden_states=encoder, encoder_attention_mask=mask).sample
    explicit = model(
        sample, 1, encoder_hidden_states=encoder, encoder_attention_mask=mask,
        encoder_hidden_states_1=encoder, encoder_attention_mask_1=mask,
    ).sample

print((implicit - explicit).abs().max())
# tensor(0.1561...)

Relevant precedent:
The intended fallback is visible from the code itself: stream 1 defaults to stream 0 when omitted.

Suggested fix:

if encoder_hidden_states_1 is None:
    encoder_hidden_states_1 = encoder_hidden_states
    if encoder_attention_mask_1 is None:
        encoder_attention_mask_1 = encoder_attention_mask

Issue 3: cross_attention_kwargs is accepted but ignored

Affected code:

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=generated_prompt_embeds,
encoder_hidden_states_1=prompt_embeds,
encoder_attention_mask_1=attention_mask,
return_dict=False,
)[0]

hidden_states = self.attentions[i * num_attention_per_layer + idx](
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=forward_encoder_hidden_states,
encoder_attention_mask=forward_encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = self.attentions[i * num_attention_per_layer + idx](
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=forward_encoder_hidden_states,
encoder_attention_mask=forward_encoder_attention_mask,
return_dict=False,
)[0]

hidden_states = self.attentions[i * num_attention_per_layer + idx](
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=forward_encoder_hidden_states,
encoder_attention_mask=forward_encoder_attention_mask,
return_dict=False,
)[0]

Problem:
AudioLDM2Pipeline.__call__() exposes cross_attention_kwargs, but does not pass it to the UNet. The UNet block normal paths also do not pass cross_attention_kwargs to Transformer2DModel.

Impact:
Custom attention processors and LoRA-style attention kwargs silently do nothing in normal inference.

Reproduction:

import torch
from diffusers import AudioLDM2UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor

class MarkerProcessor(AttnProcessor):
    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, marker=None, **kwargs):
        if marker:
            raise RuntimeError("cross_attention_kwargs propagated")
        return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs)

model = AudioLDM2UNet2DConditionModel(
    sample_size=8, in_channels=4, out_channels=4, block_out_channels=(8,),
    layers_per_block=1, norm_num_groups=1,
    down_block_types=("CrossAttnDownBlock2D",),
    up_block_types=("CrossAttnUpBlock2D",),
    cross_attention_dim=(8,),
    attention_head_dim=1,
).eval()
model.set_attn_processor(MarkerProcessor())

with torch.no_grad():
    model(torch.randn(1, 4, 8, 8), 1, encoder_hidden_states=torch.randn(1, 5, 8), cross_attention_kwargs={"marker": True})

print("marker was ignored")

Relevant precedent:

for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

Suggested fix:

# pipeline denoise call
noise_pred = self.unet(
    latent_model_input,
    t,
    encoder_hidden_states=generated_prompt_embeds,
    encoder_hidden_states_1=prompt_embeds,
    encoder_attention_mask_1=attention_mask,
    cross_attention_kwargs=cross_attention_kwargs,
    return_dict=False,
)[0]

# all AudioLDM2 cross-attn block normal paths
hidden_states = self.attentions[i * num_attention_per_layer + idx](
    hidden_states,
    encoder_hidden_states=forward_encoder_hidden_states,
    cross_attention_kwargs=cross_attention_kwargs,
    attention_mask=attention_mask,
    encoder_attention_mask=forward_encoder_attention_mask,
    return_dict=False,
)[0]

Issue 4: Gradient checkpointing misroutes Transformer2DModel arguments

Affected code:

hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
)[0]

hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
)[0]

hidden_states = self._gradient_checkpointing_func(
self.attentions[i * num_attention_per_layer + idx],
hidden_states,
forward_encoder_hidden_states,
None, # timestep
None, # class_labels
cross_attention_kwargs,
attention_mask,
forward_encoder_attention_mask,
)[0]

Problem:
The gradient-checkpointing branch calls Transformer2DModel positionally using a stale signature. With the current signature, cross_attention_kwargs, attention_mask, and encoder_attention_mask are shifted into the wrong parameters.

Impact:
_supports_gradient_checkpointing=True is advertised, but training with gradient checkpointing and encoder masks can crash or use the wrong masks.

Reproduction:

import torch
from diffusers import AudioLDM2UNet2DConditionModel

model = AudioLDM2UNet2DConditionModel(
    sample_size=8, in_channels=4, out_channels=4, block_out_channels=(8,),
    layers_per_block=1, norm_num_groups=1,
    down_block_types=("CrossAttnDownBlock2D",),
    up_block_types=("CrossAttnUpBlock2D",),
    cross_attention_dim=(8,),
    attention_head_dim=1,
)
model.enable_gradient_checkpointing()

sample = torch.randn(1, 4, 8, 8, requires_grad=True)
encoder = torch.randn(1, 5, 8, requires_grad=True)
mask = torch.tensor([[1, 1, 1, 0, 0]])

model(sample, 1, encoder_hidden_states=encoder, encoder_attention_mask=mask)
# RuntimeError: The size of tensor a (64) must match the size of tensor b (69)...

Relevant precedent:

for i, (resnet, attn) in enumerate(blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]

Suggested fix:
Do not checkpoint-call Transformer2DModel positionally. Follow the regular UNet block pattern: checkpoint the ResNet, then call the attention module with keyword arguments so the nested Transformer2DModel handles its own checkpointing.

hidden_states = self._gradient_checkpointing_func(self.resnets[i], hidden_states, temb)
hidden_states = self.attentions[i * num_attention_per_layer + idx](
    hidden_states,
    encoder_hidden_states=forward_encoder_hidden_states,
    cross_attention_kwargs=cross_attention_kwargs,
    attention_mask=attention_mask,
    encoder_attention_mask=forward_encoder_attention_mask,
    return_dict=False,
)[0]

Issue 5: Automatic scoring ranks across the full batch, not per prompt

Affected code:

# compute the audio-text similarity score using the CLAP model
logits_per_text = self.text_encoder(**inputs).logits_per_text
# sort by the highest matching generations per prompt
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())

Problem:
score_waveforms() sorts each prompt against every generated waveform, then index-selects globally. For batched prompts with num_waveforms_per_prompt > 1, a prompt can select waveforms generated for another prompt.

Impact:
The returned audio order can mix prompts, so users may receive a waveform for the wrong prompt after automatic scoring.

Reproduction:

import torch
from types import SimpleNamespace
from diffusers import AudioLDM2Pipeline

class Inputs(dict):
    def to(self, device): return self

class Tokenizer:
    def __call__(self, text, return_tensors=None, padding=None):
        return Inputs(input_ids=torch.ones(len(text), 2, dtype=torch.long))

class FeatureExtractor:
    sampling_rate = 16000
    def __call__(self, audio, return_tensors=None, sampling_rate=None):
        return SimpleNamespace(input_features=torch.zeros(len(audio), 1, 4))

class TextEncoder:
    def __call__(self, **inputs):
        return SimpleNamespace(logits_per_text=torch.tensor([[0.1, 0.2, 0.9, 0.8], [0.7, 0.6, 0.5, 0.4]]))

pipe = AudioLDM2Pipeline.__new__(AudioLDM2Pipeline)
pipe.tokenizer = Tokenizer()
pipe.feature_extractor = FeatureExtractor()
pipe.text_encoder = TextEncoder()
pipe.vocoder = SimpleNamespace(config=SimpleNamespace(sampling_rate=16000))

audio = torch.arange(16, dtype=torch.float32).view(4, 4)
ranked = pipe.score_waveforms(["prompt a", "prompt b"], audio, num_waveforms_per_prompt=2, device="cpu", dtype=torch.float32)
print(ranked[:, 0].tolist())
# [8.0, 12.0, 0.0, 4.0] selects rows 2/3 for prompt 0 and rows 0/1 for prompt 1

Relevant precedent:
No good in-repo precedent found; MusicLDMPipeline copies this method and appears to carry the same behavior.

Suggested fix:

batch_size = len(text) if isinstance(text, list) else 1
selected = []
for i in range(batch_size):
    start = i * num_waveforms_per_prompt
    end = start + num_waveforms_per_prompt
    local_scores = logits_per_text[i, start:end]
    local_indices = torch.argsort(local_scores, descending=True) + start
    selected.append(local_indices)

indices = torch.cat(selected)
audio = torch.index_select(audio, 0, indices.cpu())

Issue 6: Tuple cross_attention_dim length is not validated

Affected code:

if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],

Problem:
The constructor validates mismatched cross_attention_dim only when it is a list, but the public type allows tuples. A short tuple falls through and later raises an IndexError.

Impact:
Invalid configs fail with an opaque internal error instead of the intended config validation error. This is especially confusing for from_config() / custom checkpoint users.

Reproduction:

from diffusers import AudioLDM2UNet2DConditionModel

AudioLDM2UNet2DConditionModel(
    block_out_channels=(8, 16),
    layers_per_block=1,
    norm_num_groups=1,
    down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
    up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
    cross_attention_dim=(8,),
)
# IndexError: tuple index out of range

Relevant precedent:
The same constructor already validates tuple-like attention_head_dim and layers_per_block.

Suggested fix:

if not isinstance(cross_attention_dim, int) and len(cross_attention_dim) != len(down_block_types):
    raise ValueError(
        f"Must provide the same number of `cross_attention_dim` as `down_block_types`. "
        f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
    )

if not isinstance(transformer_layers_per_block, int) and len(transformer_layers_per_block) != len(down_block_types):
    raise ValueError(
        f"Must provide the same number of `transformer_layers_per_block` as `down_block_types`. "
        f"`transformer_layers_per_block`: {transformer_layers_per_block}. `down_block_types`: {down_block_types}."
    )

Coverage status

Fast and slow AudioLDM2 tests exist:

class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = AudioLDM2Pipeline
params = TEXT_TO_AUDIO_PARAMS
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"num_waveforms_per_prompt",
"generator",
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
unet = AudioLDM2UNet2DConditionModel(
block_out_channels=(8, 16),
layers_per_block=1,
norm_num_groups=8,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=(8, 16),
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[8, 16],
in_channels=1,
out_channels=1,
norm_num_groups=8,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
)
torch.manual_seed(0)
text_branch_config = {
"bos_token_id": 0,
"eos_token_id": 2,
"hidden_size": 8,
"intermediate_size": 37,
"layer_norm_eps": 1e-05,
"num_attention_heads": 1,
"num_hidden_layers": 1,
"pad_token_id": 1,
"vocab_size": 1000,
"projection_dim": 8,
}
audio_branch_config = {
"spec_size": 8,
"window_size": 4,
"num_mel_bins": 8,
"intermediate_size": 37,
"layer_norm_eps": 1e-05,
"depths": [1, 1],
"num_attention_heads": [1, 1],
"num_hidden_layers": 1,
"hidden_size": 192,
"projection_dim": 8,
"patch_size": 2,
"patch_stride": 2,
"patch_embed_input_channels": 4,
}
text_encoder_config = ClapConfig(
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
)
text_encoder = ClapModel(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
feature_extractor = ClapFeatureExtractor.from_pretrained(
"hf-internal-testing/tiny-random-ClapModel", hop_length=7900
)
torch.manual_seed(0)
text_encoder_2_config = T5Config(
vocab_size=32100,
d_model=32,
d_ff=37,
d_kv=8,
num_heads=1,
num_layers=1,
)
text_encoder_2 = T5EncoderModel(text_encoder_2_config)
tokenizer_2 = T5Tokenizer.from_pretrained("hf-internal-testing/tiny-random-T5Model", model_max_length=77)
torch.manual_seed(0)
language_model_config = GPT2Config(
n_embd=16,
n_head=1,
n_layer=1,
vocab_size=1000,
n_ctx=99,
n_positions=99,
)
language_model = GPT2LMHeadModel(language_model_config)
language_model.config.max_new_tokens = 8
torch.manual_seed(0)
projection_model = AudioLDM2ProjectionModel(
text_encoder_dim=16,
text_encoder_1_dim=32,
langauge_model_dim=16,
)
vocoder_config = SpeechT5HifiGanConfig(
model_in_dim=8,
sampling_rate=16000,
upsample_initial_channel=16,
upsample_rates=[2, 2],
upsample_kernel_sizes=[4, 4],
resblock_kernel_sizes=[3, 7],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
normalize_before=False,
)
vocoder = SpeechT5HifiGan(vocoder_config)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"feature_extractor": feature_extractor,
"language_model": language_model,
"projection_model": projection_model,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
}
return inputs
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = audioldm_pipe(**inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) == 256
audio_slice = audio[:10]
expected_slice = np.array(
[
2.602e-03,
1.729e-03,
1.863e-03,
-2.219e-03,
-2.656e-03,
-2.017e-03,
-2.648e-03,
-2.115e-03,
-2.502e-03,
-2.081e-03,
]
)
assert np.abs(audio_slice - expected_slice).max() < 1e-4
def test_audioldm2_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = audioldm_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
text_inputs = audioldm_pipe.tokenizer(
prompt,
padding="max_length",
max_length=audioldm_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
clap_prompt_embeds = audioldm_pipe.text_encoder.get_text_features(text_inputs)
if hasattr(clap_prompt_embeds, "pooler_output"):
clap_prompt_embeds = clap_prompt_embeds.pooler_output
clap_prompt_embeds = clap_prompt_embeds[:, None, :]
text_inputs = audioldm_pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=True,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
t5_prompt_embeds = audioldm_pipe.text_encoder_2(
text_inputs,
)
t5_prompt_embeds = t5_prompt_embeds[0]
projection_embeds = audioldm_pipe.projection_model(clap_prompt_embeds, t5_prompt_embeds)[0]
generated_prompt_embeds = audioldm_pipe.generate_language_model(projection_embeds, max_new_tokens=8)
inputs["prompt_embeds"] = t5_prompt_embeds
inputs["generated_prompt_embeds"] = generated_prompt_embeds
# forward
output = audioldm_pipe(**inputs)
audio_2 = output.audios[0]
assert np.abs(audio_1 - audio_2).max() < 1e-2
def test_audioldm2_negative_prompt_embeds(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
negative_prompt = 3 * ["this is a negative prompt"]
inputs["negative_prompt"] = negative_prompt
inputs["prompt"] = 3 * [inputs["prompt"]]
# forward
output = audioldm_pipe(**inputs)
audio_1 = output.audios[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = 3 * [inputs.pop("prompt")]
embeds = []
generated_embeds = []
for p in [prompt, negative_prompt]:
text_inputs = audioldm_pipe.tokenizer(
p,
padding="max_length",
max_length=audioldm_pipe.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
clap_prompt_embeds = audioldm_pipe.text_encoder.get_text_features(text_inputs)
if hasattr(clap_prompt_embeds, "pooler_output"):
clap_prompt_embeds = clap_prompt_embeds.pooler_output
clap_prompt_embeds = clap_prompt_embeds[:, None, :]
text_inputs = audioldm_pipe.tokenizer_2(
prompt,
padding="max_length",
max_length=True if len(embeds) == 0 else embeds[0].shape[1],
truncation=True,
return_tensors="pt",
)
text_inputs = text_inputs["input_ids"].to(torch_device)
t5_prompt_embeds = audioldm_pipe.text_encoder_2(
text_inputs,
)
t5_prompt_embeds = t5_prompt_embeds[0]
projection_embeds = audioldm_pipe.projection_model(clap_prompt_embeds, t5_prompt_embeds)[0]
generated_prompt_embeds = audioldm_pipe.generate_language_model(projection_embeds, max_new_tokens=8)
embeds.append(t5_prompt_embeds)
generated_embeds.append(generated_prompt_embeds)
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
inputs["generated_prompt_embeds"], inputs["negative_generated_prompt_embeds"] = generated_embeds
# forward
output = audioldm_pipe(**inputs)
audio_2 = output.audios[0]
assert np.abs(audio_1 - audio_2).max() < 1e-2
@pytest.mark.xfail(
condition=is_transformers_version(">=", "4.54.1"),
reason="Test currently fails on Transformers version 4.54.1.",
strict=False,
)
def test_audioldm2_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
negative_prompt = "egg cracking"
output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) == 256
audio_slice = audio[:10]
expected_slice = np.array(
[0.0026, 0.0017, 0.0018, -0.0022, -0.0026, -0.002, -0.0026, -0.0021, -0.0025, -0.0021]
)
assert np.abs(audio_slice - expected_slice).max() < 1e-4
def test_audioldm2_num_waveforms_per_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(device)
audioldm_pipe.set_progress_bar_config(disable=None)
prompt = "A hammer hitting a wooden surface"
# test num_waveforms_per_prompt=1 (default)
audios = audioldm_pipe(prompt, num_inference_steps=2).audios
assert audios.shape == (1, 256)
# test num_waveforms_per_prompt=1 (default) for batch of prompts
batch_size = 2
audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
assert audios.shape == (batch_size, 256)
# test num_waveforms_per_prompt for single prompt
num_waveforms_per_prompt = 1
audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
assert audios.shape == (num_waveforms_per_prompt, 256)
# test num_waveforms_per_prompt for batch of prompts
batch_size = 2
audios = audioldm_pipe(
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
).audios
assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
def test_audioldm2_audio_length_in_s(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
inputs = self.get_dummy_inputs(device)
output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) / vocoder_sampling_rate == 0.016
output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
audio = output.audios[0]
assert audio.ndim == 1
assert len(audio) / vocoder_sampling_rate == 0.032
def test_audioldm2_vocoder_model_in_dim(self):
components = self.get_dummy_components()
audioldm_pipe = AudioLDM2Pipeline(**components)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
prompt = ["hey"]
output = audioldm_pipe(prompt, num_inference_steps=1)
audio_shape = output.audios.shape
assert audio_shape == (1, 256)
config = audioldm_pipe.vocoder.config
config.model_in_dim *= 2
audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
output = audioldm_pipe(prompt, num_inference_steps=1)
audio_shape = output.audios.shape
# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
assert audio_shape == (1, 256)
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
@unittest.skip("Raises a not implemented error in AudioLDM2")
def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_dict_tuple_outputs_equivalent(self):
# increase tolerance from 1e-4 -> 3e-4 to account for large composite model
super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-4)
@pytest.mark.xfail(
condition=is_torch_version(">=", "2.7"),
reason="Test currently fails on PyTorch 2.7.",
strict=False,
)
def test_inference_batch_single_identical(self):
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
self._test_inference_batch_single_identical(expected_max_diff=2e-4)
def test_save_load_local(self):
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
super().test_save_load_local(expected_max_difference=2e-4)
def test_save_load_optional_components(self):
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
super().test_save_load_optional_components(expected_max_difference=2e-4)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
# The method component.dtype returns the dtype of the first parameter registered in the model, not the
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
# Without the logit scale parameters, everything is float32
model_dtypes.pop("text_encoder")
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
# the CLAP sub-models are float32
model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
# Once we send to fp16, all params are in half-precision, including the logit scale
pipe.to(dtype=torch.float16)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Not supported yet due to CLAPModel.")
def test_sequential_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
def test_cpu_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
def test_model_cpu_offload_forward_pass(self):
pass
@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A hammer hitting a wooden surface",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 2.5,
}
return inputs
def get_inputs_tts(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
inputs = {
"prompt": "A men saying",
"transcription": "hello my name is John",
"latents": latents,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 2.5,
}
return inputs
def test_audioldm2(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 25
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81952
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[17275:17285]
expected_slice = np.array([0.0791, 0.0666, 0.1158, 0.1227, 0.1171, -0.2880, -0.1940, -0.0283, -0.0126, 0.1127])
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-3
def test_audioldm2_lms(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81952
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[31390:31400]
expected_slice = np.array(
[-0.1318, -0.0577, 0.0446, -0.0573, 0.0659, 0.1074, -0.2600, 0.0080, -0.2190, -0.4301]
)
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-3
def test_audioldm2_large(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large")
audioldm_pipe = audioldm_pipe.to(torch_device)
audioldm_pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(torch_device)
audio = audioldm_pipe(**inputs).audios[0]
assert audio.ndim == 1
assert len(audio) == 81952
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
audio_slice = audio[8825:8835]
expected_slice = np.array(
[-0.1829, -0.1461, 0.0759, -0.1493, -0.1396, 0.5783, 0.3001, -0.3038, -0.0639, -0.2244]
)
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-3
def test_audioldm2_tts(self):

Slow coverage is present for cvssp/audioldm2, LMS, cvssp/audioldm2-large, and anhnct/audioldm2_gigaspeech. Several offload / isolation tests remain skipped:

@unittest.skip("Test not supported.")
def test_sequential_cpu_offload_forward_pass(self):
pass
@unittest.skip("Test not supported for now because of the use of `projection_model` in `encode_prompt()`.")
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Not supported yet due to CLAPModel.")
def test_sequential_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
def test_cpu_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
def test_model_cpu_offload_forward_pass(self):

Local pytest collection could not complete in .venv because the installed torch build is missing torch._C._distributed_c10d; the standalone repro snippets above were run successfully under .venv.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions