In [2]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from diffusers import AudioLDM2Pipeline
from diffusers.pipelines.pipeline_utils import AudioPipelineOutput
from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput
import librosa
import random
import numpy as np
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn as nn
import pandas as pd
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import scipy
import gc
from utils import *

In [18]:
repo_id = "cvssp/audioldm2-music"
#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
#device = xm.xla_device()
print(f'device: {device}')

pipeline = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float32).to(device)

device: cpu


Loading pipeline components...:   0%|          | 0/11 [00:00<?, ?it/s]

In [19]:
def custum_step_DDIM(ddim_scheduler, unet_output_epsilon, timestep, sample, X_0, reference_guidance_strength = 0.5, eta = 0.0):
    prev_timestep = timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps
    alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep]
    alpha_prod_t_prev = ddim_scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else ddim_scheduler.final_alpha_cumprod
    beta_prod_t = 1 - alpha_prod_t
    pred_original_sample = (sample - beta_prod_t ** (0.5) * unet_output_epsilon) / alpha_prod_t ** (0.5)
    real_original_sample = X_0 / alpha_prod_t ** (0.5)
    pred_epsilon = unet_output_epsilon
    variance = ddim_scheduler._get_variance(timestep, prev_timestep)
    std_dev_t = eta * variance ** (0.5)
    pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

    #TODO: 원래 pred_original_sample가 있던 자리에 (pred_original_sample*(1-reference_guidance_strength) + real_original_sample*reference_guidance_strength) 삽입
    prev_sample = alpha_prod_t_prev ** (0.5) * (pred_original_sample*(1-reference_guidance_strength) + real_original_sample*reference_guidance_strength) + pred_sample_direction

    return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

In [20]:

def inference_pipe(
        pipeline,
        reference_latent_X_0,
        reference_guidance_decay = None,
        prompt: Union[str, List[str]] = None,
        transcription: Union[str, List[str]] = None,
        audio_length_in_s: Optional[float] = None,
        num_inference_steps: int = 200,
        guidance_scale: float = 0.8,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_waveforms_per_prompt: Optional[int] = 1,
        eta: float = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.Tensor] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        generated_prompt_embeds: Optional[torch.Tensor] = None,
        negative_generated_prompt_embeds: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        negative_attention_mask: Optional[torch.LongTensor] = None,
        max_new_tokens: Optional[int] = None,
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
        callback_steps: Optional[int] = 1,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        output_type: Optional[str] = "np",
    ):
 # 0. Convert audio input length from seconds to spectrogram height
        vocoder_upsample_factor = np.prod(pipeline.vocoder.config.upsample_rates) / pipeline.vocoder.config.sampling_rate

        if audio_length_in_s is None:
            audio_length_in_s = pipeline.unet.config.sample_size * pipeline.vae_scale_factor * vocoder_upsample_factor

        height = int(audio_length_in_s / vocoder_upsample_factor)

        original_waveform_length = int(audio_length_in_s * pipeline.vocoder.config.sampling_rate)
        if height % pipeline.vae_scale_factor != 0:
            height = int(np.ceil(height / pipeline.vae_scale_factor)) * pipeline.vae_scale_factor

        # 1. Check inputs. Raise error if not correct
        pipeline.check_inputs(
            prompt,
            audio_length_in_s,
            vocoder_upsample_factor,
            callback_steps,
            transcription,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
            generated_prompt_embeds,
            negative_generated_prompt_embeds,
            attention_mask,
            negative_attention_mask,
        )

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = pipeline._execution_device
        print(f'device:{device}')
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.

        #do_classifier_free_guidance = guidance_scale > 1.0
        do_classifier_free_guidance = negative_prompt is not None or negative_generated_prompt_embeds is not None or negative_prompt_embeds is not None

        # 3. Encode input prompt
        prompt_embeds, attention_mask, generated_prompt_embeds = pipeline.encode_prompt(
            prompt,
            device,
            num_waveforms_per_prompt,
            do_classifier_free_guidance,
            transcription,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            generated_prompt_embeds=generated_prompt_embeds,
            negative_generated_prompt_embeds=negative_generated_prompt_embeds,
            attention_mask=attention_mask,
            negative_attention_mask=negative_attention_mask,
            max_new_tokens=max_new_tokens,
        )

        # 4. Prepare timesteps
        pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps = pipeline.scheduler.timesteps

        # 5. Prepare latent variables
        num_channels_latents = pipeline.unet.config.in_channels
        latents = pipeline.prepare_latents(
            batch_size * num_waveforms_per_prompt,
            num_channels_latents,
            height,
            prompt_embeds.dtype,
            device,
            generator,
            latents,
        )

        # 6. Prepare extra step kwargs
        extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)

        # 7. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
        with pipeline.progress_bar(total=num_inference_steps) as progress_bar:

            # TODO: reference_guidance strength decay exponentially as reverse process proceeding
            exp_reference_guidance_decay = torch.zeros_like(timesteps)
            if reference_guidance_decay is not None:
                step_size = pipeline.scheduler.config.num_train_timesteps / len(timesteps)
                decay_rate_for_one_step = torch.pow(torch.tensor(reference_guidance_decay), step_size)
                exp_reference_guidance_decay =  decay_rate_for_one_step** torch.arange(1, len(timesteps)+1)


            for i, t in enumerate(timesteps):
                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = pipeline.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=generated_prompt_embeds,
                    encoder_hidden_states_1=prompt_embeds,
                    encoder_attention_mask_1=attention_mask,
                    return_dict=False,
                )[0]

                # perform guidance
                if do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                #latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

                # TODO: ddim step customizing with reference_guidance_strength
                latents = custum_step_DDIM(pipeline.scheduler, noise_pred, t, latents, reference_latent_X_0, reference_guidance_strength = exp_reference_guidance_decay[i], eta = eta).prev_sample
                #call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(pipeline.scheduler, "order", 1)
                        callback(step_idx, t, latents)


        pipeline.maybe_free_model_hooks()

        output_latents = latents

        latents = 1 / pipeline.vae.config.scaling_factor * latents
        mel_spectrogram = pipeline.vae.decode(latents).sample


        audio = pipeline.mel_spectrogram_to_waveform(mel_spectrogram)

        audio = audio[:, :original_waveform_length]

        # 9. Automatic scoring
        if num_waveforms_per_prompt > 1 and prompt is not None:
            audio = pipeline.score_waveforms(
                text=prompt,
                audio=audio,
                num_waveforms_per_prompt=num_waveforms_per_prompt,
                device=device,
                dtype=prompt_embeds.dtype,
            )

        audio = audio.detach().cpu().numpy()

        if not return_dict:
            return (audio,)

        return [output_latents, AudioPipelineOutput(audios=audio)] if output_type == 'latent' else AudioPipelineOutput(audios=audio)

In [90]:

def synthesized_instrument_with_reference(instrument_class_prefix, music_type, prompt = "", background_music_path = None, latent_path=None, reference_guidance_decay = None, synthesis_strength = 0.7, num_inference_steps=20):
    
    duration = 10.24
    sr = 16000
    
    prompt_with_prefix = f'[only played by {instrument_class_prefix}]' + prompt

    latent_path =  music_type + '_latent_tensor.pt' if latent_path is None else latent_path
    music_latent = torch.load(latent_path).to(device)
    print(f'latent shape: {music_latent.shape}')

    background_music_path = music_type + '_as_reference.wav' if background_music_path is None else background_music_path

    original_music, _ = librosa.load(background_music_path, sr=sr, duration=duration)
    original_music = audio_volumn_regularization(original_music)

    print(f'original_music shape: {original_music.shape}')


    synthesized_latent, instrument_added_music = inference_pipe(pipeline,
        reference_latent_X_0=music_latent,
        reference_guidance_decay = reference_guidance_decay,
        eta = 0.1,
        prompt = prompt_with_prefix,
        num_inference_steps=num_inference_steps,
        audio_length_in_s=duration,
        output_type='latent'
    )

    instrument_added_music = instrument_added_music.audios[0]
    instrument_added_music = audio_volumn_regularization(instrument_added_music)

    print(f'instrument_added_music shape: {instrument_added_music.shape}')
    only_instrument_file_path = f'[{instrument_class_prefix}]{music_type}_with_guidance_{reference_guidance_decay}.wav'
    scipy.io.wavfile.write(only_instrument_file_path, rate=sr, data=instrument_added_music)


    mix_with_original(inst= instrument_class_prefix, music_type=music_type, guidance=reference_guidance_decay, weight1=synthesis_strength, background_music_path=background_music_path)

    mixed_latent_path = f'[{instrument_class_prefix}]{music_type}_with_guidance_{reference_guidance_decay}_latent_tensor.pt'
    mixed_file_path = f'[{instrument_class_prefix}_mixed_with_original]{music_type}_with_guidance_{reference_guidance_decay}.wav'
    torch.save(waveform_to_latent(mixed_file_path, pipeline), mixed_latent_path)

In [86]:
music_type = "classic_music"
latent_path = make_latent_from_original(music_type)
print(latent_path)

before audio reg(RMS): 0.040106818079948425
mel spec length: 1025


  latent = torch.load(latent_path).to(device)


classic_music_latent_tensor.pt


# Generate a baseline music(reference) *IF YOU DO NOT HAVE A REFERENCE MUSIC* #
If you get your own soundtrack, then this process is not necessary


In [None]:
prompt = "soft and harmonic melody based colorful hook part of FUTURE BASS MUSIC"

music_latent = pipeline(
    prompt,
    num_inference_steps=100,
    audio_length_in_s=10.24,
    output_type= "latent"
).audios


torch.save(music_latent, music_type + '_latent_tensor.pt')

scaled_latent = 1 / pipeline.vae.config.scaling_factor * music_latent
mel_spectrogram = pipeline.vae.decode(scaled_latent).sample

audio = pipeline.mel_spectrogram_to_waveform(mel_spectrogram)

vocoder_upsample_factor = np.prod(pipeline.vocoder.config.upsample_rates) / pipeline.vocoder.config.sampling_rate
audio_length_in_s = pipeline.unet.config.sample_size * pipeline.vae_scale_factor * vocoder_upsample_factor
original_waveform_length = int(audio_length_in_s * pipeline.vocoder.config.sampling_rate)
audio = audio[:, :original_waveform_length]
audio = audio.detach().cpu().numpy()



scipy.io.wavfile.write(music_type + '_as_reference.wav', rate=16000, data=audio[0])

#  Synthesize instruments *WHICH IS MAIN PART OF THIS SCRIPT* #

In [108]:
gc.collect()
instrument_class_prefix = 'Guitar' #TODO: change instrument which you want to convert music style to

inst_prompt = {
'Grand Piano': 'soft, harmonic melody but discrete timbre, colorful and vibrant',
'Drum': 'rhythmical, kick and snear included in the sound, various, vibrant',
'Guitar': 'fast, anomalous, reckless, frantic, irregular',
'Violin': 'continuous, consistent, harmonic, rigorous'
}

#TODO: 실험 결과
# reference_guidance_decay >= 0.943  -> reference를 거의 베끼는 정도
# 0.94 > reference_guidance_decay > 0.925  -> 악기에 따라 적절한 정도가 다름(실험 필요)

synthesized_instrument_with_reference(instrument_class_prefix, music_type, inst_prompt[instrument_class_prefix],  reference_guidance_decay = 0.927, synthesis_strength=0.15, num_inference_steps=25)

  music_latent = torch.load(latent_path).to(device)


latent shape: torch.Size([1, 8, 256, 16])
before audio reg(RMS): 0.040106818079948425
original_music shape: (163840,)
device:cpu


  0%|          | 0/25 [00:00<?, ?it/s]

before audio reg(RMS): 0.10560053586959839
instrument_added_music shape: (163840,)
before audio reg(RMS): 0.0352860726416111
before audio reg(RMS): 0.15000002086162567
mel spec length: 1025


# Mixing audios from each instruments #

In [113]:
Piano_guidance = 0.93
Drum_guidance = 0.93
Guitar_guidance = 0.927
DtoP = 1.0
GtoDP = 0.8
mixed_path = mix_two_inst('Drum', 'Grand Piano', music_type, guidance1=Drum_guidance, guidance2=Piano_guidance,  weight1=DtoP)
path = mix_with_original(inst= 'Drum_Grand Piano_Guitar', music_type = music_type, weight1=GtoDP, background_music_path=f'[Guitar]{music_type}_with_guidance_{Guitar_guidance}.wav', inst_music_path=mixed_path)
print(f'weight distribution: Drum({DtoP/(DtoP+1)/(GtoDP+1):.3f}), Grand Piano({1/(DtoP+1)/(GtoDP+1):.3f}), Guitar({GtoDP/(GtoDP+1):.3f}) ')

before audio reg(RMS): 0.10553175956010818
before audio reg(RMS): 0.10872796177864075
weight distribution: Drum(0.278), Grand Piano(0.278), Guitar(0.444) 


In [None]:
# no_guide_instrument_added_music = inference_pipe(pipeline,
#     reference_latent_X_0=music_latent,
#     #reference_guidance_strength = 1,
#     eta = 0.1,
#     prompt = prefix_without_prompt,
#     num_inference_steps=20,
#     audio_length_in_s=10.24,
# ).audios

# no_guide_save_file_name = '[No_guidance_to_compare]' + instrument_class_prefix + file_name + ".wav"
# scipy.io.wavfile.write(no_guide_save_file_name, rate=16000, data=no_guide_instrument_added_music[0])
# scipy.io.wavfile.write('[mixed_with_original]' + save_file_name, rate=16000, data=mix_audios(instrument_added_music[0], audio[0], weight1=0.8, weight2=0.2))
