In [1]:
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader
import os
import pandas as pd 
import librosa
import numpy as np
import matplotlib.pyplot as plt
import torch 
import torchaudio
import tqdm

from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset = load_dataset('amaai-lab/MusicBench', data_files="MusicBench_train_modified.json")  
# dataset = load_dataset('amaai-lab/MusicBench')  

In [3]:
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer


2024-04-19 15:20:56 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [4]:
audio_transformer = AudioSpectrogramTransformer(
    dim=512,
    depth=6,
    heads=8,
    dim_head=64,
    spec_n_fft=128,
    spec_win_length=24,
    spec_aug_stretch_factor=0.8
)

text_transformer = TextTransformer(
    dim=512,
    depth=6,
    heads=8,
    dim_head=64
)

mulan = MuLaN(
    audio_transformer=audio_transformer,
    text_transformer=text_transformer
)


In [5]:
dataloader = DataLoader(dataset['train'], batch_size=64, shuffle=True, collate_fn=lambda x: x)


In [6]:
dataloader.dataset[0]


{'dataset': 'MusicBench',
 'location': 'data_aug2/-0SdAVK79lg_1.wav',
 'main_caption': 'This mellow instrumental track showcases a dominant electric guitar that opens with a descending riff, followed by arpeggiated chords, hammer-ons, and a slide. The percussion section keeps it simple with rim shots and a common time count, while the bass adds a single note on the first beat of every bar. Minimalist piano chords round out the song while leaving space for the guitar to shine. There are no vocals, making it perfect for a coffee shop or some chill background music. The key is in E major, with a chord progression that centers around that key and a straightforward 4/4 time signature.',
 'alt_caption': 'This song features an electric guitar as the main instrument. The guitar plays a descending run in the beginning then plays an arpeggiated chord followed by a double stop hammer on to a higher note and a descending slide followed by a descending chord run. The percussion plays a simple beat 

## MusicCaps

In [2]:
from datasets import load_dataset

ds = load_dataset('google/MusicCaps', split='train')

In [3]:
import subprocess
import os
from pathlib import Path

def download_clip(
    video_identifier,
    output_filename,
    start_time,
    end_time,
    tmp_dir='/tmp/musiccaps/',
    num_attempts=5,
    url_base='https://www.youtube.com/watch?v='
):
    status = False

    command = f"""
        yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o "{output_filename}" --download-sections "*{start_time}-{end_time}" {url_base}{video_identifier}
    """.strip()

    attempts = 0
    while True:
        try:
            output = subprocess.check_output(command, shell=True,
                                                stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            attempts += 1
            if attempts == num_attempts:
                return status, err.output
        else:
            break

    # Check if the video was successfully saved.
    status = os.path.exists(output_filename)
    return status, 'Downloaded'



In [4]:
from datasets import Audio

samples_to_load = 5521      # How many samples to load
cores = 4                 # How many processes to use for the loading
sampling_rate = 44100     # Sampling rate for the audio, keep in 44100
writer_batch_size = 1000  # How many examples to keep in memory per worker. Reduce if OOM.
data_dir = "/srv/nfs-data/sisko/matteoc/music_data_caps" # Where to save the data

# Just select some samples 
ds = ds.select(range(samples_to_load))

# Create directory where data will be saved
data_dir = Path(data_dir)
data_dir.mkdir(exist_ok=True, parents=True)

def process(example):
    outfile_path = str(data_dir / f"{example['ytid']}.wav")
    status = True
    if not os.path.exists(outfile_path):
        status = False
        status, log = download_clip(
            example['ytid'],
            outfile_path,
            example['start_s'],
            example['end_s'],
        )

    example['audio'] = outfile_path
    example['download_status'] = status
    return example



In [5]:
ds = ds.map(
        process,
        num_proc=cores,
        writer_batch_size=writer_batch_size,
        keep_in_memory=False
    ).cast_column('audio', Audio(sampling_rate=sampling_rate))

In [6]:
len(os.listdir(data_dir))

5406

In [7]:
ds[5520]

{'ytid': 'zzNdwF40ID8',
 'start_s': 70,
 'end_s': 80,
 'audioset_positive_labels': '/m/04rlf,/m/0790c',
 'aspect_list': "['glitch', 'noise', 'instrumental', 'electronic', 'synth', 'granular', 'bells', 'flow', 'rising-and-falling', 'eerie', 'uneasy', 'robotic', 'analog sounding']",
 'caption': 'This is a glitch music piece. There is a synth sound rising in pitch that resembles a triangle wave. There are granular synth samples being played randomly. A virtual percussive low-to-mid bell sound is playing a melody that resembles a marimba. There is an eerie feeling of flow. This piece could be used in the soundtracks of dystopian sci-fi movies. It could also be used in exploration sequences of video games.',
 'author_id': 9,
 'is_balanced_subset': True,
 'is_audioset_eval': True,
 'audio': {'path': '/srv/nfs-data/sisko/matteoc/music_data_caps/zzNdwF40ID8.wav',
  'array': array([ 0.00069827, -0.00025624, -0.00119744, ..., -0.03152088,
         -0.03809599,  0.        ]),
  'sampling_rate': 4

In [8]:
caps_audio = []
caps_sr = []
caps_caption = []
for i in tqdm.tqdm(range(len(ds))):
    try:
        audio, sr = torchaudio.load(ds[i]['audio']['path'])
        caption = ds[i]['caption']
        caps_audio.append(audio)
        caps_sr.append(sr)
        caps_caption.append(caption)
    except:
        continue




100%|██████████| 5521/5521 [06:05<00:00, 15.11it/s]


In [None]:
caps_audio_mono = []
for audio_stereo in caps_audio:
    mono_waveform = torch.mean(audio_stereo, dim=0, keepdim=True)
    caps_audio_mono.append(mono_waveform)

In [None]:
len(caps_audio_mono)

5402

In [15]:
from diffusers import MusicLDMPipeline

repo_id = "ucsd-reach/musicldm"
pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

Loading pipeline components...: 100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


In [14]:
from transformers import AutoFeatureExtractor, ClapModel, ClapProcessor

model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to("cuda")
processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")

# model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to("cuda:1")
feature_extractor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused")

In [None]:
caps_audio_mono[0].shape

torch.Size([1, 958728])

In [None]:
audio_feat=[]
text_feat=[]
# resampler = torchaudio.transforms.Resample(orig_freq=22050, new_freq=48000)

# with torch.no_grad():
#     for wv,sr in tqdm.tqdm(zip(caps_audio_mono,caps_sr)):
#         inputs = processor(audios=wv.squeeze(), return_tensors="pt", sampling_rate=48_000)
#         audio_features = model.get_audio_features(inputs.input_features.to("cuda")).cpu()
#         audio_feat.append(audio_features)

with torch.no_grad():
    for wv,tx,sr in tqdm.tqdm(list(zip(caps_audio_mono,caps_caption,caps_sr))):
        inputs = processor(text=tx, audios=wv.squeeze(), return_tensors="pt", sampling_rate=48_000)
        outputs = model(**inputs.to('cuda'))
        audio_feat.append(outputs.audio_embeds)
        text_feat.append(outputs.text_embeds)

100%|██████████| 5402/5402 [22:15<00:00,  4.05it/s]


In [94]:
# inputs = processor(text=caps_caption[0], audios=caps_audio_mono[0].squeeze(), return_tensors="pt", sampling_rate=48_000)
# outputs = model(**inputs.to('cuda'))

In [102]:
audio_feat=torch.stack(audio_feat).squeeze()
text_feat=torch.stack(text_feat).squeeze()
torch.save(audio_feat, "/srv/nfs-data/sisko/matteoc/music_data_caps/feature_clap/audio_feat.pt")
torch.save(text_feat, "/srv/nfs-data/sisko/matteoc/music_data_caps/feature_clap/text_feat.pt")

In [10]:
audio_feat = torch.load("/srv/nfs-data/sisko/matteoc/music_data_caps/feature_clap/audio_feat.pt")
text_feat = torch.load("/srv/nfs-data/sisko/matteoc/music_data_caps/feature_clap/text_feat.pt")

In [32]:
# text_feat=torch.stack(text_feat).squeeze()

In [11]:
print(audio_feat.shape)
print(text_feat.shape)

torch.Size([5402, 512])
torch.Size([5402, 512])


In [36]:
caps_caption[4322]

'Latin electronic music with a Cumbia feel featuring an autotuned male vocal, female vocal response and syncopated drum pattern with a large room ambience which is muffling the music.'

In [16]:
text_feat_ldm = []
with torch.no_grad():
    for tx in tqdm.tqdm(caps_caption):
        feat_text_ldm = pipe._encode_prompt(tx, device='cuda', num_waveforms_per_prompt=1, do_classifier_free_guidance=False)
        text_feat_ldm.append(feat_text_ldm)


100%|██████████| 5402/5402 [00:37<00:00, 144.12it/s]


In [17]:
text_feat_ldm=torch.stack(text_feat_ldm).squeeze()

In [37]:
# prompt_embd = pipe._encode_prompt(caps_caption[400], device='cuda', num_waveforms_per_prompt=1, do_classifier_free_guidance=False)
audio_prova = pipe(prompt_embeds=text_feat_ldm[4322:4323], num_inference_steps=50, audio_length_in_s=10.0).audios[0]

100%|██████████| 50/50 [00:01<00:00, 36.90it/s]


In [38]:
simil_score = audio_feat[1].float() @ text_feat[1].T

In [39]:
simil_score

tensor(0.5887, device='cuda:0')

In [40]:
from IPython.display import Audio

Audio(audio_prova, rate=16000)

In [26]:
import torch.nn as nn
import torch.optim as optim

input_dim = 512
output_dim = 512
linear_layer = nn.Linear(input_dim, output_dim).to('cuda')
criterion = nn.MSELoss()
optimizer = optim.Adam(linear_layer.parameters(), lr=0.0001, weight_decay=1e-6)

num_samples = text_feat_ldm.shape[0]
train_size = int(0.8 * num_samples)
test_size = num_samples - train_size
train_audio_feat, test_audio_feat = torch.split(audio_feat, [train_size, test_size], dim=0)
train_text_feat, test_text_feat = torch.split(text_feat_ldm, [train_size, test_size], dim=0)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    linear_layer.train()
    optimizer.zero_grad() 
    audio_feat_output = linear_layer(train_audio_feat)
    loss = criterion(audio_feat_output, train_text_feat.float()) 
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 20 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')



Epoch [20/1000], Loss: 0.0022374214604496956
Epoch [40/1000], Loss: 0.0017978204414248466
Epoch [60/1000], Loss: 0.0015939654549583793
Epoch [80/1000], Loss: 0.0014780862256884575
Epoch [100/1000], Loss: 0.0014023107942193747
Epoch [120/1000], Loss: 0.0013485312229022384
Epoch [140/1000], Loss: 0.001308442442677915
Epoch [160/1000], Loss: 0.0012775688664987683
Epoch [180/1000], Loss: 0.0012531852116808295
Epoch [200/1000], Loss: 0.0012335034552961588
Epoch [220/1000], Loss: 0.0012173040304332972
Epoch [240/1000], Loss: 0.0012037343112751842
Epoch [260/1000], Loss: 0.0011921871919184923
Epoch [280/1000], Loss: 0.0011822228552773595
Epoch [300/1000], Loss: 0.0011735177831724286
Epoch [320/1000], Loss: 0.0011658292496576905
Epoch [340/1000], Loss: 0.0011589732021093369
Epoch [360/1000], Loss: 0.0011528076138347387
Epoch [380/1000], Loss: 0.0011472211917862296
Epoch [400/1000], Loss: 0.0011421259259805083
Epoch [420/1000], Loss: 0.001137451035901904
Epoch [440/1000], Loss: 0.00113313901238

In [27]:
linear_layer.eval()  
with torch.no_grad():  
    test_audio_output = linear_layer(test_audio_feat)
    test_loss = criterion(test_audio_output, test_text_feat)
    
print(f'Test Loss: {test_loss.item()}')

Test Loss: 0.0011135325767099857


In [28]:
linear_layer.eval()  
with torch.no_grad():  
    test_loss = criterion(test_audio_feat, test_text_feat)

print(f'Test Loss: {test_loss.item()}')

Test Loss: 0.0039655170403420925


In [29]:
test_audio_output.shape

torch.Size([1081, 512])

In [30]:
audio_pred = pipe(prompt_embeds=test_audio_output[1:2], num_inference_steps=50, audio_length_in_s=10.0).audios[0]

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:01<00:00, 36.60it/s]


In [31]:
Audio(audio_pred, rate=16000)

In [32]:
audio_vera = pipe(prompt_embeds=test_text_feat[1:2], num_inference_steps=50, audio_length_in_s=10.0).audios[0]

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:01<00:00, 36.24it/s]


In [33]:
Audio(audio_vera, rate=16000)