From 7dbafbabcddd0b09d9d06c019bcb6cd2faa3e828 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 14 Sep 2025 16:49:24 +0800 Subject: [PATCH 1/5] ace step can generate normal stuff, but code is shitty. need massive change --- diffsynth_engine/models/ace_step/ace_dit.py | 18 ++- .../lyric_tokenizer/lyric_tokenizer.py | 31 +++++ diffsynth_engine/models/wan/wan_dit.py | 24 +++- diffsynth_engine/pipelines/ace_step.py | 106 ++++++++++++++++-- diffsynth_engine/pipelines/base.py | 10 +- examples/ace_text_to_music.py | 1 + pyproject.toml | 2 +- 7 files changed, 166 insertions(+), 26 deletions(-) diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 09a4092..f3c8edd 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -26,14 +26,19 @@ def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.inv_freq) - self.freqs_cis_cached = torch.polar(torch.ones_like(freqs), freqs) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() def forward(self, x: torch.Tensor): seq_len = x.shape[1] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len) - return self.freqs_cis_cached[:seq_len][None, :, None, :].to(x.device) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) class SelfAttention(nn.Module): @@ -67,11 +72,13 @@ def forward(self, x, freqs, attn_mask): # x: (b, s, d), attn_mask: (b, s) q = rearrange(q, "b s n d -> b n d s") k = rearrange(k, "b s n d -> b n s d") v = rearrange(v, "b s n d -> b n d s") + q, k, v = q.float(), k.float(), v.float() v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1.0) # b n (d+1) s + dtype = x.dtype x = torch.matmul(torch.matmul(v, k), q) # inner: b n (d+1) d x = x[:, :, :-1] / (x[:, :, -1:] + 1e-15) # b n d s x = rearrange(x, "b n d s -> b s (n d)") - return self.o(x) + return self.o(x.to(dtype=dtype)) class CrossAttention(nn.Module): @@ -175,7 +182,7 @@ def __init__( kernel_size=3, groups=hidden_features * 2, use_bias=True, - act="silu", + act=None, device=device, dtype=dtype, ) @@ -194,7 +201,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x, gate = torch.chunk(x, 2, dim=1) - x *= self.glu_act(gate) + gate = self.glu_act(gate) + x = x * gate x = self.point_conv(x) x = x.transpose(1, 2) return x diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py b/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py index 5a133f1..f8c53c8 100644 --- a/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py @@ -688,3 +688,34 @@ def encode(self, txt, lang): txt = f"[{lang}]{txt}" txt = txt.replace(" ", "[SPACE]") return self.tokenizer.encode(txt).ids + + def decode(self, seq): + import torch + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") + txt = txt.replace("[SPACE]", " ") + txt = txt.replace("[STOP]", "") + # txt = txt.replace("[UNK]", "") + return txt + + # copy from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3936 + def batch_decode( + self, + sequences + ): + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `List[str]`: The list of decoded sentences. + """ + return [self.decode(seq) for seq in sequences] diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index a8b8b9c..3313da7 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -61,9 +61,27 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): def rope_apply(x, freqs): - b, s, n, d = x.shape - x_out = torch.view_as_complex(x.to(torch.float64).reshape(b, s, n, d // 2, 2)) - x_out = torch.view_as_real(x_out * freqs) + # b, s, n, d = x.shape + # x_out = torch.view_as_complex(x.to(torch.float64).reshape(b, s, n, d // 2, 2)) + # x_out = torch.view_as_real(x_out * freqs) + # # get real part and imag part from freqs + # cos = freqs.real + # sin = freqs.imag + # out = x_out.to(x.dtype).flatten(3) + cos, sin = freqs # [S, D] + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + cos, sin = cos.to(x.device), sin.to(x.device) + + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + # rotary_debug = torch.load("/home/zhangchengsong.zcs/ACE-Step/rotary_debug.pt", map_location=x.device) + # print("x max diff:",(x - rotary_debug["x"].permute(0, 2, 1, 3)).abs().max()) + # print("cos max diff:",(cos - rotary_debug["cos"].permute(0, 2, 1, 3)).abs().max()) + # print("sin max diff:",(sin - rotary_debug["sin"].permute(0, 2, 1, 3)).abs().max()) + # print("out max diff:",(x_out - rotary_debug["out"].permute(0, 2, 1, 3)).abs().max()) + return x_out.to(x.dtype).flatten(3) diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 4a8de83..36d407f 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -1,5 +1,6 @@ -from typing import Tuple +from typing import Tuple, Union +import numpy as np import torch import torch.nn.functional as F import torch.distributed as dist @@ -31,6 +32,55 @@ logger = logging.get_logger(__name__) +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional[Union[str, "torch.device"]] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + if isinstance(device, str): + device = torch.device(device) + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slightly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + def fwd_with_temperature( inputs, model_fwd_func, get_hooked_layer_func, layer_start_idx, layer_end_idx, temperature=0.01 @@ -51,11 +101,11 @@ def hook(module, input, output): handlers.append(handler) with torch.no_grad(): - prompt_emb = model_fwd_func(**inputs) + output = model_fwd_func(**inputs) for handler in handlers: handler.remove() - return prompt_emb + return output class MomentumBuffer: @@ -124,7 +174,7 @@ def __init__( ) self.config = config # sampler - self.noise_scheduler = RecifitedFlowScheduler(shift=3.0) + self.noise_scheduler = RecifitedFlowScheduler(shift=config.shift) self.sampler = FlowMatchEulerSampler() # models self.lyric_tokenizer = VoiceBpeTokenizer() @@ -157,7 +207,7 @@ def encode_prompt(self, prompt): ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) ids = ids.to(self.device)[:, :17] mask = mask.to(self.device)[:, :17] - prompt_emb = self.text_encoder(ids, mask) + prompt_emb = self.text_encoder(ids, mask).last_hidden_state return prompt_emb, mask def encode_prompt_null(self, prompt): @@ -171,10 +221,13 @@ def encode_prompt_null(self, prompt): "attention_mask": mask }, model_fwd_func=self.text_encoder, - get_hooked_layer_func=lambda i: self.text_encoder.encoders[i].attn.to_q, - layer_start_idx=4, - layer_end_idx=6, - ) + get_hooked_layer_func=lambda i: self.text_encoder + .encoder.block[i] + .layer[0] + .SelfAttention.q, + layer_start_idx=8, + layer_end_idx=10, + ).last_hidden_state return prompt_emb, mask def tokenize_lyric(self, lyrics: str): @@ -206,6 +259,10 @@ def tokenize_lyric(self, lyrics: str): token_idx = self.lyric_tokenizer.encode(line, "en") else: token_idx = self.lyric_tokenizer.encode(line, language) + toks = self.lyric_tokenizer.batch_decode( + [[tok_id] for tok_id in token_idx] + ) + logger.info(f"{line} --> {language} --> {toks}") lyric_token_idx += token_idx + [2] except Exception as e: logger.warning("tokenize error", e, "for line", line, "major_language", language) @@ -309,7 +366,13 @@ def text2audio( lyric_mask = torch.zeros((1, 1), device=self.device, dtype=torch.long) num_frames = int(audio_duration * 44100 / 512 / 8) - noise = self.generate_noise((1, 8, 16, num_frames), seed=seed, device="cpu", dtype=torch.float32).to(self.device) + # noise = randn_tensor( + # shape=(1, 8, 16, num_frames), + # generator=torch.Generator(device=self.device).manual_seed(seed), + # device=self.device, + # dtype=self.dtype, + # ) + noise = torch.randn(1, 8, 16, num_frames, device=self.device, dtype=self.dtype, generator=torch.Generator(device=self.device).manual_seed(42)) attn_mask = torch.ones(1, num_frames, device=self.device, dtype=self.dtype) _, latents, sigmas, timesteps = self.prepare_latents( latents=noise, @@ -366,10 +429,26 @@ def text2audio( attn_mask_ctx=context_mask, ) # Scheduler + def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): + # L = Lower bound + # U = Upper bound + # x_0 = Midpoint (x corresponding to y = 1.0) + # k = Steepness, can adjust based on preference + + if isinstance(x, torch.Tensor): + device_ = x.device + x = x.to(torch.float).cpu().numpy() + + new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) + + if isinstance(new_x, np.ndarray): + new_x = torch.from_numpy(new_x).to(device_) + return new_x + omega = logistic_function(omega_scale, k=0.1) dx: torch.Tensor = noise_pred * (self.sampler.sigmas[i + 1] - self.sampler.sigmas[i]) dx_mean = dx.mean(dim=(1, 2, 3), keepdim=True) latents = latents.to(dtype=torch.float32) - latents += (dx - dx_mean) * omega_scale + dx + latents += (dx - dx_mean) * omega + dx_mean latents = latents.to(dtype=noise_pred.dtype) if progress_callback is not None: progress_callback(i + 1, len(timesteps), "DENOISING") @@ -447,7 +526,10 @@ def _from_state_dict(cls, state_dicts: ACEStateDicts, config: ACEStepPipelineCon init_device = "cpu" if config.offload_mode is not None else config.device tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=256, clean="whitespace") - text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) + # text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) + from transformers import UMT5EncoderModel + text_encoder = UMT5EncoderModel.from_pretrained("/home/zhangchengsong.zcs/.cache/diffsynth/modelscope/ACE-Step/ACE-Step-v1-3.5B/__version/umt5-base") + text_encoder = text_encoder.to(config.device).eval().to(config.t5_dtype) dcae = DCAE.from_state_dict(state_dicts.dcae, config=dcae_config, device=init_device, dtype=config.vae_dtype) hifi_gan = ADaMoSHiFiGANV1.from_state_dict(state_dicts.vocoder, config=vocoder_config, device=init_device, dtype=config.vae_dtype) vae = MusicDCAE(dcae=dcae, vocoder=hifi_gan) diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index 37fbccb..b40bbb7 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -210,12 +210,12 @@ def prepare_latents( sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps) # k-diffusion # if you have any questions about this, please ask @dizhipeng.dzp for more details - latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) + # latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) init_latents = latents.clone() - sigmas, timesteps = ( - sigmas.to(device=self.device, dtype=self.dtype), - timesteps.to(device=self.device, dtype=self.dtype), - ) + # sigmas, timesteps = ( + # sigmas.to(device=self.device, dtype=self.dtype), + # timesteps.to(device=self.device, dtype=self.dtype), + # ) init_latents, latents = ( init_latents.to(device=self.device, dtype=self.dtype), latents.to(device=self.device, dtype=self.dtype), diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index f817485..1bec574 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -21,6 +21,7 @@ prompt="pop, rap, electronic, blues, hip-house, rhythm and blues", lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意", audio_duration=170.63997916666668, + seed=3299954530, ) save_audio(audio, f"tmp/ace_t2m_{seed}") diff --git a/pyproject.toml b/pyproject.toml index 8b4c119..2c5dde9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "moviepy", "librosa", "scikit-image", - "trimesh" + "trimesh", "py3langid", "pypinyin", "hangul_romanize", From 5d1dc41e8acf54afda75a990438bd03092afa6d1 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 14 Sep 2025 17:09:24 +0800 Subject: [PATCH 2/5] revert t5, rand generator, etc --- .../lyric_tokenizer/lyric_tokenizer.py | 31 ------- diffsynth_engine/pipelines/ace_step.py | 81 ++----------------- examples/ace_text_to_music.py | 2 +- 3 files changed, 9 insertions(+), 105 deletions(-) diff --git a/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py b/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py index f8c53c8..5a133f1 100644 --- a/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py +++ b/diffsynth_engine/models/ace_step/lyric_tokenizer/lyric_tokenizer.py @@ -688,34 +688,3 @@ def encode(self, txt, lang): txt = f"[{lang}]{txt}" txt = txt.replace(" ", "[SPACE]") return self.tokenizer.encode(txt).ids - - def decode(self, seq): - import torch - if isinstance(seq, torch.Tensor): - seq = seq.cpu().numpy() - txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "") - txt = txt.replace("[SPACE]", " ") - txt = txt.replace("[STOP]", "") - # txt = txt.replace("[UNK]", "") - return txt - - # copy from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3936 - def batch_decode( - self, - sequences - ): - """ - Convert a list of lists of token ids into a list of strings by calling decode. - - Args: - sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): - List of tokenized input ids. Can be obtained using the `__call__` method. - skip_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not to remove special tokens in the decoding. - kwargs (additional keyword arguments, *optional*): - Will be passed to the underlying model specific decode method. - - Returns: - `List[str]`: The list of decoded sentences. - """ - return [self.decode(seq) for seq in sequences] diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 36d407f..7dfabaa 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Tuple import numpy as np import torch @@ -32,55 +32,6 @@ logger = logging.get_logger(__name__) -def randn_tensor( - shape: Union[Tuple, List], - generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, - device: Optional[Union[str, "torch.device"]] = None, - dtype: Optional["torch.dtype"] = None, - layout: Optional["torch.layout"] = None, -): - """A helper function to create random tensors on the desired `device` with the desired `dtype`. When - passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor - is always created on the CPU. - """ - # device on which tensor is created defaults to device - if isinstance(device, str): - device = torch.device(device) - rand_device = device - batch_size = shape[0] - - layout = layout or torch.strided - device = device or torch.device("cpu") - - if generator is not None: - gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type - if gen_device_type != device.type and gen_device_type == "cpu": - rand_device = "cpu" - if device != "mps": - logger.info( - f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." - f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" - f" slightly speed up this function by passing a generator that was created on the {device} device." - ) - elif gen_device_type != device.type and gen_device_type == "cuda": - raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") - - # make sure generator list of length 1 is treated like a non-list - if isinstance(generator, list) and len(generator) == 1: - generator = generator[0] - - if isinstance(generator, list): - shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) - - return latents - def fwd_with_temperature( inputs, model_fwd_func, get_hooked_layer_func, layer_start_idx, layer_end_idx, temperature=0.01 @@ -207,7 +158,7 @@ def encode_prompt(self, prompt): ids, mask = self.tokenizer(prompt, return_mask=True, add_special_tokens=True) ids = ids.to(self.device)[:, :17] mask = mask.to(self.device)[:, :17] - prompt_emb = self.text_encoder(ids, mask).last_hidden_state + prompt_emb = self.text_encoder(ids, mask) return prompt_emb, mask def encode_prompt_null(self, prompt): @@ -221,13 +172,10 @@ def encode_prompt_null(self, prompt): "attention_mask": mask }, model_fwd_func=self.text_encoder, - get_hooked_layer_func=lambda i: self.text_encoder - .encoder.block[i] - .layer[0] - .SelfAttention.q, - layer_start_idx=8, - layer_end_idx=10, - ).last_hidden_state + get_hooked_layer_func=lambda i: self.text_encoder.encoders[i].attn.to_q, + layer_start_idx=4, + layer_end_idx=6, + ) return prompt_emb, mask def tokenize_lyric(self, lyrics: str): @@ -259,10 +207,6 @@ def tokenize_lyric(self, lyrics: str): token_idx = self.lyric_tokenizer.encode(line, "en") else: token_idx = self.lyric_tokenizer.encode(line, language) - toks = self.lyric_tokenizer.batch_decode( - [[tok_id] for tok_id in token_idx] - ) - logger.info(f"{line} --> {language} --> {toks}") lyric_token_idx += token_idx + [2] except Exception as e: logger.warning("tokenize error", e, "for line", line, "major_language", language) @@ -366,13 +310,7 @@ def text2audio( lyric_mask = torch.zeros((1, 1), device=self.device, dtype=torch.long) num_frames = int(audio_duration * 44100 / 512 / 8) - # noise = randn_tensor( - # shape=(1, 8, 16, num_frames), - # generator=torch.Generator(device=self.device).manual_seed(seed), - # device=self.device, - # dtype=self.dtype, - # ) - noise = torch.randn(1, 8, 16, num_frames, device=self.device, dtype=self.dtype, generator=torch.Generator(device=self.device).manual_seed(42)) + noise = self.generate_noise((1, 8, 16, num_frames), seed=seed, device="cpu", dtype=torch.float32).to(self.device) attn_mask = torch.ones(1, num_frames, device=self.device, dtype=self.dtype) _, latents, sigmas, timesteps = self.prepare_latents( latents=noise, @@ -526,10 +464,7 @@ def _from_state_dict(cls, state_dicts: ACEStateDicts, config: ACEStepPipelineCon init_device = "cpu" if config.offload_mode is not None else config.device tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=256, clean="whitespace") - # text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) - from transformers import UMT5EncoderModel - text_encoder = UMT5EncoderModel.from_pretrained("/home/zhangchengsong.zcs/.cache/diffsynth/modelscope/ACE-Step/ACE-Step-v1-3.5B/__version/umt5-base") - text_encoder = text_encoder.to(config.device).eval().to(config.t5_dtype) + text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) dcae = DCAE.from_state_dict(state_dicts.dcae, config=dcae_config, device=init_device, dtype=config.vae_dtype) hifi_gan = ADaMoSHiFiGANV1.from_state_dict(state_dicts.vocoder, config=vocoder_config, device=init_device, dtype=config.vae_dtype) vae = MusicDCAE(dcae=dcae, vocoder=hifi_gan) diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index 1bec574..b559731 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -21,7 +21,7 @@ prompt="pop, rap, electronic, blues, hip-house, rhythm and blues", lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意", audio_duration=170.63997916666668, - seed=3299954530, + seed=seed, ) save_audio(audio, f"tmp/ace_t2m_{seed}") From 15902beb3a4d005579889d149854288158767c62 Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 14 Sep 2025 17:16:39 +0800 Subject: [PATCH 3/5] revert more --- diffsynth_engine/models/ace_step/ace_dit.py | 7 ++----- diffsynth_engine/pipelines/base.py | 8 ++++---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index f3c8edd..af84eb3 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -72,13 +72,11 @@ def forward(self, x, freqs, attn_mask): # x: (b, s, d), attn_mask: (b, s) q = rearrange(q, "b s n d -> b n d s") k = rearrange(k, "b s n d -> b n s d") v = rearrange(v, "b s n d -> b n d s") - q, k, v = q.float(), k.float(), v.float() v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1.0) # b n (d+1) s - dtype = x.dtype x = torch.matmul(torch.matmul(v, k), q) # inner: b n (d+1) d x = x[:, :, :-1] / (x[:, :, -1:] + 1e-15) # b n d s x = rearrange(x, "b n d s -> b s (n d)") - return self.o(x.to(dtype=dtype)) + return self.o(x) class CrossAttention(nn.Module): @@ -201,8 +199,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x, gate = torch.chunk(x, 2, dim=1) - gate = self.glu_act(gate) - x = x * gate + x = x * self.glu_act(gate) x = self.point_conv(x) x = x.transpose(1, 2) return x diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index b40bbb7..5c2e2ae 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -212,10 +212,10 @@ def prepare_latents( # if you have any questions about this, please ask @dizhipeng.dzp for more details # latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) init_latents = latents.clone() - # sigmas, timesteps = ( - # sigmas.to(device=self.device, dtype=self.dtype), - # timesteps.to(device=self.device, dtype=self.dtype), - # ) + sigmas, timesteps = ( + sigmas.to(device=self.device, dtype=self.dtype), + timesteps.to(device=self.device, dtype=self.dtype), + ) init_latents, latents = ( init_latents.to(device=self.device, dtype=self.dtype), latents.to(device=self.device, dtype=self.dtype), From 84d9ded1388de840cd024434f95a482251283a7e Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 14 Sep 2025 17:39:39 +0800 Subject: [PATCH 4/5] rope problem partially resolve --- diffsynth_engine/models/ace_step/ace_dit.py | 28 +++++++++++++-------- diffsynth_engine/models/wan/wan_dit.py | 24 +++--------------- examples/ace_text_to_music.py | 2 +- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index af84eb3..116dd70 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -10,34 +10,41 @@ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings from diffsynth_engine.models.basic.attention import attention from diffsynth_engine.models.basic.transformer_helper import RMSNorm -from diffsynth_engine.models.wan.wan_dit import rope_apply, modulate +from diffsynth_engine.models.wan.wan_dit import modulate from diffsynth_engine.models.ace_step.ace_lyric_encoder import ConformerEncoder from diffsynth_engine.utils.constants import ACE_DIT_CONFIG_FILE -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"): - super().__init__() # TODO: how to deal with meta device issue? - device = "cuda:2" - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)) +def rope_apply(x, freqs): # TODO: edit this into complex calculation + cos, sin = freqs # [1, S, 1, D] + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return x_out.to(x.dtype).flatten(3) + + +class Qwen2RotaryEmbedding: + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self._set_cos_sin_cache(seq_len=max_position_embeddings) def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float() + t = torch.arange(self.max_seq_len_cached, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.cos_cached = emb.cos() self.sin_cached = emb.sin() - def forward(self, x: torch.Tensor): + def __call__(self, x: torch.Tensor): seq_len = x.shape[1] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.cos_cached[None, :seq_len, None, :].to(x), + self.sin_cached[None, :seq_len, None, :].to(x), ) @@ -366,7 +373,6 @@ def __init__( dim=head_dim, max_position_embeddings=max_position, base=rope_theta, - device=device, ) inner_dim = num_heads * head_dim diff --git a/diffsynth_engine/models/wan/wan_dit.py b/diffsynth_engine/models/wan/wan_dit.py index 3313da7..a8b8b9c 100644 --- a/diffsynth_engine/models/wan/wan_dit.py +++ b/diffsynth_engine/models/wan/wan_dit.py @@ -61,27 +61,9 @@ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0): def rope_apply(x, freqs): - # b, s, n, d = x.shape - # x_out = torch.view_as_complex(x.to(torch.float64).reshape(b, s, n, d // 2, 2)) - # x_out = torch.view_as_real(x_out * freqs) - # # get real part and imag part from freqs - # cos = freqs.real - # sin = freqs.imag - # out = x_out.to(x.dtype).flatten(3) - cos, sin = freqs # [S, D] - cos = cos[None, :, None, :] - sin = sin[None, :, None, :] - cos, sin = cos.to(x.device), sin.to(x.device) - - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) - x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) - # rotary_debug = torch.load("/home/zhangchengsong.zcs/ACE-Step/rotary_debug.pt", map_location=x.device) - # print("x max diff:",(x - rotary_debug["x"].permute(0, 2, 1, 3)).abs().max()) - # print("cos max diff:",(cos - rotary_debug["cos"].permute(0, 2, 1, 3)).abs().max()) - # print("sin max diff:",(sin - rotary_debug["sin"].permute(0, 2, 1, 3)).abs().max()) - # print("out max diff:",(x_out - rotary_debug["out"].permute(0, 2, 1, 3)).abs().max()) - + b, s, n, d = x.shape + x_out = torch.view_as_complex(x.to(torch.float64).reshape(b, s, n, d // 2, 2)) + x_out = torch.view_as_real(x_out * freqs) return x_out.to(x.dtype).flatten(3) diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index b559731..4efb819 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -14,7 +14,7 @@ ), device="cuda:2", ) - seed = 3299954530 + seed = 42 pipe = ACEStepMusicPipeline.from_pretrained(config) audio = pipe.text2audio( From 81a23ce9a6722934edcf9b451390cb248a995e9b Mon Sep 17 00:00:00 2001 From: continue-revolution Date: Sun, 14 Sep 2025 17:50:39 +0800 Subject: [PATCH 5/5] I think there should be no problem --- diffsynth_engine/models/ace_step/ace_dit.py | 2 +- diffsynth_engine/pipelines/ace_step.py | 24 ++++++--------------- examples/ace_text_to_music.py | 19 +++++++++++----- 3 files changed, 21 insertions(+), 24 deletions(-) diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 116dd70..c319e4e 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -297,7 +297,7 @@ def forward(self, x, t): class ACEStepDiTStateDictConverter(StateDictConverter): - def convert(self, state_dict): + def convert(self, state_dict): # TODO: can this be more elegant? for key in list(state_dict.keys()): # change all linear_q / linear_k / linear_v / linear_p to q / k / v / p if "linear_q" in key: diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 7dfabaa..9e94f70 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -1,6 +1,6 @@ from typing import Tuple -import numpy as np +import math import torch import torch.nn.functional as F import torch.distributed as dist @@ -301,6 +301,10 @@ def text2audio( guidance_interval: float = 0.5, progress_callback: Optional[Callable[[int, int, str], None]] = None, ): + def logistic(x, L=0.9, U=1.1, x_0=0.0, k=0.1): + return L + (U - L) / (1 + math.exp(-k * (x - x_0))) + omega = logistic(omega_scale) + prompt_emb, prompt_attn_mask = self.encode_prompt(prompt) prompt_emb_null, prompt_attn_mask_null = self.encode_prompt_null(prompt) if len(lyrics.strip()) > 0: @@ -320,7 +324,7 @@ def text2audio( ) # Initialize sampler self.sampler.initialize(sigmas=sigmas) - # guidance interval + # Guidance interval cfg_start_step = int(num_inference_steps * ((1 - guidance_interval) / 2)) cfg_end_step = int(num_inference_steps * (guidance_interval / 2 + 0.5)) momentum_buffer = MomentumBuffer() @@ -367,22 +371,6 @@ def text2audio( attn_mask_ctx=context_mask, ) # Scheduler - def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): - # L = Lower bound - # U = Upper bound - # x_0 = Midpoint (x corresponding to y = 1.0) - # k = Steepness, can adjust based on preference - - if isinstance(x, torch.Tensor): - device_ = x.device - x = x.to(torch.float).cpu().numpy() - - new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) - - if isinstance(new_x, np.ndarray): - new_x = torch.from_numpy(new_x).to(device_) - return new_x - omega = logistic_function(omega_scale, k=0.1) dx: torch.Tensor = noise_pred * (self.sampler.sigmas[i + 1] - self.sampler.sigmas[i]) dx_mean = dx.mean(dim=(1, 2, 3), keepdim=True) latents = latents.to(dtype=torch.float32) diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index 4efb819..a6eac32 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -1,4 +1,5 @@ -# import random +import random +import argparse from diffsynth_engine.configs import ACEStepPipelineConfig from diffsynth_engine.pipelines.ace_step import ACEStepMusicPipeline @@ -7,22 +8,30 @@ if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--seed", + type=int, + default=random.randint(0, 2**32 - 1), + ) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + config = ACEStepPipelineConfig( model_path=fetch_model( model_uri="ACE-Step/ACE-Step-v1-3.5B", path="ace_step_transformer/diffusion_pytorch_model.safetensors", ), - device="cuda:2", + device=args.device, ) - seed = 42 pipe = ACEStepMusicPipeline.from_pretrained(config) audio = pipe.text2audio( prompt="pop, rap, electronic, blues, hip-house, rhythm and blues", lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意", audio_duration=170.63997916666668, - seed=seed, + seed=args.seed, ) - save_audio(audio, f"tmp/ace_t2m_{seed}") + save_audio(audio, f"tmp/ace_t2m_{args.seed}") del pipe