In [1]:
import os 

import librosa as li
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime
import panel as pn
pn.extension('ipywidgets')
import sklearn.decomposition
import soundfile as sf
import torch
torch.set_grad_enabled(False)
from tqdm import tqdm

_AVAILABLE_EXTS =  [k.lower() for k in  sf.available_formats().keys()]

In [2]:
LOAD = True
SAVE = False
ort_session = onnxruntime.InferenceSession("/data/genova/msclap_onnx/clap.onnx", providers=["CPUExecutionProvider"])
samples_folder = '/data/genova/datasets/Drum kits/Hip Hop & Lofi/Boom Bap Essentials'
if samples_folder is None:
    raise ValueError('Please select a valid folder of samples')
files = []
for root, _, _files in os.walk(samples_folder):
    for f in tqdm(_files, leave=False):
        if np.any([f.endswith(ext) for ext in _AVAILABLE_EXTS]):
            files.append(os.path.join(root, f)) 
embeddings = []

if LOAD:
    with open('embeddings.npy', 'rb') as f:
        embeddings = np.load(f)
else:
    for file in tqdm(files):
        y, sr = li.load(path=file)
        if y.shape[0]==2:
            y = y.mean(0)[None, :]
        if sr!=22050:
            y = li.resample(y, sr, 22050)
        if y.ndim==1:
            y = y[None, :]
        if y.shape[-1]<int(.5*22050): # Minimal duration for CLAP seems to be 0.5 seconds, so we need to pad everything
            delta = int(.5*22050)-y.shape[-1]
            y = np.concatenate((y, np.zeros((1, delta))), axis=-1)
        y = y.astype(np.float32)
        ort_inputs = {ort_session.get_inputs()[0].name: y}
        ort_outs = ort_session.run(None, ort_inputs)
        embeddings.append(ort_outs[0])
    embeddings = np.concatenate(embeddings, axis=0)
    if SAVE:
        with open('embeddings.npy', 'wb') as f:
            np.save(f, embeddings)

                                                                                                                                                                                                                              

In [3]:
embs = embeddings-embeddings.mean(0)
pca = sklearn.decomposition.PCA(n_components=2)
pca.fit(embs)

In [4]:
components = pca.components_

In [5]:
embs_2d = embs @ components.T

In [6]:
NORM = False
SCALE = 0

if NORM:
    x = 2*((embs_2d[:, 0]-embs_2d[:, 0].min())/(embs_2d[:, 0].max()-embs_2d[:, 0].min()))-1
    y = 2*((embs_2d[:, 1]-embs_2d[:, 1].min())/(embs_2d[:, 1].max()-embs_2d[:, 1].min()))-1
else:
    x = embs_2d[:, 0]
    y = embs_2d[:, 1]

if SCALE:
    x*=SCALE
    y*=SCALE

In [9]:
import plotly.graph_objects as go
from IPython.display import display, Audio

f = go.FigureWidget([go.Scatter(x=x, y=y, mode='markers')])

scatter = f.data[0]
colors = ['#a3a7e4'] * len(files)
scatter.marker.color = colors
scatter.marker.size = [10] * len(files)

to_display = []
# create our callback function
def update_point(trace, points, selector):
    c = list(scatter.marker.color)
    s = list(scatter.marker.size)
    for i in points.point_inds:
        to_display.append(files[i])
        display(Audio(filename=files[i]))
        c[i] = '#bae2be'
        s[i] = 20
        with f.batch_update():
         scatter.marker.color = c
         scatter.marker.size = s

scatter.on_click(update_point)
f

FigureWidget({
    'data': [{'marker': {'color': [#a3a7e4, #a3a7e4, #a3a7e4, ..., #a3a7e4,
                                   #a3a7e4, #a3a7e4],
                         'size': [10, 10, 10, ..., 10, 10, 10]},
              'mode': 'markers',
              'type': 'scatter',
              'uid': 'ca3b3fbf-05a8-4442-88f4-babf7a718ea4',
              'x': array([-14.009306 , -10.240442 ,  -7.3622766, ...,  -6.9002547, -12.474204 ,
                          -14.305709 ], dtype=float32),
              'y': array([-14.591681 ,  -1.1837027,  -1.2252506, ...,  10.14129  ,  -6.601142 ,
                           -3.7851405], dtype=float32)}],
    'layout': {'template': '...'}
})

In [26]:
import random
NUM_SAMPLES = 5 
rand_idx, rand_file = random.choice(list(enumerate(files)))
print('Target audio')
display(Audio(filename=rand_file))

print('Without PCA')
rand_coords = embeddings[rand_idx][None, :]
dists = np.sqrt(((embeddings - rand_coords)**2).sum(-1))
sorted_idx = np.argsort(dists)
for i in sorted_idx[:NUM_SAMPLES+1]:
    display(Audio(filename=files[i]))

print('Using PCA')
rand_coords = embs_2d[rand_idx][None, :]
dists = np.sqrt(((embs_2d - rand_coords)**2).sum(-1))
sorted_idx = np.argsort(dists)
for i in sorted_idx[:NUM_SAMPLES+1]:
    display(Audio(filename=files[i]))

Target audio


Without PCA


Using PCA


In [8]:
for f in to_display:
    display(Audio(filename=f))