### comparing embedding spaces

I couldn't replicate the openl3 embedding to full precision so I'm going to do some dimensionality reduction to the embedding spaces on some audio samples and see how the two spaces behave wrt each other


In [1]:
import numpy as np

In [5]:
# first, write a MARL embed
import openl3 

def load_openl3_model():
    return openl3.models.load_audio_embedding_model(input_repr="mel128", embedding_size=512, content_type="music")

def openl3_marl(X, sr, model):
    assert X.ndim == 3, "must be shape (batch, channel, sample)"

    embeddings = []
    for x in X:
        x = x.squeeze(0)
        emb, _ = openl3.get_audio_embedding(x, sr, model=model, center=False, verbose=False)
        embeddings.append(emb[0])
    embeddings = np.stack(embeddings)



In [6]:
# now, write a torch embed
import torch
from instrument_recognition.models import torchopenl3

def load_torchl3_model():
    return torchopenl3.get_model(128, 512)

def openl3_torch(X, sr, model):
    assert X.ndim == 3
    assert sr == 48000

    X = torch.from_numpy(X).cuda()

    embeddings = model(X).detach().cpu().numpy()

    return embeddings




In [7]:
from instrument_recognition.datasets.base_dataset import BaseDataModule
import tqdm

max_samples = 5000

def debatch(batch):
    for k,v in batch.items():
        if isinstance(v, list):
            batch[k] = v[0]
    return batch
        

path_to_data = "/home/hugo/CHONK/data/mdb-hop-0.25-chunk-1-AUGMENTED/splits"
dm = BaseDataModule(path_to_data=path_to_data, batch_size=1, num_workers=10, use_embeddings=False, 
                    class_subset=['flute', 'trumpet'])
dm.setup()

# load models
torch_model = load_torchl3_model()
torch_model.eval()
torch_model.cuda()

marl_model = load_openl3_model()

# get the validation set bc its smaller
dl = dm.train_dataloader()
embeddings = []

pbar = tqdm.tqdm(dl)
for idx, entry in enumerate(pbar):
    entry = debatch(entry)
    
    torch_embedding = openl3_torch(entry['X'].numpy(), entry['sr'], torch_model)
    marl_embedding = openl3_marl(entry['X'].numpy(), entry['sr'], marl_model)

    torch_embeddings.append(torch_embedding)
    marl_embeddings.append(marl_embedding)

    if idx >= max_samples:
        continue

found 831404 entries
(('Main System', 57060), ('acoustic guitar', 28536), ('auxiliary percussion', 9768), ('banjo', 1128), ('bassoon', 7536), ('brass section', 2512), ('cello', 19708), ('claps', 272), ('clarinet', 6784), ('clarinet section', 496), ('clean electric guitar', 51816), ('cymbal', 2876), ('distorted electric guitar', 30320), ('double bass', 34244), ('drum machine', 18900), ('drum set', 88268), ('electric bass', 64696), ('female singer', 30580), ('flute', 14420), ('flute section', 212), ('french horn', 2796), ('fx/processed sound', 24456), ('glockenspiel', 2004), ('harmonica', 940), ('harp', 5868), ('male rapper', 2488), ('male singer', 29460), ('mandolin', 11656), ('oboe', 2436), ('piano', 72856), ('piccolo', 216), ('snare drum', 576), ('string section', 4552), ('synthesizer', 39908), ('tabla', 24080), ('tack piano', 3748), ('tambourine', 1380), ('tenor saxophone', 19796), ('timpani', 760), ('trombone', 1380), ('trumpet', 2248), ('trumpet section', 7356), ('tuba', 144), ('vi

KeyboardInterrupt: 