Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions diffsynth_engine/models/ace_step/ace_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,42 @@
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
from diffsynth_engine.models.basic.attention import attention
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
from diffsynth_engine.models.wan.wan_dit import rope_apply, modulate
from diffsynth_engine.models.wan.wan_dit import modulate
from diffsynth_engine.models.ace_step.ace_lyric_encoder import ConformerEncoder
from diffsynth_engine.utils.constants import ACE_DIT_CONFIG_FILE


class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cuda:0"):
super().__init__() # TODO: how to deal with meta device issue?
device = "cuda:2"
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim))
def rope_apply(x, freqs): # TODO: edit this into complex calculation
cos, sin = freqs # [1, S, 1, D]
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
x_out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return x_out.to(x.dtype).flatten(3)


class Qwen2RotaryEmbedding:
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
self._set_cos_sin_cache(seq_len=max_position_embeddings)

def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.int64).float()
t = torch.arange(self.max_seq_len_cached, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
self.freqs_cis_cached = torch.polar(torch.ones_like(freqs), freqs)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()

def forward(self, x: torch.Tensor):
def __call__(self, x: torch.Tensor):
seq_len = x.shape[1]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len)

return self.freqs_cis_cached[:seq_len][None, :, None, :].to(x.device)
return (
self.cos_cached[None, :seq_len, None, :].to(x),
self.sin_cached[None, :seq_len, None, :].to(x),
)


class SelfAttention(nn.Module):
Expand Down Expand Up @@ -175,7 +187,7 @@ def __init__(
kernel_size=3,
groups=hidden_features * 2,
use_bias=True,
act="silu",
act=None,
device=device,
dtype=dtype,
)
Expand All @@ -194,7 +206,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.inverted_conv(x)
x = self.depth_conv(x)
x, gate = torch.chunk(x, 2, dim=1)
x *= self.glu_act(gate)
x = x * self.glu_act(gate)
x = self.point_conv(x)
x = x.transpose(1, 2)
return x
Expand Down Expand Up @@ -285,7 +297,7 @@ def forward(self, x, t):


class ACEStepDiTStateDictConverter(StateDictConverter):
def convert(self, state_dict):
def convert(self, state_dict): # TODO: can this be more elegant?
for key in list(state_dict.keys()):
# change all linear_q / linear_k / linear_v / linear_p to q / k / v / p
if "linear_q" in key:
Expand Down Expand Up @@ -361,7 +373,6 @@ def __init__(
dim=head_dim,
max_position_embeddings=max_position,
base=rope_theta,
device=device,
)

inner_dim = num_heads * head_dim
Expand Down
15 changes: 10 additions & 5 deletions diffsynth_engine/pipelines/ace_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import math
import torch
import torch.nn.functional as F
import torch.distributed as dist
Expand Down Expand Up @@ -51,11 +52,11 @@ def hook(module, input, output):
handlers.append(handler)

with torch.no_grad():
prompt_emb = model_fwd_func(**inputs)
output = model_fwd_func(**inputs)

for handler in handlers:
handler.remove()
return prompt_emb
return output


class MomentumBuffer:
Expand Down Expand Up @@ -124,7 +125,7 @@ def __init__(
)
self.config = config
# sampler
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0)
self.noise_scheduler = RecifitedFlowScheduler(shift=config.shift)
self.sampler = FlowMatchEulerSampler()
# models
self.lyric_tokenizer = VoiceBpeTokenizer()
Expand Down Expand Up @@ -300,6 +301,10 @@ def text2audio(
guidance_interval: float = 0.5,
progress_callback: Optional[Callable[[int, int, str], None]] = None,
):
def logistic(x, L=0.9, U=1.1, x_0=0.0, k=0.1):
return L + (U - L) / (1 + math.exp(-k * (x - x_0)))
omega = logistic(omega_scale)

prompt_emb, prompt_attn_mask = self.encode_prompt(prompt)
prompt_emb_null, prompt_attn_mask_null = self.encode_prompt_null(prompt)
if len(lyrics.strip()) > 0:
Expand All @@ -319,7 +324,7 @@ def text2audio(
)
# Initialize sampler
self.sampler.initialize(sigmas=sigmas)
# guidance interval
# Guidance interval
cfg_start_step = int(num_inference_steps * ((1 - guidance_interval) / 2))
cfg_end_step = int(num_inference_steps * (guidance_interval / 2 + 0.5))
momentum_buffer = MomentumBuffer()
Expand Down Expand Up @@ -369,7 +374,7 @@ def text2audio(
dx: torch.Tensor = noise_pred * (self.sampler.sigmas[i + 1] - self.sampler.sigmas[i])
dx_mean = dx.mean(dim=(1, 2, 3), keepdim=True)
latents = latents.to(dtype=torch.float32)
latents += (dx - dx_mean) * omega_scale + dx
latents += (dx - dx_mean) * omega + dx_mean
latents = latents.to(dtype=noise_pred.dtype)
if progress_callback is not None:
progress_callback(i + 1, len(timesteps), "DENOISING")
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def prepare_latents(
sigmas, timesteps = self.noise_scheduler.schedule(num_inference_steps)
# k-diffusion
# if you have any questions about this, please ask @dizhipeng.dzp for more details
latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5)
# latents = latents * sigmas[0] / ((sigmas[0] ** 2 + 1) ** 0.5)
init_latents = latents.clone()
sigmas, timesteps = (
sigmas.to(device=self.device, dtype=self.dtype),
Expand Down
18 changes: 14 additions & 4 deletions examples/ace_text_to_music.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# import random
import random
import argparse

from diffsynth_engine.configs import ACEStepPipelineConfig
from diffsynth_engine.pipelines.ace_step import ACEStepMusicPipeline
Expand All @@ -7,21 +8,30 @@


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--seed",
type=int,
default=random.randint(0, 2**32 - 1),
)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()

config = ACEStepPipelineConfig(
model_path=fetch_model(
model_uri="ACE-Step/ACE-Step-v1-3.5B",
path="ace_step_transformer/diffusion_pytorch_model.safetensors",
),
device="cuda:2",
device=args.device,
)
seed = 3299954530

pipe = ACEStepMusicPipeline.from_pretrained(config)
audio = pipe.text2audio(
prompt="pop, rap, electronic, blues, hip-house, rhythm and blues",
lyrics="[verse]\n我走过深夜的街道\n冷风吹乱思念的漂亮外套\n你的微笑像星光很炫耀\n照亮了我孤独的每分每秒\n\n[chorus]\n愿你是风吹过我的脸\n带我飞过最远最遥远的山间\n愿你是风轻触我的梦\n停在心头不再飘散无迹无踪\n\n[verse]\n一起在喧哗避开世俗的骚动\n独自在天台探望月色的朦胧\n你说爱像音乐带点重节奏\n一拍一跳让我忘了心的温度多空洞\n\n[bridge]\n唱起对你的想念不隐藏\n像诗又像画写满藏不了的渴望\n你的影子挥不掉像风的倔强\n追着你飞扬穿越云海一样泛光\n\n[chorus]\n愿你是风吹过我的手\n暖暖的触碰像春日细雨温柔\n愿你是风盘绕我的身\n深情万万重不会有一天走远走\n\n[verse]\n深夜的钢琴弹起动人的旋律\n低音鼓砸进心底的每一次呼吸\n要是能将爱化作歌声传递\n你是否会听见我心里的真心实意",
audio_duration=170.63997916666668,
seed=args.seed,
)
save_audio(audio, f"tmp/ace_t2m_{seed}")
save_audio(audio, f"tmp/ace_t2m_{args.seed}")

del pipe
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"moviepy",
"librosa",
"scikit-image",
"trimesh"
"trimesh",
"py3langid",
"pypinyin",
"hangul_romanize",
Expand Down