diff --git a/diffsynth_engine/models/ace_step/ace_dit.py b/diffsynth_engine/models/ace_step/ace_dit.py index 09a4092..c319e4e 100644 --- a/diffsynth_engine/models/ace_step/ace_dit.py +++ b/diffsynth_engine/models/ace_step/ace_dit.py @@ -10,30 +10,42 @@ from diffsynth_engine.models.basic.timestep import TimestepEmbeddings from diffsynth_engine.models.basic.attention import attention from diffsynth_engine.models.basic.transformer_helper import RMSNorm -from diffsynth_engine.models.wan.wan_dit import rope_apply, modulate +from diffsynth_engine.models.wan.wan_dit import modulate from diffsynth_engine.models.ace_step.ace_lyric_encoder import ConformerEncoder from diffsynth_engine.utils.constants import ACE_DIT_CONFIG_FILE -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"): - super().__init__() # TODO: how to deal with meta device issue? - device = "cuda:2" - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)) +def rope_apply(x, freqs): # TODO: edit this into complex calculation + cos, sin = freqs # [1, S, 1, D] + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return x_out.to(x.dtype).flatten(3) + + +class Qwen2RotaryEmbedding: + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self._set_cos_sin_cache(seq_len=max_position_embeddings) def _set_cos_sin_cache(self, seq_len): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float() + t = torch.arange(self.max_seq_len_cached, dtype=torch.float32) freqs = torch.outer(t, self.inv_freq) - self.freqs_cis_cached = torch.polar(torch.ones_like(freqs), freqs) + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos() + self.sin_cached = emb.sin() - def forward(self, x: torch.Tensor): + def __call__(self, x: torch.Tensor): seq_len = x.shape[1] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len) - return self.freqs_cis_cached[:seq_len][None, :, None, :].to(x.device) + return ( + self.cos_cached[None, :seq_len, None, :].to(x), + self.sin_cached[None, :seq_len, None, :].to(x), + ) class SelfAttention(nn.Module): @@ -175,7 +187,7 @@ def __init__( kernel_size=3, groups=hidden_features * 2, use_bias=True, - act="silu", + act=None, device=device, dtype=dtype, ) @@ -194,7 +206,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.inverted_conv(x) x = self.depth_conv(x) x, gate = torch.chunk(x, 2, dim=1) - x *= self.glu_act(gate) + x = x * self.glu_act(gate) x = self.point_conv(x) x = x.transpose(1, 2) return x @@ -285,7 +297,7 @@ def forward(self, x, t): class ACEStepDiTStateDictConverter(StateDictConverter): - def convert(self, state_dict): + def convert(self, state_dict): # TODO: can this be more elegant? for key in list(state_dict.keys()): # change all linear_q / linear_k / linear_v / linear_p to q / k / v / p if "linear_q" in key: @@ -361,7 +373,6 @@ def __init__( dim=head_dim, max_position_embeddings=max_position, base=rope_theta, - device=device, ) inner_dim = num_heads * head_dim diff --git a/diffsynth_engine/pipelines/ace_step.py b/diffsynth_engine/pipelines/ace_step.py index 4a8de83..9e94f70 100644 --- a/diffsynth_engine/pipelines/ace_step.py +++ b/diffsynth_engine/pipelines/ace_step.py @@ -1,5 +1,6 @@ from typing import Tuple +import math import torch import torch.nn.functional as F import torch.distributed as dist @@ -51,11 +52,11 @@ def hook(module, input, output): handlers.append(handler) with torch.no_grad(): - prompt_emb = model_fwd_func(**inputs) + output = model_fwd_func(**inputs) for handler in handlers: handler.remove() - return prompt_emb + return output class MomentumBuffer: @@ -124,7 +125,7 @@ def __init__( ) self.config = config # sampler - self.noise_scheduler = RecifitedFlowScheduler(shift=3.0) + self.noise_scheduler = RecifitedFlowScheduler(shift=config.shift) self.sampler = FlowMatchEulerSampler() # models self.lyric_tokenizer = VoiceBpeTokenizer() @@ -300,6 +301,10 @@ def text2audio( guidance_interval: float = 0.5, progress_callback: Optional[Callable[[int, int, str], None]] = None, ): + def logistic(x, L=0.9, U=1.1, x_0=0.0, k=0.1): + return L + (U - L) / (1 + math.exp(-k * (x - x_0))) + omega = logistic(omega_scale) + prompt_emb, prompt_attn_mask = self.encode_prompt(prompt) prompt_emb_null, prompt_attn_mask_null = self.encode_prompt_null(prompt) if len(lyrics.strip()) > 0: @@ -319,7 +324,7 @@ def text2audio( ) # Initialize sampler self.sampler.initialize(sigmas=sigmas) - # guidance interval + # Guidance interval cfg_start_step = int(num_inference_steps * ((1 - guidance_interval) / 2)) cfg_end_step = int(num_inference_steps * (guidance_interval / 2 + 0.5)) momentum_buffer = MomentumBuffer() @@ -369,7 +374,7 @@ def text2audio( dx: torch.Tensor = noise_pred * (self.sampler.sigmas[i + 1] - self.sampler.sigmas[i]) dx_mean = dx.mean(dim=(1, 2, 3), keepdim=True) latents = latents.to(dtype=torch.float32) - latents += (dx - dx_mean) * omega_scale + dx + latents += (dx - dx_mean) * omega + dx_mean latents = latents.to(dtype=noise_pred.dtype) if progress_callback is not None: progress_callback(i + 1, len(timesteps), "DENOISING") diff --git a/diffsynth_engine/pipelines/base.py b/diffsynth_engine/pipelines/base.py index 37fbccb..5c2e2ae 100644 --- a/diffsynth_engine/pipelines/base.py +++ b/diffsynth_engine/pipelines/base.py @@ -210,7 +210,7 @@ def prepare_latents( sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps) # k-diffusion # if you have any questions about this, please ask @dizhipeng.dzp for more details - latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) + # latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5) init_latents = latents.clone() sigmas, timesteps = ( sigmas.to(device=self.device, dtype=self.dtype), diff --git a/examples/ace_text_to_music.py b/examples/ace_text_to_music.py index f817485..a6eac32 100644 --- a/examples/ace_text_to_music.py +++ b/examples/ace_text_to_music.py @@ -1,4 +1,5 @@ -# import random +import random +import argparse from diffsynth_engine.configs import ACEStepPipelineConfig from diffsynth_engine.pipelines.ace_step import ACEStepMusicPipeline @@ -7,21 +8,30 @@ if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--seed", + type=int, + default=random.randint(0, 2**32 - 1), + ) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + config = ACEStepPipelineConfig( model_path=fetch_model( model_uri="ACE-Step/ACE-Step-v1-3.5B", path="ace_step_transformer/diffusion_pytorch_model.safetensors", ), - device="cuda:2", + device=args.device, ) - seed = 3299954530 pipe = ACEStepMusicPipeline.from_pretrained(config) audio = pipe.text2audio( prompt="pop, rap, electronic, blues, hip-house, rhythm and blues", lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意", audio_duration=170.63997916666668, + seed=args.seed, ) - save_audio(audio, f"tmp/ace_t2m_{seed}") + save_audio(audio, f"tmp/ace_t2m_{args.seed}") del pipe diff --git a/pyproject.toml b/pyproject.toml index 8b4c119..2c5dde9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ "moviepy", "librosa", "scikit-image", - "trimesh" + "trimesh", "py3langid", "pypinyin", "hangul_romanize",