In [None]:
import os
from glob import glob
import random

import pandas as pd
from tqdm.notebook import tqdm
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import joblib

import librosa

import warnings
warnings.filterwarnings('ignore') # to silence librosa warnings

from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv1D, Conv2D, MaxPooling1D, MaxPooling2D, Flatten, Dense, Lambda, Dropout, GlobalAveragePooling1D, GlobalMaxPooling1D, Concatenate, LeakyReLU
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.metrics import Precision, Recall

tqdm.pandas()

import seaborn as sns

from triplet_dataset import TripletDataset

%load_ext dotenv
%dotenv

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

## Dataset

In [None]:
triplets = TripletDataset(os.environ['PATH_TO_TRACKS'], n=1)

In [None]:
print(len(triplets.df))
triplets.df.head()

In [None]:
class Track:
    def __init__(self, filepath: str, sr: int = 22050) -> None:
        self.filepath = filepath
        self.sr = sr
    
    def _normalize_mel_spectrogram(mel_spec: np.ndarray) -> np.ndarray:
        max_val = np.max(mel_spec)
        min_val = np.min(mel_spec)
        normalized_spectrogram = (mel_spec - min_val) / (max_val - min_val)

        return normalized_spectrogram
    
    def audio_extract(self, from_sec: int, to_sec: int) -> np.ndarray:
        audio, _ = librosa.load(
            self.filepath,
            mono=True,
            sr=self.sr,
            offset=from_sec,
            duration=to_sec - from_sec
        )

        if audio is None:
            raise Exception("Something went wrong went reading extract")

        return audio
    
    def spectrogram(self, from_sec: int = 30, to_sec: int = 36) -> np.ndarray:
        extract = self.audio_extract(from_sec, to_sec)

        spec = librosa.feature.melspectrogram(y=extract, sr=self.sr, n_fft=512, hop_length=128)
        spec_db = librosa.power_to_db(S=spec, ref=np.max)
        spec_db_norm = Track._normalize_mel_spectrogram(spec_db)

        return spec_db_norm
    
    #def trispectrogram(self, offset: float = 1.0) -> np.ndarray:
    #    """
    #    Take 3 spectrogram of n seconds at 25%, 50% and 75% of the track into one

    #    Params
    #    ======
    #    `offset`: offset the start of the spectrograms by `offset` percent.
    #    Usefull for data augmentation
    #    """
    #    total_length_sec = len(self.audio) / self.sr
    #    n = 5

    #    start_25 = int(0.25 * offset * total_length_sec)
    #    stop_25 = start_25 + n
    #    start_50 = int(0.50 * offset * total_length_sec)
    #    stop_50 = start_50 + n
    #    start_75 = int(0.75 * offset * total_length_sec)
    #    stop_75 = start_75 + n

    #    spec_1 = self.spectrogram(start_25, stop_25)
    #    spec_2 = self.spectrogram(start_50, stop_50)
    #    spec_3 = self.spectrogram(start_75, stop_75)

    #    return np.concatenate([spec_1, spec_2, spec_3], axis=1)

## Cache to spectrograms

Because reading audio files is slow. We store the spectograms directly

In [None]:
def save_dict_to_file(data_dict, filename):
    if not os.path.exists(filename):
        joblib.dump(data_dict, filename)

def load_dict_from_file(filename):
    data_dict = {}
    if os.path.exists(filename):
        data_dict = joblib.load(filename)
    return data_dict

spectrogram_cache: dict = load_dict_from_file('spectrogram_cache.joblib')
all_unique_track_paths = pd.unique(triplets.df.values.ravel())

if not spectrogram_cache:
    spectrogram_cache = {}

    for track_path in tqdm(all_unique_track_paths):
        track = Track(track_path)

        spectrogram_cache[track_path] = track.spectrogram()

    #save_dict_to_file(spectrogram_cache, 'spectrogram_cache.joblib')

# load new files that are not present in cache
count = 0
for track_path in tqdm(all_unique_track_paths):
    if track_path not in spectrogram_cache.keys():
        track = Track(track_path)

        spectrogram_cache[track_path] = track.trispectrogram()
        count += 1

if count > 0:
    #save_dict_to_file(spectrogram_cache, 'spectrogram_cache.joblib')
    pass

In [None]:
spectrogram_cache[list(spectrogram_cache.keys())[0]].shape

In [None]:
def plot_spec(track: str, spectrogram_cache: dict):
    plt.figure(figsize=(10, 2))

    librosa.display.specshow(spectrogram_cache[track], y_axis='mel')
    plt.colorbar()

    plt.tight_layout()
    plt.title(os.path.basename(track))
    plt.show()

In [None]:
for _, triplet in triplets.df[:3].iterrows():
    print("="*80)
    print("Anchor")
    plot_spec(triplet['anchor'], spectrogram_cache)
    print("Positive")
    plot_spec(triplet['positive'], spectrogram_cache)
    print("Negative")
    plot_spec(triplet['negative'], spectrogram_cache)

## Training data

In [None]:
def training_data(triplets, spectrogram_cache):
    X_anchor = np.array([spectrogram_cache[path].T for path in triplets.df['anchor']])
    X_positive = np.array([spectrogram_cache[path].T for path in triplets.df['positive']])
    X_negative = np.array([spectrogram_cache[path].T for path in triplets.df['negative']])

    # For triplet loss, y is not directly used during training
    # We can return X and an empty y or just X
    return X_anchor, X_positive, X_negative

X_anchor, X_positive, X_negative = training_data(triplets, spectrogram_cache)

print(f"X_anchor shape: {X_anchor.shape}")
print(f"X_positive shape: {X_positive.shape}")
print(f"X_negative shape: {X_negative.shape}")

## Build model

In [None]:
# for compatibility with CNN, add one channel
X_anchor_gray = X_anchor.reshape(X_anchor.shape[0], X_anchor.shape[1], X_anchor.shape[2])
X_positive_gray = X_positive.reshape(X_positive.shape[0], X_positive.shape[1], X_positive.shape[2])
X_negative_gray = X_negative.reshape(X_negative.shape[0], X_negative.shape[1], X_negative.shape[2])

In [None]:
input_shape = X_anchor_gray.shape[1:]
input_shape

In [None]:
class GlobalL2Pooling1D(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.sqrt(tf.reduce_sum(tf.square(inputs), axis=1))

leaky_relu_layer = LeakyReLU(alpha=0.3)

# Base network is from:
# (1) Recommending music on Spotify with deep learning. Sander Dieleman. https://sander.ai/2014/08/05/spotify-cnns.html (accessed 2024-03-23).
def build_base_network(input_shape):
    inputs = Input(shape=input_shape)

    x = Conv1D(filters=256, kernel_size=4, activation=leaky_relu_layer)(inputs)
    x = MaxPooling1D(pool_size=4)(x)

    x = Conv1D(filters=256, kernel_size=4, activation=leaky_relu_layer)(x)
    x = MaxPooling1D(pool_size=2)(x)

    x = Conv1D(filters=512, kernel_size=4, activation=leaky_relu_layer)(x)

    # global temporal pooling
    mean_pool = GlobalAveragePooling1D()(x)
    max_pool = GlobalMaxPooling1D()(x)
    l2_pool = GlobalL2Pooling1D()(x)

    pooled_features = Concatenate()([mean_pool, max_pool, l2_pool])

    x = Dense(2048, activation=leaky_relu_layer)(pooled_features)
    x = Dense(2048, activation=leaky_relu_layer)(x)

    outputs = Dense(64)(x)

    return Model(inputs, outputs)

def build_siamese_network(base_network, input_shape):
    input_anchor = Input(shape=input_shape, name="anchor_input")
    input_positive = Input(shape=input_shape, name="positive_input")
    input_negative = Input(shape=input_shape, name="negative_input")

    embeddings_anchor = base_network(input_anchor)
    embeddings_positive = base_network(input_positive)
    embeddings_negative = base_network(input_negative)
    
    outputs = tf.concat([embeddings_anchor, embeddings_positive, embeddings_negative], axis=1)
    
    siamese_network = Model(inputs=[input_anchor, input_positive, input_negative], outputs=outputs)

    return siamese_network

def triplet_loss(y_true, y_pred, margin = 0.2):
    anchor, positive, negative = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2]
    
    pos_dist = K.sum(K.square(anchor - positive), axis=-1)
    neg_dist = K.sum(K.square(anchor - negative), axis=-1)
    
    loss = K.maximum(0.0, pos_dist - neg_dist + margin)
    
    return loss

base_network = build_base_network(input_shape)
model = build_siamese_network(base_network, input_shape)

model.compile(optimizer=Adam(learning_rate=0.00001), loss=triplet_loss)

model.summary()

In [None]:
base_network.summary()

## Train

In [None]:
X_train_anchor, X_test_anchor, X_train_positive, X_test_positive, X_train_negative, X_test_negative = train_test_split(
    X_anchor_gray, X_positive_gray, X_negative_gray, test_size=0.2, random_state=0)

In [None]:
input_shape

In [None]:
history = model.fit(
    [
        X_train_anchor.reshape((-1, input_shape[0], input_shape[1])),
        X_train_positive.reshape((-1, input_shape[0], input_shape[1])),
        X_train_negative.reshape((-1, input_shape[0], input_shape[1]))
    ],
    np.zeros_like(X_train_anchor), # dummy labels
    epochs=5,
    batch_size=8
)

# summarize history for loss
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [None]:
def spec_dist(left_spec: np.ndarray, right_spec: np.ndarray):
    spec_left = left_spec.reshape((1, input_shape[0], input_shape[1]))
    spec_right = right_spec.reshape((1, input_shape[0], input_shape[1]))

    embeddings_left = base_network.predict(spec_left, verbose=False).ravel()
    embeddings_right = base_network.predict(spec_right, verbose=False).ravel()

    euclidean_dist = np.linalg.norm(embeddings_left - embeddings_right)
    return euclidean_dist

def track_dist(left_path: str, right_path: str):
    spec_left = spectrogram_cache[left_path].reshape((1, input_shape[0], input_shape[1], input_shape[2]))
    spec_right = spectrogram_cache[right_path].reshape((1, input_shape[0], input_shape[1], input_shape[2]))

    return spec_dist(spec_left, spec_right)

similars_dist = []
differents_dist = []
for x_a, x_p, x_n in tqdm(zip(X_test_anchor, X_test_positive, X_test_negative), total=X_test_anchor.shape[0]):
    similars_dist.append(spec_dist(x_a, x_p))
    differents_dist.append(spec_dist(x_a, x_n))

In [None]:
sns.histplot(similars_dist)

In [None]:
sns.histplot(differents_dist)