In [None]:
import os

# Set the CUDA_VISIBLE_DEVICES environment variable to specify which GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
from musiclm_pytorch import MuLaNTrainer, MuLaN

import pickle
import torch

from torch.utils.data import Dataset

import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

In [None]:
audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 256,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

In [None]:

from torch.utils.data import Dataset
from pathlib import Path

class MuLanDataset(Dataset):
    def __init__(self, txt_pickle_path: Path, wav_pickle_path: Path):
                
        with open(wav_pickle_path, 'rb') as f:
            self.wavs = pickle.load(f)
        
        with open(txt_pickle_path, 'rb') as f:
            self.txts = pickle.load(f)

        self.num_data = len(self.txts)
                
    def __len__(self):
        return self.num_data
    
    def __getitem__(self, idx):
        # read wav from pt file, read txt from list
        return self.wavs[idx], self.txts[idx]
    
    
    
training_data = MuLanDataset(
    txt_pickle_path=Path('pkls/txts.pkl'),
    wav_pickle_path=Path('pkls/wavs.pkl'))

### Train Mulan

In [None]:
# mulan_trainer = MuLaNTrainer(mulan=mulan, dataset=training_data, num_train_steps=1000, batch_size=16, grad_accum_every=16)
# mulan_trainer.train()

### Load saved mulan model

In [None]:
mulan_trainer = MuLaNTrainer(mulan=mulan, dataset=training_data, num_train_steps=1000, batch_size=2, grad_accum_every=16)
min_loss_path = '/root/musiclm-pytorch/results/mulan_min_loss.pt'
mulan_trainer.load(min_loss_path)

### Get most sim music for the given text description

In [None]:
from tqdm import tqdm

special_characters = {'&', ',', '"', "'", '/', ';', '“', '(', '‘', '’', '.', ')', '-', '\n', ':'}
def replace_special_characters_with_space(text):
    for char in special_characters:
        text = text.replace(char, ' ')
    return text
# input
query_text = ['This music features a classic piano solo, showcasing intricate melodies and expressive harmonies. The timeless elegance and nuanced performance create an immersive and captivating listening experience.']

query_text = [replace_special_characters_with_space(text) for text in query_text]


# get the latent representation of the query text
query_text_latent = mulan_trainer.mulan.get_text_latents(raw_texts=query_text)

# get the audio representation of the query text, highest similarity, iterate over all mulan dataset
max_similarity = 0
max_similarity_idx = 0
idx_simliarity_text_list = []
for idx in tqdm(range(len(training_data))):
    wav, txt = training_data[idx]
    # append fake batch
    wav = torch.unsqueeze(wav, 0).to(mulan_trainer.device)
    audio_latent = mulan_trainer.mulan.get_audio_latents(wav)
    
    # compute cosine similarity between two latents
    similarity = torch.nn.functional.cosine_similarity(query_text_latent, audio_latent).detach().cpu().numpy()
    
    idx_simliarity_text_list.append((idx, similarity, txt))
    
    if similarity > max_similarity:
        max_similarity = similarity
        max_similarity_idx = idx
        


In [None]:
import IPython.display as ipd
from IPython.display import display

# sort the similarity_text list by similarity (higest to lowest)
print(query_text)
idx_simliarity_text_list.sort(key=lambda x: x[1], reverse=True)
for sim_text in idx_simliarity_text_list[:10]:
    idx, sim, txt = sim_text
    print(f'{idx} {sim} - {txt}')
    # display(ipd.Audio(training_data[idx][0], rate=16000))

In [None]:
# show the best match
print(training_data[max_similarity_idx][1])
# play the audio of the best match
import IPython.display as ipd
ipd.Audio(training_data[max_similarity_idx][0], rate=16000)

### Text Encoder

In [None]:

'''
text1 = ['his voice.  song']
text2 = ['his voice   \n  song']
text3 = ['his voice song']

embed1 = mulan.get_text_latents(raw_texts=text1)
print(embed1.shape)
embed2 = mulan.get_text_latents(raw_texts=text2)
print(embed2.shape)
embed3 = mulan.get_text_latents(raw_texts=text3)

print(embed1.sum())
print(embed2.sum())
print(embed3.sum())
'''
