In [1]:
import torch, random
import numpy as np
import torch.nn.functional as F
from tqdm.auto import tqdm
from IPython.display import Audio
from matplotlib import pyplot as plt
from diffusers import DiffusionPipeline, AudioPipelineOutput
from torchaudio import transforms as AT
from torchvision import transforms as IT

import torchaudio

from transformers import ClapProcessor, ClapModel, AutoProcessor
from torch.optim import AdamW, Adam
from torch.utils.data import DataLoader, Dataset

from diffusers import MusicLDMPipeline


  from .autonotebook import tqdm as notebook_tqdm


## Load MusicCAPS

In [23]:
from datasets import load_dataset

ds = load_dataset('google/MusicCaps', split='train')

import subprocess
import os
from pathlib import Path

def download_clip(
    video_identifier,
    output_filename,
    start_time,
    end_time,
    tmp_dir='/tmp/musiccaps/',
    num_attempts=5,
    url_base='https://www.youtube.com/watch?v='
):
    status = False

    command = f"""
        yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o "{output_filename}" --download-sections "*{start_time}-{end_time}" {url_base}{video_identifier}
    """.strip()

    attempts = 0
    while True:
        try:
            output = subprocess.check_output(command, shell=True,
                                                stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            attempts += 1
            if attempts == num_attempts:
                return status, err.output
        else:
            break

    # Check if the video was successfully saved.
    status = os.path.exists(output_filename)
    return status, 'Downloaded'


from datasets import Audio as audiod

samples_to_load = 5521      # How many samples to load
cores = 4                 # How many processes to use for the loading
sampling_rate = 44100     # Sampling rate for the audio, keep in 44100
writer_batch_size = 1000  # How many examples to keep in memory per worker. Reduce if OOM.
data_dir = "/srv/nfs-data/sisko/matteoc/music/music_data_caps" # Where to save the data

# Just select some samples 
ds = ds.select(range(samples_to_load))

# Create directory where data will be saved
data_dir = Path(data_dir)
data_dir.mkdir(exist_ok=True, parents=True)

def process(example):
    outfile_path = str(data_dir / f"{example['ytid']}.wav")
    status = True
    if not os.path.exists(outfile_path):
        status = False
        status, log = download_clip(
            example['ytid'],
            outfile_path,
            example['start_s'],
            example['end_s'],
        )

    example['audio'] = outfile_path
    example['download_status'] = status
    return example


ds = ds.map(
        process,
        num_proc=cores,
        writer_batch_size=writer_batch_size,
        keep_in_memory=False
    ).cast_column('audio', audiod(sampling_rate=sampling_rate))

In [24]:
song_idx = 3799
audio_array = ds[song_idx]["audio"]["array"]
audio_caption = ds[song_idx]["caption"]
sample_rate_dataset = ds[song_idx]["audio"]["sampling_rate"]
print("Audio array shape:", audio_array.shape)
print("Audio caption:", audio_caption)
print("Sample rate:", sample_rate_dataset)
display(Audio(audio_array[0:10*sample_rate_dataset], rate=sample_rate_dataset))

Audio array shape: (880832,)
Audio caption: The low quality recording features a classical song that consists of a brass solo melody played over sustained brass melody and short flute lick. It sounds emotional, joyful and the recording is noisy.
Sample rate: 44100


In [25]:
ds[song_idx]

{'ytid': 'dbBlYyaFKTQ',
 'start_s': 30,
 'end_s': 40,
 'audioset_positive_labels': '/m/01kcd,/m/0319l,/m/05pd6,/m/07c6l,/m/07gql',
 'aspect_list': "['low quality', 'noisy', 'classical', 'flute lick', 'brass solo melody', 'sustained brass melody', 'emotional', 'joyful']",
 'caption': 'The low quality recording features a classical song that consists of a brass solo melody played over sustained brass melody and short flute lick. It sounds emotional, joyful and the recording is noisy.',
 'author_id': 4,
 'is_balanced_subset': False,
 'is_audioset_eval': False,
 'audio': {'path': '/srv/nfs-data/sisko/matteoc/music/music_data_caps/dbBlYyaFKTQ.wav',
  'array': array([-0.00827911, -0.01138475, -0.01273868, ..., -0.01942613,
         -0.01393065,  0.        ]),
  'sampling_rate': 44100},
 'download_status': True}

## Music Bench

In [85]:
from datasets import load_dataset

ds_music_bench = load_dataset("amaai-lab/MusicBench")

Downloading data: 100%|██████████| 85.7M/85.7M [00:02<00:00, 35.5MB/s]
Downloading data: 100%|██████████| 518k/518k [00:00<00:00, 1.19MB/s]
Downloading data: 100%|██████████| 591k/591k [00:00<00:00, 1.75MB/s]
Generating train split: 100%|██████████| 52768/52768 [00:00<00:00, 221389.24 examples/s]
Generating test split: 100%|██████████| 800/800 [00:00<00:00, 60154.95 examples/s]


In [107]:
len(ds_music_bench['train'])

52768

In [116]:
ds_music_bench['train'][0]['location']

'data_aug2/-0SdAVK79lg_1.wav'

## Build Dataset

In [None]:
def pad_or_truncate_audio(audio, max_len_samples):
    """Pads the audio to the max length with zeros or truncates if necessary."""
    audio = torch.tensor(audio).float()
    if len(audio) < max_len_samples:
        padding = max_len_samples - len(audio)
        audio = F.pad(audio, (0, padding), 'constant', 0)
    else:
        audio = audio[:max_len_samples]
    return audio

max_len_sec = 10
# audio_tensor = torch.zeros(samples_to_load, sample_rate_dataset*max_len_sec)
audio_tensor = []
text_tensor = []
for elem in tqdm(range(len(ds))):
    try:
        audio_select = ds[elem]["audio"]["array"]
        text_select = ds[elem]["caption"]
    except FileNotFoundError:
        continue
    audio_padded = pad_or_truncate_audio(audio_select, sample_rate_dataset*max_len_sec)
    # audio_tensor[elem] = audio_padded
    audio_tensor.append(audio_padded)
    text_tensor.append(text_select)
audio_tensor = torch.stack(audio_tensor, dim=0)
text_tensor = torch.stack(text_tensor, dim=0)




In [2]:
text_tensor =  np.load('/srv/nfs-data/sisko/matteoc/music/music_data_caps_capt.npy').tolist()
audio_tensor = torch.load('/srv/nfs-data/sisko/matteoc/music/music_data_caps_audio.pt')

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

repo_id = "ucsd-reach/musicldm"
musicldm_pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float32)
device = "cuda:4" if torch.cuda.is_available() else "cpu"
musicldm_pipe = musicldm_pipe.to(device)

clap_model_id = "laion/larger_clap_music_and_speech"
clap_model = ClapModel.from_pretrained(clap_model_id).to(device)
clap_process = AutoProcessor.from_pretrained(clap_model_id)

Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  6.24it/s]


In [4]:
from librosa.filters import mel as librosa_mel_fn
import sys
import os
# Add the root directory to the Python path
sys.path.append(os.path.abspath(".."))
from data.audioLDM_pre import *

sampling_rate_ldm = 16000
n_mel_channels = 64
mel_fmin = 0
mel_fmax = 8000
duration = 10.0
filter_length = 1024
hop_length = 160
win_length = 1024 
window = 'hann'
target_length = int(duration * sampling_rate_ldm / hop_length)
pad_wav_start_sample = 0

# resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=sampling_rate_ldm)
# audio_res = resampler(audio_tensor[0])

stft = STFT(
    filter_length=filter_length, 
    hop_length=hop_length, 
    win_length=win_length,
    window=window
)




In [5]:
class AudioDataset(Dataset):
    def __init__(self, dataset, captions=None, sample_rate_dataset=44100, new_sr=sampling_rate_ldm):
        self.dataset = dataset
        self.captions = captions
        self.resampler = torchaudio.transforms.Resample(orig_freq=sample_rate_dataset, new_freq=new_sr)
        # vself.resampler_clap = torchaudio.transforms.Resample(orig_freq=sample_rate_dataset, new_freq=48000)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        audio_array = self.dataset[idx]
        audio_res = self.resampler(audio_array)

        if self.captions is not None:
            caption = self.captions[idx]
            return audio_res, caption
         
        else:
            return audio_res
    
    
def get_mel_features(audio):
    magnitude, phase = stft.transform(audio)
    mel_basis = librosa_mel_fn(
                sr=sampling_rate_ldm, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
            )
    mel_basis = torch.from_numpy(mel_basis).float()
    magnitudes = magnitude.data
    mel_output = torch.matmul(mel_basis, magnitudes)
    mel_output = spectral_normalize(mel_output, torch.log).permute(0,2,1)

    return mel_output
    
    


In [6]:
audio_dataset = AudioDataset(audio_tensor, text_tensor)

train_size = int(0.8 * len(audio_dataset))
val_size = len(audio_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(audio_dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)



In [7]:
for step, batch in (enumerate(val_dataloader)):
    batch_sample = 0
    batch_audio = batch[0]
    print('batch_audio: ', batch_audio.shape)
    batch_text = batch[1]
    real_audio_val = get_mel_features(batch_audio).to(device)
    real_text_val = batch_text[batch_sample]
    real_audio_val = real_audio_val[batch_sample].unsqueeze(0).unsqueeze(0)
    print('real_audio_val: ', real_audio_val.shape)
    print('real_text_raw: ', real_text_val)
    audio_features_val = clap_model.get_audio_features(real_audio_val)
    prompt_embeds_val = musicldm_pipe._encode_prompt(
            real_text_val,
            device,
            num_waveforms_per_prompt=1,
            do_classifier_free_guidance=False,
            negative_prompt='',
            prompt_embeds=None,
            negative_prompt_embeds=None,
        )
    print('prompt_embeds_val: ', prompt_embeds_val.shape)

    # inputs = clap_process(audios=real_audio_raw.float(), return_tensors="pt", sampling_rate=48_000)
    # outputs = clap_model(inputs=inputs['input_features'].to(device)).last_hidden_state
    # print('last_hidden_state: ', outputs.shape)

    with torch.no_grad():
        encoded = musicldm_pipe.vae.encode(real_audio_val)
        latents_val = musicldm_pipe.vae.config.scaling_factor * encoded.latent_dist.sample()
        print('latents_val: ', latents_val.shape)
    with torch.no_grad():
        mel_spectrogram_val = musicldm_pipe.vae.decode(latents_val/musicldm_pipe.vae.config.scaling_factor).sample
        print('mel_spectrogram_val: ', mel_spectrogram_val.shape)

    original_waveform_length_val = int(10.0 * musicldm_pipe.vocoder.config.sampling_rate)
    audio_val = musicldm_pipe.mel_spectrogram_to_waveform(mel_spectrogram_val.to(device=device))
    audio_val = audio_val[:, :original_waveform_length_val]
    output_type = "np"
    if output_type == "np":
        audio_val = audio_val.detach().numpy()
    audio_pipe_val = AudioPipelineOutput(audios=audio_val)
    break

batch_audio:  torch.Size([8, 160000])
real_audio_val:  torch.Size([1, 1, 1001, 64])
real_text_raw:  The Blues/Pop song features groovy hi hats, punchy snare and kick hits, tinny snare rolls in-between snare hits, addictive brass melody and smooth bass at the very end of the loop. It sounds a bit repetitive, but still addictive and energetic.
prompt_embeds_val:  torch.Size([1, 512])
latents_val:  torch.Size([1, 8, 250, 16])
mel_spectrogram_val:  torch.Size([1, 1, 1000, 64])


In [8]:
# True Audio

Audio(audio_pipe_val[0], rate=16000)

## Training

In [9]:
# from peft import get_peft_model, LoraConfig, TaskType
from peft import LoraConfig, LoraModel, get_peft_model

# unet_lora_config = LoraConfig(
#     r=64,
#     lora_alpha=128,
#     lora_dropout=0.1,
#     target_modules=["attn1.to_q", "attn1.to_v", "ff.fc1"]
# )

unet_lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        init_lora_weights="gaussian",
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    )

# musicldm_pipe.unet.add_adapter(unet_lora_config)

In [10]:
# unet_lora = LoraModel(musicldm_pipe.unet, unet_lora_config, 'default')
unet_lora = get_peft_model(musicldm_pipe.unet, unet_lora_config)

In [11]:
def count_trainable_params(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable params : {trainable_params/1e6} M  over {total_params/1e6}M : {trainable_params/total_params}")
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_trainable_params(unet_lora)

Trainable params : 1.80224 M  over 186.838792M : 0.009645962600742998


1802240

In [9]:
output_type = "np"
audio_length_in_s = 10.0

num_inference_steps = 50
do_classifier_free_guidance = False
cross_attention_kwargs = None
guidance_scale = 2.0
callback = None
callback_steps = 1
extra_step_kwargs= {}
extra_step_kwargs["eta"] = 0.0
extra_step_kwargs["generator"] = torch.Generator(device=device).manual_seed(42)

In [10]:
num_epochs = 10  
lr = 1e-4  
grad_accumulation_steps = 2  

optimizer = torch.optim.AdamW(musicldm_pipe.unet.parameters(), lr=lr)
# optimizer = torch.optim.AdamW(
#         filter(lambda p: p.requires_grad, musicldm_pipe.unet.parameters()),
#         lr=lr)

train_losses = []
val_losses = []


In [11]:
import wandb 

wandb.login(key='41a4723fac40aff96b88423b6d3e15dd64f87488')
wandb.init()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
huggingfac

In [12]:
def training_step(batch_audio, batch_text, bs):

    batch_mel  = get_mel_features(batch_audio).to(device).unsqueeze(1)
    encoded = musicldm_pipe.vae.encode(batch_mel)
    latents_real = musicldm_pipe.vae.config.scaling_factor * encoded.latent_dist.mean
    audio_features = clap_model.get_audio_features(batch_mel)
    text_features = musicldm_pipe._encode_prompt(
        list(batch_text),
        device,
        num_waveforms_per_prompt=1,
        do_classifier_free_guidance=False,
        negative_prompt='',
        prompt_embeds=None,
        negative_prompt_embeds=None,
    )

    # noise = torch.randn(latents_real.shape).to(latents_real.device)
    noise = musicldm_pipe.prepare_latents(
        bs,  #  --> da moltiplicare se num_waveforms_per_prompt > 1
        musicldm_pipe.unet.config.in_channels,
        1000,  # height
        torch.float32,
        torch.device(device),
        generator=torch.Generator(device=device).manual_seed(42),
        latents=None,
    )

    return latents_real, audio_features, text_features, noise

    

In [14]:
def inference_musicldm(latents, features_val):
    
    musicldm_pipe.scheduler.set_timesteps(num_inference_steps=50, device=device)
    timesteps = musicldm_pipe.scheduler.timesteps
    num_warmup_steps = len(timesteps) - num_inference_steps * musicldm_pipe.scheduler.order
    with musicldm_pipe.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = musicldm_pipe.scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            noise_pred_eval = musicldm_pipe.unet(
                latent_model_input,
                t,
                encoder_hidden_states=None,
                class_labels=features_val,
                cross_attention_kwargs=cross_attention_kwargs,
                return_dict=False,
            )[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred_eval.chunk(2)
                noise_pred_eval = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = musicldm_pipe.scheduler.step(noise_pred_eval, t, latents, **extra_step_kwargs).prev_sample

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % musicldm_pipe.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    step_idx = i // getattr(musicldm_pipe.scheduler, "order", 1)
                    callback(step_idx, t, latents)
        
    if not output_type == "latent":
        latents = 1 / musicldm_pipe.vae.config.scaling_factor * latents
        mel_spectrogram = musicldm_pipe.vae.decode(latents).sample
    
    original_waveform_length = int(audio_length_in_s * musicldm_pipe.vocoder.config.sampling_rate)
    audio_to_save = musicldm_pipe.mel_spectrogram_to_waveform(mel_spectrogram.to(device=device))
    audio_to_save = audio_to_save[:, :original_waveform_length]

    if output_type == "np":
        audio = audio_to_save.detach().numpy()
    audio_pipe = AudioPipelineOutput(audios=audio)

    return audio_pipe
            

In [None]:
for epoch in range(num_epochs):

    musicldm_pipe.unet.train()
    for step, batch_train in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):

        # batch = batch.to(device)
        bs = batch_train[0].shape[0]
        batch_audio = batch_train[0]
        batch_text = batch_train[1]
        latents_real, audio_features, text_features, noise = training_step(batch_audio, batch_text, bs)
        del batch_train
        
        # Sample a random timestep for each image
        timesteps = torch.randint(0, musicldm_pipe.scheduler.num_train_timesteps, (bs,), device=latents_real.device,).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process) --> we are in the training!
        noisy_latents = musicldm_pipe.scheduler.add_noise(latents_real, noise, timesteps)

        # Get the model prediction for the noise
        noise_pred = musicldm_pipe.unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=None,
            class_labels=text_features,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]
        
        # Compare the prediction with the actual noise:
        loss = F.mse_loss(
            noise_pred, noise
        )  

        # Store for later plotting
        train_losses.append(loss.item())

        # Update the model parameters with the optimizer based on this loss
        loss.backward()
        wandb.log({'train_loss_step': loss.item()})

        # Gradient accumulation:
        if (step + 1) % grad_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    print(f"Epoch {epoch} average loss: {sum(train_losses[-len(train_dataloader):])/len(train_dataloader)}")
    wandb.log({'average_train_loss': sum(train_losses[-len(train_dataloader):])/len(train_dataloader)})

    musicldm_pipe.unet.eval()
    for step, batch_val in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):

        bs = batch_val[0].shape[0]
        batch_audio = batch_val[0]
        batch_text = batch_val[1]
        latents_real, audio_features, text_features, noise = training_step(batch_audio, batch_text, bs)
        del batch_val

        # Sample a random timestep for each image
        timesteps = torch.randint(0, musicldm_pipe.scheduler.num_train_timesteps, (bs,), device=latents_real.device,).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process) --> we are in the training!
        noisy_latents = musicldm_pipe.scheduler.add_noise(latents_real, noise, timesteps)

        # Get the model prediction for the noise
        noise_pred = musicldm_pipe.unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states=None,
            class_labels=text_features,
            cross_attention_kwargs=None,
            return_dict=False,
        )[0]
        
        # Compare the prediction with the actual noise:
        loss = F.mse_loss(
            noise_pred, noise
        )  

        # Store for later plotting
        val_losses.append(loss.item())
        wandb.log({'val_loss_step': loss.item()})

    wandb.log({'average_val_loss': sum(val_losses[-len(val_dataloader):])/len(val_dataloader)})

    audio_reconstr = inference_musicldm(noise[0:1], text_features[0:1])[0]

    wandb.log({"audio_reconstr": wandb.Audio(audio_reconstr.squeeze(), sample_rate=16000, caption=batch_text[0])})
    wandb.log({"audio_real": wandb.Audio(batch_audio[0], sample_rate=16000, caption=batch_text[0])})


    
# Plot the loss curve:
plt.plot(train_losses)

In [11]:
# musicldm_pipe.unet.save_adapter("/srv/nfs-data/sisko/matteoc/music/music_ldm_train", "lora_adapter")

# clonare modello
# merge and unload
# salvare

import copy

output_dir_lora = "/srv/nfs-data/sisko/matteoc/music/lora_saved_model"

# cloned_pipe = copy.deepcopy(musicldm_pipe)
# unet_lora.merge_and_unload()
musicldm_pipe.save_pretrained(output_dir_lora)

