In [1]:
import os

# Set the CUDA_VISIBLE_DEVICES environment variable to specify which GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
from musiclm_pytorch import MuLaNTrainer, MuLaN

import pickle
import torch

from torch.utils.data import Dataset

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

2024-06-02 15:05:36.389693: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-02 15:05:36.391351: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-02 15:05:36.427928: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 256,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

In [4]:

from torch.utils.data import Dataset
from pathlib import Path


class MusicDataset(Dataset):
    def __init__(self, musiccaps_dataset_pkl_path: Path, sdd_dataset_pkl_path: Path):
                
        self.musiccaps_dataset = pickle.load(open(musiccaps_dataset_pkl_path, 'rb'))
        self.num_musiccaps = len(self.musiccaps_dataset)

        self.sdd_dataset = pickle.load(open(sdd_dataset_pkl_path, 'rb'))
        self.num_sdd = len(self.sdd_dataset)
     
        self.num_data = self.num_musiccaps + self.num_sdd
        
        self.wav_duration = 16000 * 10 # 10 seconds
        
        
    def __len__(self):
        return self.num_data
    
    def __getitem__(self, idx):
        if idx < self.num_musiccaps:
            wav = self.musiccaps_dataset[idx][0]
            cap = self.musiccaps_dataset[idx][1]
        
        else:
            # get the sdd_dataset_index
            idx = idx - self.num_musiccaps
            real_wav_len = self.sdd_dataset[idx][2]
            
            # randomly select a starting point
            start_point = torch.randint(0, real_wav_len - self.wav_duration, (1,)).item()
            wav = self.sdd_dataset[idx][0][start_point:start_point+self.wav_duration]
            cap = self.sdd_dataset[idx][1]
            
        return wav, cap
    
    
# training_data = MuLanDataset(
#     txt_pickle_path=Path('pkls/txts.pkl'),
#     wav_pickle_path=Path('pkls/wavs.pkl'))

In [5]:
# requires ~ 45GB of DRAM
training_data = MusicDataset(
    musiccaps_dataset_pkl_path=Path('pkls/musiccaps_dataset.pkl'),
    sdd_dataset_pkl_path=Path('pkls/sdd_dataset.pkl')
)

In [6]:
# test MusicDataset
num_musiccap_data = training_data.num_musiccaps
print(num_musiccap_data)
wav, txt = training_data[num_musiccap_data + 10]
print(wav.shape, txt)

wav, txt = training_data[10]
print(wav.shape, txt)

5480
torch.Size([160000]) Rainy piano vibe well mix and spatially pan vocal
torch.Size([160000]) this record contain break and shoot sound . there be also a lot of deep rumble noise . the whole audio be pan to the right side of the speaker . this be an amateur record and of poor audio quality . this audio may be play in a video game .


In [7]:
print(training_data.num_data)

6186


### Train Mulan

In [8]:
mulan_trainer = MuLaNTrainer(mulan=mulan, dataset=training_data, num_train_steps=2000, batch_size=24, grad_accum_every=24)

training with dataset of 5876 samples and validating with randomly splitted 310 samples


In [9]:
mulan_trainer.train()

spectrogram yielded shape of (129, 1251), but had to be cropped to (128, 1248) to be patchified for transformer
0: loss: 3.146928042173385
0: saving model to results
0: saving model with minimum loss to results
1: loss: 3.5955953995386754
2: loss: 3.1361465950806937
2: saving model with minimum loss to results
3: loss: 3.135649800300598
3: saving model with minimum loss to results
4: loss: 3.129435191551844
4: saving model with minimum loss to results
5: loss: 3.1293199757734933
5: saving model with minimum loss to results
6: loss: 3.1254666348298397
6: saving model with minimum loss to results
7: loss: 3.1177236338456473
7: saving model with minimum loss to results
8: loss: 3.1072869002819066
8: saving model with minimum loss to results
9: loss: 3.091572354237239
9: saving model with minimum loss to results
10: loss: 3.0837446848551435
10: saving model with minimum loss to results
11: loss: 3.0354432463645935
11: saving model with minimum loss to results
12: loss: 3.0267563958962755
1

KeyboardInterrupt: 

### Load saved mulan model

In [None]:
mulan_trainer = MuLaNTrainer(mulan=mulan, dataset=training_data, num_train_steps=1000, batch_size=2, grad_accum_every=16)
min_loss_path = '/root/musiclm-pytorch/results/mulan_min_loss.pt'
mulan_trainer.load(min_loss_path)

### Get most sim music for the given text description

In [None]:
from tqdm import tqdm

special_characters = {'&', ',', '"', "'", '/', ';', '“', '(', '‘', '’', '.', ')', '-', '\n', ':'}
def replace_special_characters_with_space(text):
    for char in special_characters:
        text = text.replace(char, ' ')
    return text
# input
query_text = ['This music features a classic piano solo, showcasing intricate melodies and expressive harmonies. The timeless elegance and nuanced performance create an immersive and captivating listening experience.']

query_text = [replace_special_characters_with_space(text) for text in query_text]


# get the latent representation of the query text
query_text_latent = mulan_trainer.mulan.get_text_latents(raw_texts=query_text)

# get the audio representation of the query text, highest similarity, iterate over all mulan dataset
max_similarity = 0
max_similarity_idx = 0
idx_simliarity_text_list = []
for idx in tqdm(range(len(training_data))):
    wav, txt = training_data[idx]
    # append fake batch
    wav = torch.unsqueeze(wav, 0).to(mulan_trainer.device)
    audio_latent = mulan_trainer.mulan.get_audio_latents(wav)
    
    # compute cosine similarity between two latents
    similarity = torch.nn.functional.cosine_similarity(query_text_latent, audio_latent).detach().cpu().numpy()
    
    idx_simliarity_text_list.append((idx, similarity, txt))
    
    if similarity > max_similarity:
        max_similarity = similarity
        max_similarity_idx = idx
        


In [None]:
import IPython.display as ipd
from IPython.display import display

# sort the similarity_text list by similarity (higest to lowest)
print(query_text)
idx_simliarity_text_list.sort(key=lambda x: x[1], reverse=True)
for sim_text in idx_simliarity_text_list[:10]:
    idx, sim, txt = sim_text
    print(f'{idx} {sim} - {txt}')
    # display(ipd.Audio(training_data[idx][0], rate=16000))

In [None]:
# show the best match
print(training_data[max_similarity_idx][1])
# play the audio of the best match
import IPython.display as ipd
ipd.Audio(training_data[max_similarity_idx][0], rate=16000)

### Text Encoder

In [None]:

'''
text1 = ['his voice.  song']
text2 = ['his voice   \n  song']
text3 = ['his voice song']

embed1 = mulan.get_text_latents(raw_texts=text1)
print(embed1.shape)
embed2 = mulan.get_text_latents(raw_texts=text2)
print(embed2.shape)
embed3 = mulan.get_text_latents(raw_texts=text3)

print(embed1.sum())
print(embed2.sum())
print(embed3.sum())
'''
