[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbrotos/SoundSeg/blob/main/demo.ipynb)

# Singing Voice Separation by U-Net

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab, cloning repo, and installing requirements ...')
    !git clone https://github.com/mbrotos/SoundSeg.git
    %cd SoundSeg/src
    !pip install -r requirements.txt > /dev/null
    !nvidia-smi
    print('Make sure to enable GPU acceleration in CoLab by going to Edit > Notebook Settings > Hardware Accelerator')
else:
    print('Not running on CoLab, creating local environment and installing requirements ...')
    %cd src
    !python -m venv venv
    !source ./venv/bin/activate
    !pip install -r requirements.txt

In [None]:
import os
import librosa
import IPython.display as ipd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import argparse
import config as cfg
import datetime
from scaler import normalize, denormalize
import pickle
import json
from augmentations import consecutive_oversample, blackout
from uuid import uuid4
import mir_eval

### Data Loading

In [None]:
# Obtain dataset from Zenodo
# MUSDB18 - a corpus for music separation https://doi.org/10.5281/zenodo.3338373
# NOTE: This dataset is quite large (>20GB -- zipped) and may take a while to download.

# Alternatively, you can download a subset of the dataset (~9GB) with the following command: 
# https://drive.google.com/file/d/1_kdifA4ztVXBveb9FYzmY49fvAKZmIJF/view?usp=sharing
!gdown 1_kdifA4ztVXBveb9FYzmY49fvAKZmIJF && unzip data_wav.zip && rm data_wav.zip


"""
# If you are running on Google Colab, you can mount your Google Drive and save the dataset there.
# This helps avoid having to download the dataset every time you run the notebook.
# Replace "School/EE8223/Code" with the path to your folder on Google Drive.

from google.colab import drive
drive.mount('/content/drive')
!cp /content/drive/MyDrive/School/EE8223/Code/data_wav.zip . && unzip data_wav.zip && rm data_wav.zip
"""

# This cell may take 5 minutes to run. Please be patient.

In [None]:
# Create required directories
!mkdir -p models processed_data

# Preprocess dataset
!python preprocessing.py --dsType train
!python preprocessing.py --dsType test

# Tensorflow dataset prep
!python dataset_prep.py

# See ./src/preprocessing.py for more details on the preprocessing steps.
# See ./src/dataset_prep.py for more details on the dataset preparation steps.

"""Some of these details will be shown in the Evaluate the model section below."""

#### Song mixutre and vocal example

In [None]:
mix_mags_train = np.load("./processed_data/mix_mags_train_512x128.npy", mmap_mode='r' )
mix_phases_train = np.load("./processed_data/mix_phases_train_512x128.npy", mmap_mode='r')
vocal_train = np.load(f"./processed_data/vocal_mags_train_512x128.npy",mmap_mode='r')
print(mix_mags_train.shape, mix_phases_train.shape, vocal_train.shape)

In [None]:
mix_wav = librosa.istft(mix_mags_train[0,:,:,0] * mix_phases_train[0,:,:,0], hop_length=cfg.HOP_SIZE)
vocal_wav = librosa.istft(vocal_train[0,:,:,0] * mix_phases_train[0,:,:,0], hop_length=cfg.HOP_SIZE)

In [None]:
print("Example of a song segment mixture used for training:")
ipd.Audio(mix_wav, rate=cfg.SR)

In [None]:
print("Example of a song segment vocal used for training:")
ipd.Audio(vocal_wav, rate=cfg.SR)

#### Visualize the data

In [None]:
def plot_spectrogram_with_shading(spectrogram, title, subplot_index, shade_start=None, shade_end=None):
    plt.subplot(subplot_index)
    plt.imshow(spectrogram, aspect='auto', origin='lower')
    if shade_start is not None and shade_end is not None:
        plt.axvspan(shade_start, shade_end, color='red', alpha=0.2)
    plt.title(title)
    plt.colorbar()

In [None]:
plt.figure(figsize=(18, 6))
plot_spectrogram_with_shading(mix_mags_train[0,:,:,0], "Mixture Magnitude", 131)
plot_spectrogram_with_shading(vocal_train[0,:,:,0], "Vocal Magnitude", 132)

### Run unit tests

In [None]:
!python test_audio_processing.py
# See ./src/test_audio_processing.py for more details on the audio processing unit tests I created for this project.

### Define Model Architecture

In [None]:
import tensorflow as tf
from keras.layers import Activation, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, MaxPooling2D, Input, Conv1D, Normalization

def get_model(img_size, num_classes=1):
    inputs = Input(shape=img_size + (1,))

    conv1 = Conv2D(64, 3, strides=1, padding="same")(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation("relu")(conv1)

    conv2 = Conv2D(64, 3, strides=1, padding="same")(conv1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation("relu")(conv2)

    pool1 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, strides=1, padding="same")(pool1)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation("relu")(conv3)

    conv4 = Conv2D(128, 3, strides=1, padding="same")(conv3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation("relu")(conv4)

    pool2 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(256, 3, strides=1, padding="same")(pool2)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation("relu")(conv5)

    conv6 = Conv2D(256, 3, strides=1, padding="same")(conv5)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation("relu")(conv6)

    pool3 = MaxPooling2D(pool_size=(2, 2))(conv6)

    conv7 = Conv2D(512, 3, strides=1, padding="same")(pool3)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation("relu")(conv7)

    conv8 = Conv2D(512, 3, strides=1, padding="same")(conv7)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation("relu")(conv8)

    pool4 = MaxPooling2D(pool_size=(2, 2))(conv8)

    conv9 = Conv2D(1024, 3, strides=1, padding="same")(pool4)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation("relu")(conv9)

    conv10 = Conv2D(1024, 3, strides=1, padding="same")(conv9)
    conv10 = BatchNormalization()(conv10)
    conv10 = Activation("relu")(conv10)

    up1 = Conv2DTranspose(512, 2, strides=2, padding="same")(conv10)
    up1 = Concatenate()([up1, conv8])

    upconv1 = Conv2D(512, 3, strides=1, padding="same")(up1)
    upconv1 = BatchNormalization()(upconv1)
    upconv1 = Activation("relu")(upconv1)

    upconv2 = Conv2D(512, 3, strides=1, padding="same")(upconv1)
    upconv2 = BatchNormalization()(upconv2)
    upconv2 = Activation("relu")(upconv2)

    up2 = Conv2DTranspose(256, 2, strides=2, padding="same")(upconv2)
    up2 = Concatenate()([up2, conv6])

    upconv3 = Conv2D(256, 3, strides=1, padding="same")(up2)
    upconv3 = BatchNormalization()(upconv3)
    upconv3 = Activation("relu")(upconv3)

    upconv4 = Conv2D(256, 3, strides=1, padding="same")(upconv3)
    upconv4 = BatchNormalization()(upconv4)
    upconv4 = Activation("relu")(upconv4)

    up3 = Conv2DTranspose(128, 2, strides=2, padding="same")(upconv4)
    up3 = Concatenate()([up3, conv4])

    upconv5 = Conv2D(128, 3, strides=1, padding="same")(up3)
    upconv5 = BatchNormalization()(upconv5)
    upconv5 = Activation("relu")(upconv5)

    upconv6 = Conv2D(128, 3, strides=1, padding="same")(upconv5)
    upconv6 = BatchNormalization()(upconv6)
    upconv6 = Activation("relu")(upconv6)

    up4 = Conv2DTranspose(64, 2, strides=2, padding="same")(upconv6)
    up4 = Concatenate()([up4, conv2])

    upconv7 = Conv2D(64, 3, strides=1, padding="same")(up4)
    upconv7 = BatchNormalization()(upconv7)
    upconv7 = Activation("relu")(upconv7)

    upconv8 = Conv2D(64, 3, strides=1, padding="same")(upconv7)
    upconv8 = BatchNormalization()(upconv8)
    upconv8 = Activation("relu")(upconv8)

    output = Conv1D(num_classes, 1, activation="linear")(upconv8)

    # Define the model
    model = tf.keras.Model(inputs, output)
    return model

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10, help='Number of epochs to train for')
parser.add_argument("--batch_size", type=int, default=5, help='Batch size for training')
parser.add_argument("--normalization", type=str, default="frequency", help='Normalization axis (time or frequency)')
parser.add_argument("--lr", type=float, default=1e-3, help='Learning rate for training')
parser.add_argument("--mask", action="store_true", default=False, help='Experimental. Causes unstable training.')
parser.add_argument("--quantile_scaler", action="store_true", default=False, help='Toggle quantile scaling as the normalization method')
parser.add_argument("--q_min", type=float, default=25.0, help='Minimum quantile for quantile scaling')
parser.add_argument("--q_max", type=float, default=75.0, help='Maximum quantile for quantile scaling')
parser.add_argument("--loss", type=str, default="mse", help='Loss function to use (mse or mae)')
parser.add_argument("--dataset_size", type=int, default=None, help='Number of samples to use from the dataset (None = all)')
parser.add_argument("--augmentations", action="store_true", default=False, help='Toggle data augmentations (splicing, and blackout)')
parser.add_argument("--seed", type=int, default=42, help='Random seed for reproducibility')
parser.add_argument("--mmap", action="store_true", default=True, help='Toggle memory mapping for dataset loading (helps with large datasets and limited RAM)')

# Top performaning model args, the batch size has been reduced from 64 to 15 to fit on colab GPUs
# The epochs were also reduced from 20 to 5 to save time
args = parser.parse_args(['--normalization', 'frequency', '--epochs', '5', '--batch_size', '15', '--loss', 'mae', '--augmentations'])

print('Args:')
print(args)

np.random.seed(args.seed)
tf.random.set_seed(args.seed)

### Define datasets

In [None]:
# load data
mix_mags_train = np.load("./processed_data/mix_mags_train_512x128.npy", mmap_mode='r' if args.mmap else None)[:args.dataset_size]
mix_phases_train = np.load("./processed_data/mix_phases_train_512x128.npy", mmap_mode='r' if args.mmap else None)[:args.dataset_size]
vocal_train = np.load( f"./processed_data/vocal_mags_train_512x128.npy", mmap_mode='r' if args.mmap else None)[:args.dataset_size]

mix_mags_train_norm, vocal_train_norm, mix_mags_train_norm_factors = normalize(
    np.copy(mix_mags_train),
    np.copy(vocal_train),
    normalization=args.normalization,
    quantile_scaler=args.quantile_scaler,
    q_min=args.q_min,
    q_max=args.q_max,
)

In [None]:
if args.augmentations:
    print('Appling augmentations...')
    
    # Remove outliers
    true_vocal = denormalize(
        vocal_train_norm,
        mix_mags_train_norm_factors,
        normalization=args.normalization,
        quantile_scaler=args.quantile_scaler,
    )

    vocal_waves = []

    for i in range(0, len(true_vocal)):
        cur_phase = np.concatenate(mix_phases_train[i : i + 1], axis=1)
        cur_true_vocal = np.concatenate(true_vocal[i : i + 1], axis=1)
        vocal_waves.append(librosa.istft(
                cur_true_vocal[:, :, 0] * cur_phase[:, :, 0],
                hop_length=cfg.HOP_SIZE,
                window="hann",
            )
        )
    vocal_waves = np.array(vocal_waves)
    dist = np.abs(vocal_waves).sum(axis=1)
    indices = np.where(dist < 100)[0]

    mix_mags_train_norm = np.delete(mix_mags_train_norm, indices, axis=0)
    vocal_train_norm = np.delete(vocal_train_norm, indices, axis=0)
    mix_mags_train_norm_factors = np.delete(mix_mags_train_norm_factors, indices, axis=0)
    mix_phases_train = np.delete(mix_phases_train, indices, axis=0)
    
    # Splicing and blackout
    mix_blackout, vocal_blackout = blackout(mix_mags_train_norm, vocal_train_norm)
    mix_blackout = mix_blackout[:mix_blackout.shape[0]//4]
    vocal_blackout = vocal_blackout[:vocal_blackout.shape[0]//4]
    mix_consec, vocal_consec = consecutive_oversample(mix_mags_train_norm, vocal_train_norm)
    mix_consec = mix_consec[:mix_consec.shape[0]//2]
    vocal_consec = vocal_consec[:vocal_consec.shape[0]//2]

    mix_mags_train_norm = np.concatenate((mix_mags_train_norm, mix_consec, mix_blackout), axis=0)
    vocal_train_norm = np.concatenate((vocal_train_norm, vocal_consec, vocal_blackout), axis=0)
    
    
    
    
print('Datasets:')
print(f'Mixes: {mix_mags_train_norm.shape}')
print(f'Vocals: {vocal_train_norm.shape}')

#### Augmentation Visualizations

In [None]:
i=2
mix_blackout_test, vocal_blackout_test = blackout(mix_mags_train_norm[i:i+1], vocal_train_norm[i:i+1])
print('Notice the blackout regions in the spectrograms below.')
# Plotting
plt.figure(figsize=(12, 8))

# Original Mix Spectrogram
plot_spectrogram_with_shading(mix_mags_train_norm[i], 'Original Mix', 321)

# Original Vocal Spectrogram
plot_spectrogram_with_shading(vocal_train_norm[i], 'Original Vocal', 322)

# Blackout Mix Spectrogram
plot_spectrogram_with_shading(mix_blackout_test[0], 'Blackout Mix', 323)

# Blackout Vocal Spectrogram
plot_spectrogram_with_shading(vocal_blackout_test[0], 'Blackout Vocal', 324)

plt.tight_layout()
plt.show()

In [None]:
print('Notice the Consecutive Oversample Mix is a mix of the original mix and the mix with 1 sample shifted over.')
mix_consec_test, _ = consecutive_oversample(mix_mags_train_norm[i:i+2], vocal_train_norm[i:i+1])
# Plotting Consecutive Oversample with shaded highlights
plt.figure(figsize=(18, 6))

# Original Mix Spectrogram with shading on the last half
plot_spectrogram_with_shading(mix_mags_train_norm[i], 'Original Mix', 131, shade_start=64, shade_end=128)

# Original Mix + 1 Spectrogram with shading on the first half
plot_spectrogram_with_shading(mix_mags_train_norm[i+1], 'Original Mix + 1', 132, shade_start=0, shade_end=64)

# Consecutive Oversample Mix Spectrogram with corresponding shaded highlights
plot_spectrogram_with_shading(mix_consec_test[0], 'Consecutive Oversample Mix', 133)

plt.tight_layout()
plt.show()


In [None]:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model_name = f"model_{timestamp}_{uuid4().hex}"
    
# Save normalization factors as pkl
with open(f"./models/scaler--{model_name}.pkl", "wb") as f:
    pickle.dump(mix_mags_train_norm_factors, f)

data_len = len(mix_mags_train_norm)

# shuffle datasets before splitting
indices = np.arange(data_len)
np.random.shuffle(indices)
mix_mags_train_norm = mix_mags_train_norm[indices]
vocal_train_norm = vocal_train_norm[indices]

val_len = int(data_len * 0.1)
val_data = (
    mix_mags_train_norm[-val_len:],
    vocal_train_norm[-val_len:],
)
train_data = (
    mix_mags_train_norm[:-val_len],
    vocal_train_norm[:-val_len],
)

dataset = tf.data.Dataset.from_tensor_slices(train_data)
dataset = (
    dataset.shuffle(args.batch_size * 2)
    .batch(args.batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
val_ds = tf.data.Dataset.from_tensor_slices(val_data)
val_ds = val_ds.batch(args.batch_size).prefetch(tf.data.AUTOTUNE)

### Build Model

In [None]:
model = get_model((cfg.FREQUENCY_BINS, cfg.SAMPLE_SZ), num_classes=1)

model.compile(
    optimizer=tf.keras.optimizers.Adam(args.lr),
    loss=args.loss,
)

os.makedirs(f"./models/{model_name}")
os.makedirs(f"./models/{model_name}/logs")

# Save args as json
with open(f"./models/{model_name}/args.json", "w") as f:
    json.dump(vars(args), f, indent=4)
    
model.summary()

### Train the model

In [None]:
history = model.fit(
    dataset,
    validation_data=val_ds,
    epochs=args.epochs,
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint(
            filepath=f"./models/{model_name}/{model_name}-val.hdf5", save_best_only=True, monitor='val_loss', mode='min', save_weights_only=True
        ),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=f"./models/{model_name}/{model_name}-train.hdf5", save_best_only=True, monitor='loss', mode='min', save_weights_only=True
        ),
        tf.keras.callbacks.TensorBoard(
            log_dir=f"./models/{model_name}/logs", histogram_freq=1
        ),
    ],
)

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

# Extracting the number of epochs
epochs = range(1, len(loss) + 1)

# Plotting the loss and validation loss
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, 'bo-', label='Training loss')
plt.plot(epochs, val_loss, 'ro-', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

### Evaluate the model

#### Test voice seperation

##### Download and extract a song

In [None]:
# This requires ffmpeg to be installed on your system


# Replace url with the song you want to test
url = 'https://www.youtube.com/watch?v=TLV4_xaYynY' # "All Along The Watchtower" by The Jimi Hendrix Experience
!yt-dlp {url} -f mp4 -o ./test_audio.mp4 && ffmpeg -i ./test_audio.mp4 -ac 1 -ar {cfg.SR} ./test_audio.wav -y

song = librosa.load('./test_audio.wav', sr=cfg.SR, mono=True)[0]

In [None]:
# Trim song and test song audio
start_time_sec = 18
end_time_sec = 53
start_time = start_time_sec * cfg.SR
end_time = end_time_sec * cfg.SR
song = song[start_time:end_time]
print("Song waveform mixture:")
ipd.Audio(song, rate=cfg.SR)

In [None]:
# Extract spectrogram from waveform

mix_stft = librosa.stft(song, n_fft=cfg.FRAME_SIZE, hop_length=cfg.HOP_SIZE, window='hann')
mix_mag, mix_phase = librosa.magphase(mix_stft)

numOfSamples = mix_mag.shape[1] // cfg.SAMPLE_SZ # This will cut off the last bit of the song
        
# Print some info
print(f"Number of samples: {numOfSamples}")
print(f"Shape of mix_mag: {mix_mag.shape}")
    
mix_mag_samples = np.array(np.split(mix_mag[:512,:cfg.SAMPLE_SZ*numOfSamples], numOfSamples, axis=1))[:,:,:,np.newaxis]
# Trim phase information to match the shape of the magnitude
mix_phase = mix_phase[:512,:numOfSamples*cfg.SAMPLE_SZ]
mix_phase_samples = np.array(np.split(mix_phase, numOfSamples, axis=1))[:,:,:,np.newaxis]

# Print some info
print(f"Shape of mix_mag_samples: {mix_mag_samples.shape}")
print(f"Shape of mix_phase_samples: {mix_phase_samples.shape}")

In [None]:
# Example of a spectrogram
plot_spectrogram_with_shading(mix_mag_samples[3], 'Mix Spectrogram', 111)

In [None]:
# Normalize mix_mag_samples
youtube_x, _, youtube_x_norm_factors = normalize(
        np.copy(mix_mag_samples),
        np.copy(mix_mag_samples),
        normalization=args.normalization,
        quantile_scaler=args.quantile_scaler,
        q_min=args.q_min,
        q_max=args.q_max,
)

In [None]:
def denormalize_and_istft(pred_norm, norm_factors, phase):
    pred = denormalize(
        pred_norm,
        norm_factors,
        normalization=args.normalization,
        quantile_scaler=args.quantile_scaler,
    )
    pred_comb = np.concatenate(pred, axis=1)
    phase_comb = np.concatenate(phase, axis=1)
    return librosa.istft(
                pred_comb[:, :, 0] * phase_comb[:, :, 0],
                hop_length=cfg.HOP_SIZE,
                window="hann",
            )

##### Predict the vocals from current model

In [None]:
pred_norm_cur_model = model.predict(youtube_x)
pred_wave_cur_model = denormalize_and_istft(pred_norm_cur_model, youtube_x_norm_factors, mix_phase_samples)

In [None]:
print("Predicted vocal waveform from the current model:")
ipd.Audio(pred_wave_cur_model, rate=cfg.SR) # Not bad! But we can do better.

##### Predict the vocals from our best model

In [None]:
# Download and unzip the pretrained model
# Link: https://drive.google.com/file/d/1_n6yMqqxSdr2f_WhLhPwAtC2CSadIH7a/view?usp=sharing

best_model_name = 'model_20231208-002503_5244b97903ec49b58783dd64f9c9f5ca'
!gdown 1_n6yMqqxSdr2f_WhLhPwAtC2CSadIH7a && unzip {best_model_name}.zip && rm {best_model_name}.zip

In [None]:
best_model = get_model((cfg.FREQUENCY_BINS, cfg.SAMPLE_SZ), num_classes=1)
best_model.load_weights(f"./{best_model_name}/{best_model_name}-val.hdf5")

In [None]:
pred_norm_best_model = best_model.predict(youtube_x)
pred_wave_best_model = denormalize_and_istft(pred_norm_best_model, youtube_x_norm_factors, mix_phase_samples)

In [None]:
print("Predicted vocal waveform from the best model:")
ipd.Audio(pred_wave_best_model, rate=cfg.SR) 
# Notice the song is slightly shorter than the original, this is due to the way we split the song into samples

In [None]:
print("Original mixed waveform:")
ipd.Audio(song, rate=cfg.SR)

### Comparing data normalization techniques

In [None]:
def plot_matrix(matrix, title, ax, decimals=2):
    cax = ax.matshow(matrix, cmap=plt.cm.viridis)
    plt.colorbar(cax, ax=ax)
    ax.set_title(title)
    # Time and frequency axis labels
    ax.set_xlabel('Time')
    ax.set_ylabel('Frequency')
    for (i, j), val in np.ndenumerate(matrix):
        ax.text(j, i, f"{val:.{decimals}f}", ha='center', va='center', color='white')

In [None]:
# Let's recreate the small matrix using only integer values for the unnormalized spectrogram.
np.random.seed(0)
small_spectrogram = np.random.randint(0, 100, (10, 5))

# Plotting the matrices with actual number values
fig, axes = plt.subplots(1, 1, figsize=(4, 6))

# # Plot original small spectrogram with integer values
plot_matrix(small_spectrogram, 'Original Small Spectrogram (Integers)', axes, decimals=0)

plt.tight_layout()

In [None]:
# Let's recreate the small matrix using only integer values for the unnormalized spectrogram.
small_spectrogram = np.random.randint(0, 100, (10, 5))

# Perform min-max normalization across the frequency dimension (rows)
min_max_norm_freq_small = (small_spectrogram - small_spectrogram.min(axis=1)[:, None]) / \
                          (small_spectrogram.max(axis=1) - small_spectrogram.min(axis=1))[:, None]

# Perform min-max normalization across the time dimension (columns)
min_max_norm_time_small = (small_spectrogram - small_spectrogram.min(axis=0)) / \
                          (small_spectrogram.max(axis=0) - small_spectrogram.min(axis=0))

# Plotting the matrices with actual number values
fig, axes = plt.subplots(1, 2, figsize=(8, 6))

# # Plot original small spectrogram with integer values
# plot_matrix(small_spectrogram, 'Original Small Spectrogram (Integers)', axes[0], decimals=0)

# Plot min-max normalization across frequency with float values
plot_matrix(min_max_norm_freq_small, 'Min-Max Normalization\n(Frequency)', axes[0])

# Plot min-max normalization across time with float values
plot_matrix(min_max_norm_time_small, 'Min-Max Normalization\n(Time)', axes[1])

plt.tight_layout()

In [None]:
from sklearn.preprocessing import RobustScaler

# Function to perform robust scaling across a specified axis
def robust_scale(matrix, axis):
    scaler = RobustScaler()
    if axis == 0:  # Scale each column independently
        scaled_matrix = scaler.fit_transform(matrix.T).T
    elif axis == 1:  # Scale each row independently
        scaled_matrix = scaler.fit_transform(matrix)
    return scaled_matrix

# Apply robust scaling to the small_spectrogram across both axes
robust_scaled_freq_small = robust_scale(small_spectrogram, axis=1)  # Scale each row (frequency)
robust_scaled_time_small = robust_scale(small_spectrogram, axis=0)  # Scale each column (time)

# Plotting the matrices with robust scaling applied
fig, axes = plt.subplots(1, 2, figsize=(8, 6))



# Plot robust scaling across frequency with float values
plot_matrix(robust_scaled_freq_small, 'Robust Scaling\n(Frequency)', axes[0])

# Plot robust scaling across time with float values
plot_matrix(robust_scaled_time_small, 'Robust Scaling\n(Time)', axes[1])

plt.tight_layout()