Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MusicGen: Missing 'rtf' assignment in prompted samples generation. #458

Open
ODD2 opened this issue May 10, 2024 · 0 comments
Open

MusicGen: Missing 'rtf' assignment in prompted samples generation. #458

ODD2 opened this issue May 10, 2024 · 0 comments

Comments

@ODD2
Copy link

ODD2 commented May 10, 2024

Hi, I'm currently experimenting on MusicGen and encountered a 'rtf undefined' exception when the system is configured to only generate prompted samples, that is the 'generate' section in config/solver/musigen/default.yaml is modified as:

generate:
  every: 25
  num_workers: 5
  path: samples
  audio:
    format: wav
    strategy: loudness
    sample_rate: ${sample_rate}
    loudness_headroom_db: 14
  lm:
    prompted_samples: true 
    unprompted_samples: false # <- this line is modified
    gen_gt_samples: false
    prompt_duration: null   # if not set, will use dataset.generate.segment_duration / 4
    gen_duration: null      # if not set, will use dataset.generate.segment_duration
    remove_prompts: false
    # generation params
    use_sampling: false
    temp: 1.0
    top_k: 0
    top_p: 0.0

I'm guessing the assignment of 'rtf' is missing in the section for prompted sample generation?

#line 577 in audiocraft/solvers/musicgen
if self.cfg.generate.lm.prompted_samples:
    gen_outputs = self.run_generate_step(
        batch, gen_duration=target_duration, prompt_duration=prompt_duration,
        **self.generation_params)
    gen_audio = gen_outputs['gen_audio'].cpu()
    prompt_audio = gen_outputs['prompt_audio'].cpu()
    sample_manager.add_samples(
        gen_audio, self.epoch, hydrated_conditions,
        prompt_wavs=prompt_audio, ground_truth_wavs=audio,
        generation_args=sample_generation_params)
   # rtf = gen_outputs["rtf"])  missing?

Currently, I've modified the generation section for the 'rtf' metric as follow:

#line 560 in audiocraft/solvers/musicgen
rtf = [] # <- modified 
if self.cfg.generate.lm.unprompted_samples:
    if self.cfg.generate.lm.gen_gt_samples:
        # get the ground truth instead of generation
        self.logger.warn(
            "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
        gen_unprompted_audio = audio
        rtf.append(1.)  # <- modified 
    else:
        gen_unprompted_outputs = self.run_generate_step(
            batch, gen_duration=target_duration, prompt_duration=None,
            **self.generation_params)
        gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
        rtf.append(gen_unprompted_outputs['rtf'])  # <- modified 
    sample_manager.add_samples(
        gen_unprompted_audio, self.epoch, hydrated_conditions,
        ground_truth_wavs=audio, generation_args=sample_generation_params)

if self.cfg.generate.lm.prompted_samples:
    gen_outputs = self.run_generate_step(
        batch, gen_duration=target_duration, prompt_duration=prompt_duration,
        **self.generation_params)
    gen_audio = gen_outputs['gen_audio'].cpu()
    prompt_audio = gen_outputs['prompt_audio'].cpu()
    sample_manager.add_samples(
        gen_audio, self.epoch, hydrated_conditions,
        prompt_wavs=prompt_audio, ground_truth_wavs=audio,
        generation_args=sample_generation_params)
    rtf.append(gen_outputs["rtf"]) # <- modified 

metrics['rtf'] = sum(rtf)/min(len(rtf),1) # <- modified 

Please let me know if the modification is correct. Thanks for the great work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant