-
Couldn't load subscription status.
- Fork 25
Fix most problems in conrevo/ace-step branch #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @continue-revolution, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands the capabilities of the diffsynth_engine by introducing the "ACE-Step" music generation pipeline. This new feature enables users to generate music from text prompts and lyrics, leveraging advanced deep learning models for diffusion, text encoding, and audio synthesis. The changes encompass new model definitions, configuration files, and a comprehensive pipeline orchestrating these components, along with robust multilingual text processing utilities to handle diverse lyrical inputs.
Highlights
- New ACE-Step Music Generation Pipeline: Introduces a complete pipeline for text-to-music generation, integrating several new model architectures for diffusion, text encoding, and audio synthesis.
- Advanced Diffusion Transformer (DiT) Model: Implements a sophisticated Diffusion Transformer (
ACEStepDiT) that processes time, speaker, genre, and lyric embeddings, leveraging a dedicatedConformerEncoderfor lyric processing. - Multilingual Lyric Processing: Adds comprehensive multilingual lyric tokenization and language segmentation capabilities, supporting various languages (e.g., Chinese, Japanese, English, Korean) and incorporating number-to-word conversion for SSML tags.
- Integrated VAE and Vocoder: Incorporates a Deep Convolutional Autoencoder (DCAE) for efficient latent representation of audio and a high-fidelity HiFi-GAN variant vocoder for synthesizing audio waveforms.
- Enhanced Guidance Mechanisms: The pipeline supports advanced generation control through Classifier-Free Guidance (CFG) and Advanced Projection Guidance (APG) for improved output quality.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several changes across the ACE-Step model and pipeline, including refactoring rotary position embeddings, updating attention mechanisms, and adding decoding capabilities to the lyric tokenizer. My review focuses on improving code quality, fixing bugs, and ensuring portability. Key issues identified include a hardcoded model path, a bug where a user-provided seed is ignored in favor of a hardcoded value, inefficient function definitions inside loops, and several instances of dead or commented-out code that should be removed. I've also pointed out some style issues like imports inside functions and an inaccurate docstring.
| text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) | ||
| # text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) | ||
| from transformers import UMT5EncoderModel | ||
| text_encoder = UMT5EncoderModel.from_pretrained("/home/zhangchengsong.zcs/.cache/diffsynth/modelscope/ACE-Step/ACE-Step-v1-3.5B/__version/umt5-base") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A hardcoded absolute path is used to load the UMT5EncoderModel. This is not portable and will cause the code to fail on any other machine. Please replace this with a mechanism to fetch the model from a repository (like Hugging Face Hub) or use a relative path.
| text_encoder = UMT5EncoderModel.from_pretrained("/home/zhangchengsong.zcs/.cache/diffsynth/modelscope/ACE-Step/ACE-Step-v1-3.5B/__version/umt5-base") | |
| text_encoder = UMT5EncoderModel.from_pretrained("your-model-identifier") |
| # noise = randn_tensor( | ||
| # shape=(1, 8, 16, num_frames), | ||
| # generator=torch.Generator(device=self.device).manual_seed(seed), | ||
| # device=self.device, | ||
| # dtype=self.dtype, | ||
| # ) | ||
| noise = torch.randn(1, 8, 16, num_frames, device=self.device, dtype=self.dtype, generator=torch.Generator(device=self.device).manual_seed(42)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The seed parameter of the text2audio function is being ignored. Instead, a hardcoded seed 42 is used for generating noise. This will cause the function to produce the same output every time, which is likely not the intended behavior. The commented-out code should also be removed.
| # noise = randn_tensor( | |
| # shape=(1, 8, 16, num_frames), | |
| # generator=torch.Generator(device=self.device).manual_seed(seed), | |
| # device=self.device, | |
| # dtype=self.dtype, | |
| # ) | |
| noise = torch.randn(1, 8, 16, num_frames, device=self.device, dtype=self.dtype, generator=torch.Generator(device=self.device).manual_seed(42)) | |
| noise = torch.randn(1, 8, 16, num_frames, device=self.device, dtype=self.dtype, generator=torch.Generator(device=self.device).manual_seed(seed)) |
| return self.tokenizer.encode(txt).ids | ||
|
|
||
| def decode(self, seq): | ||
| import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch library is imported inside the decode method. It's a best practice to place all imports at the top of the file. This makes dependencies clear and avoids potential overhead from re-importing modules within a function that might be called frequently. Please move this import to the top of the file.
| """ | ||
| Convert a list of lists of token ids into a list of strings by calling decode. | ||
| Args: | ||
| sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): | ||
| List of tokenized input ids. Can be obtained using the `__call__` method. | ||
| skip_special_tokens (`bool`, *optional*, defaults to `False`): | ||
| Whether or not to remove special tokens in the decoding. | ||
| kwargs (additional keyword arguments, *optional*): | ||
| Will be passed to the underlying model specific decode method. | ||
| Returns: | ||
| `List[str]`: The list of decoded sentences. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for batch_decode appears to be copied from another source and is inaccurate. It mentions skip_special_tokens and kwargs parameters, which are not part of this function's signature. This is misleading and should be corrected to reflect the actual implementation.
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
List of tokenized input ids. Can be obtained using the `__call__` method.
Returns:
`List[str]`: The list of decoded sentences.
"""| def randn_tensor( | ||
| shape: Union[Tuple, List], | ||
| generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, | ||
| device: Optional[Union[str, "torch.device"]] = None, | ||
| dtype: Optional["torch.dtype"] = None, | ||
| layout: Optional["torch.layout"] = None, | ||
| ): | ||
| """A helper function to create random tensors on the desired `device` with the desired `dtype`. When | ||
| passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor | ||
| is always created on the CPU. | ||
| """ | ||
| # device on which tensor is created defaults to device | ||
| if isinstance(device, str): | ||
| device = torch.device(device) | ||
| rand_device = device | ||
| batch_size = shape[0] | ||
|
|
||
| layout = layout or torch.strided | ||
| device = device or torch.device("cpu") | ||
|
|
||
| if generator is not None: | ||
| gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type | ||
| if gen_device_type != device.type and gen_device_type == "cpu": | ||
| rand_device = "cpu" | ||
| if device != "mps": | ||
| logger.info( | ||
| f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." | ||
| f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" | ||
| f" slightly speed up this function by passing a generator that was created on the {device} device." | ||
| ) | ||
| elif gen_device_type != device.type and gen_device_type == "cuda": | ||
| raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") | ||
|
|
||
| # make sure generator list of length 1 is treated like a non-list | ||
| if isinstance(generator, list) and len(generator) == 1: | ||
| generator = generator[0] | ||
|
|
||
| if isinstance(generator, list): | ||
| shape = (1,) + shape[1:] | ||
| latents = [ | ||
| torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) | ||
| for i in range(batch_size) | ||
| ] | ||
| latents = torch.cat(latents, dim=0).to(device) | ||
| else: | ||
| latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) | ||
|
|
||
| return latents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| tokenizer = WanT5Tokenizer(WAN_TOKENIZER_CONF_PATH, seq_len=256, clean="whitespace") | ||
| text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) | ||
| # text_encoder = ACETextEncoder.from_state_dict(state_dicts.t5, device=init_device, dtype=config.t5_dtype, **t5_config) | ||
| from transformers import UMT5EncoderModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.