In [1]:
import torch, random
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm
from IPython.display import Audio
from matplotlib import pyplot as plt
from diffusers import DiffusionPipeline, AudioPipelineOutput
from torchaudio import transforms as AT
from torchvision import transforms as IT

import torchaudio
import transformers
from transformers import ClapProcessor, ClapModel, AutoProcessor
from torch.optim import AdamW, Adam
from torch.utils.data import DataLoader, Dataset
import os

from diffusers import MusicLDMPipeline

os.environ["TOKENIZERS_PARALLELISM"] = "false"


  from .autonotebook import tqdm as notebook_tqdm


## Music Bench (to continue)

In [85]:
from datasets import load_dataset

ds_music_bench = load_dataset("amaai-lab/MusicBench")

Downloading data: 100%|██████████| 85.7M/85.7M [00:02<00:00, 35.5MB/s]
Downloading data: 100%|██████████| 518k/518k [00:00<00:00, 1.19MB/s]
Downloading data: 100%|██████████| 591k/591k [00:00<00:00, 1.75MB/s]
Generating train split: 100%|██████████| 52768/52768 [00:00<00:00, 221389.24 examples/s]
Generating test split: 100%|██████████| 800/800 [00:00<00:00, 60154.95 examples/s]


In [107]:
len(ds_music_bench['train'])

52768

In [116]:
ds_music_bench['train'][0]['location']

'data_aug2/-0SdAVK79lg_1.wav'

## Build Dataset

In [2]:
text_tensor =  np.load('/srv/nfs-data/sisko/matteoc/music/music_data_caps_capt.npy').tolist()
audio_tensor = torch.load('/srv/nfs-data/sisko/matteoc/music/music_data_caps_audio.pt')

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from diffusers import AudioLDM2Pipeline

repo_id = "cvssp/audioldm2-music"
audioldm_pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
cuda_dev = "cuda:1"
device = cuda_dev if torch.cuda.is_available() else "cpu"
audioldm_pipe = audioldm_pipe.to(device)

clap_model_id = "laion/clap-htsat-unfused"
clap_model = ClapModel.from_pretrained(clap_model_id).to(device)
clap_process = AutoProcessor.from_pretrained(clap_model_id)

Loading pipeline components...: 100%|██████████| 11/11 [00:39<00:00,  3.63s/it]
  return self.fget.__get__(instance, owner)()


In [4]:
from librosa.filters import mel as librosa_mel_fn
import sys
# Add the root directory to the Python path
sys.path.append(os.path.abspath(".."))
from data.audioLDM_pre import *

sampling_rate_ldm = 16000
n_mel_channels = 64
mel_fmin = 0
mel_fmax = 8000
duration = 10.24
filter_length = 1024
hop_length = 160
win_length = 1024 
window = 'hann'
target_length = int(duration * sampling_rate_ldm / hop_length)
pad_wav_start_sample = 0

# resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=sampling_rate_ldm)
# audio_res = resampler(audio_tensor[0])

stft = STFT(
    filter_length=filter_length, 
    hop_length=hop_length, 
    win_length=win_length,
    window=window
)




In [5]:
class AudioDataset(Dataset):
    def __init__(self, dataset, captions=None, sample_rate_dataset=44100, new_sr=sampling_rate_ldm):
        self.dataset = dataset
        self.captions = captions
        self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_dataset, new_freq=new_sr)
        # vself.resampler_clap = torchaudio.transforms.Resample(orig_freq=sample_rate_dataset, new_freq=48000)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        audio_array = self.dataset[idx]
        audio_res = self.resampler(audio_array)

        if self.captions is not None:
            caption = self.captions[idx]
            return audio_res, caption
         
        else:
            return audio_res
    
    
def get_mel_features(audio):
    magnitude, phase = stft.transform(audio)
    mel_basis = librosa_mel_fn(
                sr=sampling_rate_ldm, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
            )
    mel_basis = torch.from_numpy(mel_basis).float()
    magnitudes = magnitude.data
    mel_output = torch.matmul(mel_basis, magnitudes)
    mel_output = spectral_normalize(mel_output, torch.log).permute(0,2,1)

    return mel_output
    
    


In [6]:
audio_dataset = AudioDataset(audio_tensor, text_tensor)

seed = 70
torch.manual_seed(seed)
random.seed(seed)

train_size = int(0.8 * len(audio_dataset))
val_size = len(audio_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(audio_dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=4)



In [7]:
import librosa

for step, batch in (enumerate(val_dataloader)):
    batch_sample = 0
    batch_audio = batch[0][batch_sample].unsqueeze(0)
    batch_audio_zeros = torch.zeros_like(batch_audio)   
    print('batch_audio: ', batch_audio.shape)
    batch_text = batch[1][batch_sample]
    
    real_audio_val = get_mel_features(batch_audio).to(device).unsqueeze(1)
    audio_val_zeros = get_mel_features(batch_audio_zeros).to(device).unsqueeze(1)
    real_text_val = batch_text
    print('real_audio_val: ', real_audio_val.shape)
    print('real_text_raw: ', real_text_val)

    # audio_features_val = clap_model.get_audio_features(real_audio_val)
    # audio_val_zeros = clap_model.get_audio_features(audio_val_zeros)
    # audio_features_val = torch.cat((audio_val_zeros, audio_features_val), dim=0)
    prompt_embeds_val, attention_mask, generated_prompt_embeds = audioldm_pipe.encode_prompt(
            real_text_val,
            device,
            num_waveforms_per_prompt=1,
            do_classifier_free_guidance=True,
            transcription=None,
            negative_prompt='',
            prompt_embeds=None,
            negative_prompt_embeds=None,
            generated_prompt_embeds=None,
            negative_generated_prompt_embeds=None,
            attention_mask=None,
            negative_attention_mask=None,
            max_new_tokens=8,
        )
    print('prompt_embeds_val: ', prompt_embeds_val.shape)
    print('generated_prompt_embeds: ', generated_prompt_embeds.shape)

    with torch.no_grad():
        # encoded = musicldm_pipe.vae.encode(real_audio_val2['input_features'].to(device))
        encoded = audioldm_pipe.vae.encode(real_audio_val.half())
        latents_val = audioldm_pipe.vae.config.scaling_factor * encoded.latent_dist.sample()
        print('latents_val: ', latents_val.shape)
    with torch.no_grad():
        mel_spectrogram_val = audioldm_pipe.vae.decode(latents_val/audioldm_pipe.vae.config.scaling_factor).sample
        print('mel_spectrogram_val: ', mel_spectrogram_val.shape)

    original_waveform_length_val = int(duration * audioldm_pipe.vocoder.config.sampling_rate)
    audio_val = audioldm_pipe.mel_spectrogram_to_waveform(mel_spectrogram_val.to(device=device))
    audio_val = audio_val[:, :original_waveform_length_val]
    output_type = "np"
    if output_type == "np":
        audio_val = audio_val.detach().numpy()
    audio_pipe_val = AudioPipelineOutput(audios=audio_val)
    break

batch_audio:  torch.Size([1, 160000])
real_audio_val:  torch.Size([1, 1, 1001, 64])
real_text_raw:  This is an electro swing/funk music piece. It is an instrumental piece. There is a brass section making up the most of the melody with the saxophone as the lead. The bass line is being played by the tuba. There is also a piano adding texture to the melody. The rhythmic background is provided by an electronic drum beat. The atmosphere is upbeat and eccentric. This piece could be playing at a nightclub or a dance club. It could also work well in the soundtrack of a comedy animation movie.
prompt_embeds_val:  torch.Size([2, 116, 1024])
generated_prompt_embeds:  torch.Size([2, 8, 768])
latents_val:  torch.Size([1, 8, 250, 16])
mel_spectrogram_val:  torch.Size([1, 1, 1000, 64])


In [8]:
# True Audio

Audio(audio_pipe_val[0], rate=16000)

## Inference

In [9]:
prompt_embeds_val, attention_mask, generated_prompt_embeds = audioldm_pipe.encode_prompt(
            real_text_val,
            device,
            num_waveforms_per_prompt=1,
            do_classifier_free_guidance=True,
            transcription=None,
            negative_prompt='',
            prompt_embeds=None,
            negative_prompt_embeds=None,
            generated_prompt_embeds=None,
            negative_generated_prompt_embeds=None,
            attention_mask=None,
            negative_attention_mask=None,
            max_new_tokens=8,
        )
print('prompt_embeds_val: ', prompt_embeds_val.shape)
print('generated_prompt_embeds: ', generated_prompt_embeds.shape)

prompt_embeds_val:  torch.Size([2, 116, 1024])
generated_prompt_embeds:  torch.Size([2, 8, 768])


In [10]:
device = audioldm_pipe._execution_device
num_inference_steps = 50
audioldm_pipe.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = audioldm_pipe.scheduler.timesteps

In [11]:
num_channels_latents = audioldm_pipe.unet.config.in_channels
num_waveforms_per_prompt = 1
vocoder_upsample_factor = np.prod(audioldm_pipe.vocoder.config.upsample_rates) / audioldm_pipe.vocoder.config.sampling_rate
audio_length_in_s = audioldm_pipe.unet.config.sample_size * audioldm_pipe.vae_scale_factor * vocoder_upsample_factor
height = int(audio_length_in_s / vocoder_upsample_factor)
original_waveform_length = int(audio_length_in_s * audioldm_pipe.vocoder.config.sampling_rate)
generator = torch.Generator(cuda_dev).manual_seed(42)

latents = audioldm_pipe.prepare_latents(
    1 * num_waveforms_per_prompt,
    num_channels_latents,
    height,
    dtype=prompt_embeds_val.dtype,
    device=device,
    generator=generator,
    latents=None
)

In [12]:
eta = 0.0
extra_step_kwargs = audioldm_pipe.prepare_extra_step_kwargs(generator, eta)
guidance_scale = 3.5
do_classifier_free_guidance = guidance_scale > 1

In [14]:
num_warmup_steps = len(timesteps) - num_inference_steps * audioldm_pipe.scheduler.order
with audioldm_pipe.progress_bar(total=num_inference_steps) as progress_bar:
    for i, t in enumerate(timesteps):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = audioldm_pipe.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = audioldm_pipe.unet(
            latent_model_input,
            t,
            encoder_hidden_states=generated_prompt_embeds,
            encoder_hidden_states_1=prompt_embeds_val,
            encoder_attention_mask_1=None,
            return_dict=False,
        )[0]

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = audioldm_pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample

        # call the callback, if provided
        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % audioldm_pipe.scheduler.order == 0):
            progress_bar.update()

audioldm_pipe.maybe_free_model_hooks()

output_type == "np"
if not output_type == "latent":
    latents = 1 / audioldm_pipe.vae.config.scaling_factor * latents
    mel_spectrogram = audioldm_pipe.vae.decode(latents).sample

audio = audioldm_pipe.mel_spectrogram_to_waveform(mel_spectrogram)
audio = audio[:, :original_waveform_length]

if output_type == "np":
    audio = audio.detach().numpy()
audio_pipe = AudioPipelineOutput(audios=audio)

100%|██████████| 50/50 [00:08<00:00,  5.96it/s]


In [15]:
Audio(audio, rate=16000)