Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions tests/pipelines/audioldm2/test_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,15 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = AudioLDM2UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
block_out_channels=(8, 16),
layers_per_block=1,
norm_num_groups=8,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=([None, 16, 32], [None, 16, 32]),
cross_attention_dim=(8, 16),
)
scheduler = DDIMScheduler(
beta_start=0.00085,
Expand All @@ -91,9 +92,10 @@ def get_dummy_components(self):
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
block_out_channels=[8, 16],
in_channels=1,
out_channels=1,
norm_num_groups=8,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
Expand All @@ -102,32 +104,34 @@ def get_dummy_components(self):
text_branch_config = ClapTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=16,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=2,
num_hidden_layers=2,
num_attention_heads=1,
num_hidden_layers=1,
pad_token_id=1,
vocab_size=1000,
projection_dim=16,
projection_dim=8,
)
audio_branch_config = ClapAudioConfig(
spec_size=64,
spec_size=8,
window_size=4,
num_mel_bins=64,
num_mel_bins=8,
intermediate_size=37,
layer_norm_eps=1e-05,
depths=[2, 2],
num_attention_heads=[2, 2],
num_hidden_layers=2,
depths=[1, 1],
num_attention_heads=[1, 1],
num_hidden_layers=1,
hidden_size=192,
projection_dim=16,
projection_dim=8,
patch_size=2,
patch_stride=2,
patch_embed_input_channels=4,
)
text_encoder_config = ClapConfig.from_text_audio_configs(
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
text_config=text_branch_config,
audio_config=audio_branch_config,
projection_dim=16,
)
text_encoder = ClapModel(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
Expand All @@ -141,17 +145,17 @@ def get_dummy_components(self):
d_model=32,
d_ff=37,
d_kv=8,
num_heads=2,
num_layers=2,
num_heads=1,
num_layers=1,
)
text_encoder_2 = T5EncoderModel(text_encoder_2_config)
tokenizer_2 = T5Tokenizer.from_pretrained("hf-internal-testing/tiny-random-T5Model", model_max_length=77)

torch.manual_seed(0)
language_model_config = GPT2Config(
n_embd=16,
n_head=2,
n_layer=2,
n_head=1,
n_layer=1,
vocab_size=1000,
n_ctx=99,
n_positions=99,
Expand All @@ -160,7 +164,11 @@ def get_dummy_components(self):
language_model.config.max_new_tokens = 8

torch.manual_seed(0)
projection_model = AudioLDM2ProjectionModel(text_encoder_dim=16, text_encoder_1_dim=32, langauge_model_dim=16)
projection_model = AudioLDM2ProjectionModel(
text_encoder_dim=16,
text_encoder_1_dim=32,
langauge_model_dim=16,
)

vocoder_config = SpeechT5HifiGanConfig(
model_in_dim=8,
Expand Down Expand Up @@ -220,7 +228,18 @@ def test_audioldm2_ddim(self):

audio_slice = audio[:10]
expected_slice = np.array(
[0.0025, 0.0018, 0.0018, -0.0023, -0.0026, -0.0020, -0.0026, -0.0021, -0.0027, -0.0020]
[
2.602e-03,
1.729e-03,
1.863e-03,
-2.219e-03,
-2.656e-03,
-2.017e-03,
-2.648e-03,
-2.115e-03,
-2.502e-03,
-2.081e-03,
]
)

assert np.abs(audio_slice - expected_slice).max() < 1e-4
Expand Down Expand Up @@ -361,7 +380,7 @@ def test_audioldm2_negative_prompt(self):

audio_slice = audio[:10]
expected_slice = np.array(
[0.0025, 0.0018, 0.0018, -0.0023, -0.0026, -0.0020, -0.0026, -0.0021, -0.0027, -0.0020]
[0.0026, 0.0017, 0.0018, -0.0022, -0.0026, -0.002, -0.0026, -0.0021, -0.0025, -0.0021]
)

assert np.abs(audio_slice - expected_slice).max() < 1e-4
Expand All @@ -388,7 +407,7 @@ def test_audioldm2_num_waveforms_per_prompt(self):
assert audios.shape == (batch_size, 256)

# test num_waveforms_per_prompt for single prompt
num_waveforms_per_prompt = 2
num_waveforms_per_prompt = 1
audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios

assert audios.shape == (num_waveforms_per_prompt, 256)
Expand Down