In [1]:
import torch

from Tokenizers import TokenizersConfig, Tokenizers

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import os
import glob
import torchaudio as ta
import torch
import warnings
import torchaudio
warnings.filterwarnings("ignore", category=FutureWarning)


from tqdm import tqdm
from resemblyzer import VoiceEncoder, preprocess_wav, normalize_volume, trim_long_silences
from pathlib import Path
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from scipy.io import wavfile

In [3]:
QUERY_ROOT = "/home/buffett/NAS_NTU/moisesdb/npyq"
STFT_ROOT = "/home/buffett/NAS_NTU/moisesdb/stft"

In [4]:
query_files = glob.glob(os.path.join(QUERY_ROOT, "**", "*.query-10s.npy"), recursive=True)
len(query_files)

4241

In [5]:
STFT = torchaudio.transforms.Spectrogram(
            n_fft=2048,
            win_length=2048,
            hop_length=512,
            pad_mode="constant",
            pad=0,
            window_fn=torch.__dict__["hann_window"],
            wkwargs=None,
            power=None,
            normalized=True,
            center=True,
            onesided=True,
        )

In [6]:
# Move the BEATs model to GPU
for query_file in tqdm(query_files):
    
    wav = np.load(query_file)
    wav = np.mean(wav, axis=0)
    wav = torch.tensor(wav, dtype=torch.float32).unsqueeze(0)
    
    # Encoder
    embed = STFT(wav)

    # Detach the tensor from the computation graph and convert it to a NumPy array
    embed_np = embed.detach().cpu().numpy()  # Move to CPU before converting to NumPy

    # Define the output path for saving the embedding
    resemb_path = query_file.replace(".query-10s.npy", ".stft.npy").replace(QUERY_ROOT, STFT_ROOT)

    # Create the necessary directories if they do not exist
    os.makedirs(os.path.dirname(resemb_path), exist_ok=True)
    
    # Save the embedding as a numpy file
    np.save(resemb_path, embed_np)

 31%|███       | 1320/4241 [16:39<36:51,  1.32it/s]  


OSError: [Errno 107] Transport endpoint is not connected

In [None]:
passts = glob.glob(os.path.join(STFT_ROOT, "**", "*.stft.npy"), recursive=True)

passt_data = []
stem_data = []

for passt_file in passts:
    passt = np.load(passt_file)
    passt_data.append(passt)

    stem = passt_file.split("/")[-1].split(".")[0]

    stem_data.append(stem)

In [None]:
len(passt_data)
passt_data[0]

In [None]:
passt_data = np.stack(passt_data)
passt_data.shape

In [None]:
passt_data = passt_data.squeeze(1)
passt_data.shape

In [None]:
passt_data = passt_data.mean(axis=1)
passt_data.shape

In [None]:
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, verbose=1)

tsne_data = tsne.fit_transform(passt_data)

In [None]:
FINE_LEVEL_INSTRUMENTS = {
    "lead_male_singer",
    "lead_female_singer",
    "human_choir",
    "background_vocals",
    "other_vocals",
    "bass_guitar",
    "bass_synthesizer",
    "contrabass_double_bass",
    "tuba",
    "bassoon",
    "snare_drum",
    "toms",
    "kick_drum",
    "cymbals",
    "overheads",
    "full_acoustic_drumkit",
    "drum_machine",
    "hihat",
    "fx",
    "click_track",
    "clean_electric_guitar",
    "distorted_electric_guitar",
    "lap_steel_guitar_or_slide_guitar",
    "acoustic_guitar",
    "other_plucked",
    "atonal_percussion",
    "pitched_percussion",
    "grand_piano",
    "electric_piano",
    "organ_electric_organ",
    "synth_pad",
    "synth_lead",
    "other_sounds",
    "violin",
    "viola",
    "cello",
    "violin_section",
    "viola_section",
    "cello_section",
    "string_section",
    "other_strings",
    "brass",
    "flutes",
    "reeds",
    "other_wind",
}

COARSE_LEVEL_INSTRUMENTS = {
    "vocals",
    "bass",
    "drums",
    "guitar",
    "other_plucked",
    "percussion",
    "piano",
    "other_keys",
    "bowed_strings",
    "wind",
    "other",
}

COARSE_TO_FINE = {
    "vocals": [
        "lead_male_singer",
        "lead_female_singer",
        "human_choir",
        "background_vocals",
        "other_vocals",
    ],
    "bass": [
        "bass_guitar",
        "bass_synthesizer",
        "contrabass_double_bass",
        "tuba",
        "bassoon",
    ],
    "drums": [
        "snare_drum",
        "toms",
        "kick_drum",
        "cymbals",
        "overheads",
        "full_acoustic_drumkit",
        "drum_machine",
        "hihat",
    ],
    "other": ["fx", "click_track"],
    "guitar": [
        "clean_electric_guitar",
        "distorted_electric_guitar",
        "lap_steel_guitar_or_slide_guitar",
        "acoustic_guitar",
    ],
    "other_plucked": ["other_plucked"],
    "percussion": ["atonal_percussion", "pitched_percussion"],
    "piano": ["grand_piano", "electric_piano"],
    "other_keys": ["organ_electric_organ", "synth_pad", "synth_lead", "other_sounds"],
    "bowed_strings": [
        "violin",
        "viola",
        "cello",
        "violin_section",
        "viola_section",
        "cello_section",
        "string_section",
        "other_strings",
    ],
    "wind": ["brass", "flutes", "reeds", "other_wind"],
}

COARSE_TO_FINE = {k: set(v) for k, v in COARSE_TO_FINE.items()}
FINE_TO_COARSE = {k: kk for kk, v in COARSE_TO_FINE.items() for k in v}

ALL_LEVEL_INSTRUMENTS = COARSE_LEVEL_INSTRUMENTS.union(FINE_LEVEL_INSTRUMENTS)

In [None]:
tsne_data = pd.concat([pd.DataFrame(tsne_data, columns=["x", "y"]), pd.Series(stem_data, name="stem")], axis=1)

In [None]:
tsne_data["coarse"] = tsne_data["stem"].apply(lambda x: FINE_TO_COARSE[x] if x in FINE_LEVEL_INSTRUMENTS else x)

In [None]:
tsne_fine = tsne_data[tsne_data["stem"].isin(FINE_LEVEL_INSTRUMENTS)]

plt.figure(figsize=(10, 10))
sns.scatterplot(data=tsne_fine, x="x", y="y", hue="stem", alpha=0.5)

In [None]:
tsne_coarse = tsne_data[tsne_data["stem"].isin(COARSE_LEVEL_INSTRUMENTS)]

plt.figure(figsize=(10, 10))
sns.scatterplot(data=tsne_coarse, x="x", y="y", hue="stem", alpha=0.5)

In [None]:
tsne_fine_coarse = tsne_data[tsne_data["stem"].isin(FINE_LEVEL_INSTRUMENTS)]

f, ax = plt.subplots(1, 1, figsize=(4.5, 4.5))
sns.scatterplot(data=tsne_fine_coarse,
                x="x", y="y",
                hue="coarse",
                edgecolor="none",
                alpha=0.5,
                style="coarse",
                palette="tab20",
                ax=ax
        )

ax.set(
    xlabel=None,
    ylabel=None,
    title="t-SNE of BEATs embeddings",
    xticklabels=[],
    yticklabels=[]
)

leg = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
for lh in leg.legend_handles:
    lh.set_alpha(1)

plt.show()

In [None]:
ALLOWED_INSTRUMENT = [
      "drums",
      "lead_male_singer",
      "lead_female_singer",
      # "human_choir",
      "background_vocals",
      # "other_vocals",
      "bass_guitar",
      "bass_synthesizer",
      # "contrabass_double_bass",
      # "tuba",
      # "bassoon",
      "fx",
      "clean_electric_guitar",
      "distorted_electric_guitar",
      # "lap_steel_guitar_or_slide_guitar",
      "acoustic_guitar",
      # "other_plucked",
      "pitched_percussion",
      "grand_piano",
      "electric_piano",
      "organ_electric_organ",
      "synth_pad",
      "synth_lead",
      # "violin",
      # "viola",
      # "cello",
      # "violin_section",
      # "viola_section",
      # "cello_section",
      "string_section",
      # "other_strings",
      "brass",
      # "flutes",
      "reeds",
      # "other_wind"
  ]

allowed_instruments_dtype = pd.CategoricalDtype(categories=ALLOWED_INSTRUMENT, ordered=True)

In [None]:
short_stem = {
    "drums": "Drums",
  "lead_male_singer": "Lead M. Vox",
  "lead_female_singer": "Lead F. Vox",
  # "human_choir",
  "background_vocals": "Bg. Vox",
  # "other_vocals",
  "bass_guitar": "Bass Gtr.",
  "bass_synthesizer": "Bass Synth",
  # "contrabass_double_bass",
  # "tuba",
  # "bassoon",
  "fx": "Fx",
  "clean_electric_guitar": "Clean. E. Gtr.",
  "distorted_electric_guitar": "Dist. E. Gtr.",
  # "lap_steel_guitar_or_slide_guitar",
  "acoustic_guitar": "A. Gtr",
  "other_plucked": "Other Plucked",
  "pitched_percussion": "Pitched Perc.",
  "grand_piano": "Grand Piano",
  "electric_piano": "E. Piano",
  "organ_electric_organ": "Organ",
  "synth_pad": "Synth Pad",
  "synth_lead": "Synth Lead",
  # "violin",
  # "viola",
  # "cello",
  # "violin_section",
  # "viola_section",
  # "cello_section",
  "string_section": "Str. Sect.",
  "other_strings": "Other Str.",
  "brass": "Brass",
  # "flutes",
  "reeds": "Reeds",
  "other_wind": "Other Wind"
}

In [None]:
tsne_allowed = tsne_data[tsne_data["stem"].isin(ALLOWED_INSTRUMENT)].copy()

tsne_allowed["stem"] = tsne_allowed["stem"].astype(allowed_instruments_dtype)


f, ax = plt.subplots(1, 1, figsize=(6, 3.5))

sns.scatterplot(data=tsne_allowed,
                x="x", y="y",
                hue="stem",
                edgecolor="none",
                alpha=0.5,
                size=1,
                style="coarse",
                palette="tab10",
                ax=ax
        )

ax.set(
    xlabel=None,
    ylabel=None,
    xticklabels=[],
    yticklabels=[]
)

h, l = ax.get_legend_handles_labels()

stem_to_color = {ll: hh.get_markerfacecolor() for hh, ll in zip(h, l)}
coarse_to_marker = {ll: hh.get_marker() for hh, ll in zip(h, l)}


newh1 = []
newl1 = []

newh2 = []
newl2 = []

centroids = tsne_allowed.groupby("stem")[["x", "y"]].mean()

n_stems = len(ALLOWED_INSTRUMENT)

for stem in ALLOWED_INSTRUMENT:
    coarse_stem = FINE_TO_COARSE[stem] if stem in FINE_LEVEL_INSTRUMENTS else stem

    if True:
    #if centroids.loc[stem, "x"] < 0 or len(newh2) >= n_stems // 2:
        newh1.append(plt.Line2D([0], [0],
                                marker=coarse_to_marker[coarse_stem],
                                color=stem_to_color[stem],
                                markersize=5,
                                linestyle="None"))
        newl1.append(short_stem[stem])
    else:
        newh2.append(plt.Line2D([0], [0],
                                marker=coarse_to_marker[coarse_stem],
                                color=stem_to_color[stem],
                                markersize=5,
                                linestyle="None"))
        newl2.append(short_stem[stem])

ax.get_legend().remove()

l1 = plt.legend(newh1, newl1, loc="center left", borderaxespad=0., fontsize="small",
                frameon=True, handletextpad=0., columnspacing=0.8, labelspacing=0.1,
                framealpha=0.5, edgecolor="none")

ax.add_artist(l1)
# ax.add_artist(l2)

ax.set(xlim=[-110, 75])

centroids = tsne_allowed.groupby("stem")[["x", "y"]].mean()

plt.savefig("beats.pdf", bbox_inches="tight")

plt.show()