[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mbrotos/SoungSeg/blob/main/demo.ipynb)

# Singing Voice Separation by U-Net

In [None]:
IN_COLAB = False
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab, cloning repo, and installing requirements ...')
    IN_COLAB = True
    !git clone https://github.com/mbrotos/SoungSeg.git
    %cd SoungSeg/src
    !pip install -r requirements.txt
else:
    print('Not running on CoLab, creating local environment and installing requirements ...')
    %cd src
    !python -m venv venv
    !source ./venv/bin/activate
    !pip install -r requirements.txt

In [None]:
import os
import librosa
import IPython.display as ipd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy import stats
import argparse
import config as cfg
import datetime
from scaler import normalize, denormalize
import pickle
import json
from augmentations import consecutive_oversample, blackout
from uuid import uuid4

### Data Loading

In [None]:
# Obtain dataset from Zenodo
# MUSDB18 - a corpus for music separation https://doi.org/10.5281/zenodo.3338373
# PLEASE NOTE: This dataset is quite large (>20GB -- zipped) and may take a while to download.

# Alternatively, you can download a subset of the dataset (~9GB) with the following command: 
# https://drive.google.com/file/d/1_kdifA4ztVXBveb9FYzmY49fvAKZmIJF/view?usp=sharing
!gdown 1_kdifA4ztVXBveb9FYzmY49fvAKZmIJF && unzip data_wav.zip && rm data_wav.zip

# This cell may take 5 minutes to run. Please be patient.

In [None]:
# Create required directories
!mkdir -p models processed_data

# Preprocess dataset
!python preprocessing.py --dsType train
!python preprocessing.py --dsType test

# Tensorflow dataset prep
!python dataset_prep.py

#### Song mixutre and vocal example

In [None]:
mix_mags_train = np.load("./processed_data/mix_mags_train_512x128.npy", mmap_mode='r' )
mix_phases_train = np.load("./processed_data/mix_phases_train_512x128.npy", mmap_mode='r')
vocal_train = np.load(f"./processed_data/vocal_mags_train_512x128.npy",mmap_mode='r')

### Preprocessing

#### Visualize the data

### Run unit tests

In [None]:
!python test_audio_processing.py

### Define Model Architecture

In [None]:
import tensorflow as tf
from keras.layers import Activation, Conv2D, BatchNormalization, Conv2DTranspose, Concatenate, MaxPooling2D, Input, Conv1D, Normalization

def get_model(img_size, num_classes=1):
    inputs = Input(shape=img_size + (1,))

    conv1 = Conv2D(64, 3, strides=1, padding="same")(inputs)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation("relu")(conv1)

    conv2 = Conv2D(64, 3, strides=1, padding="same")(conv1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Activation("relu")(conv2)

    pool1 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(128, 3, strides=1, padding="same")(pool1)
    conv3 = BatchNormalization()(conv3)
    conv3 = Activation("relu")(conv3)

    conv4 = Conv2D(128, 3, strides=1, padding="same")(conv3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Activation("relu")(conv4)

    pool2 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(256, 3, strides=1, padding="same")(pool2)
    conv5 = BatchNormalization()(conv5)
    conv5 = Activation("relu")(conv5)

    conv6 = Conv2D(256, 3, strides=1, padding="same")(conv5)
    conv6 = BatchNormalization()(conv6)
    conv6 = Activation("relu")(conv6)

    pool3 = MaxPooling2D(pool_size=(2, 2))(conv6)

    conv7 = Conv2D(512, 3, strides=1, padding="same")(pool3)
    conv7 = BatchNormalization()(conv7)
    conv7 = Activation("relu")(conv7)

    conv8 = Conv2D(512, 3, strides=1, padding="same")(conv7)
    conv8 = BatchNormalization()(conv8)
    conv8 = Activation("relu")(conv8)

    pool4 = MaxPooling2D(pool_size=(2, 2))(conv8)

    conv9 = Conv2D(1024, 3, strides=1, padding="same")(pool4)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation("relu")(conv9)

    conv10 = Conv2D(1024, 3, strides=1, padding="same")(conv9)
    conv10 = BatchNormalization()(conv10)
    conv10 = Activation("relu")(conv10)

    up1 = Conv2DTranspose(512, 2, strides=2, padding="same")(conv10)
    up1 = Concatenate()([up1, conv8])

    upconv1 = Conv2D(512, 3, strides=1, padding="same")(up1)
    upconv1 = BatchNormalization()(upconv1)
    upconv1 = Activation("relu")(upconv1)

    upconv2 = Conv2D(512, 3, strides=1, padding="same")(upconv1)
    upconv2 = BatchNormalization()(upconv2)
    upconv2 = Activation("relu")(upconv2)

    up2 = Conv2DTranspose(256, 2, strides=2, padding="same")(upconv2)
    up2 = Concatenate()([up2, conv6])

    upconv3 = Conv2D(256, 3, strides=1, padding="same")(up2)
    upconv3 = BatchNormalization()(upconv3)
    upconv3 = Activation("relu")(upconv3)

    upconv4 = Conv2D(256, 3, strides=1, padding="same")(upconv3)
    upconv4 = BatchNormalization()(upconv4)
    upconv4 = Activation("relu")(upconv4)

    up3 = Conv2DTranspose(128, 2, strides=2, padding="same")(upconv4)
    up3 = Concatenate()([up3, conv4])

    upconv5 = Conv2D(128, 3, strides=1, padding="same")(up3)
    upconv5 = BatchNormalization()(upconv5)
    upconv5 = Activation("relu")(upconv5)

    upconv6 = Conv2D(128, 3, strides=1, padding="same")(upconv5)
    upconv6 = BatchNormalization()(upconv6)
    upconv6 = Activation("relu")(upconv6)

    up4 = Conv2DTranspose(64, 2, strides=2, padding="same")(upconv6)
    up4 = Concatenate()([up4, conv2])

    upconv7 = Conv2D(64, 3, strides=1, padding="same")(up4)
    upconv7 = BatchNormalization()(upconv7)
    upconv7 = Activation("relu")(upconv7)

    upconv8 = Conv2D(64, 3, strides=1, padding="same")(upconv7)
    upconv8 = BatchNormalization()(upconv8)
    upconv8 = Activation("relu")(upconv8)

    output = Conv1D(num_classes, 1, activation="linear")(upconv8)

    # Define the model
    model = tf.keras.Model(inputs, output)
    return model

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10, help='Number of epochs to train for')
parser.add_argument("--batch_size", type=int, default=5, help='Batch size for training')
parser.add_argument("--normalization", type=str, default="frequency", help='Normalization axis (time or frequency)')
parser.add_argument("--lr", type=float, default=1e-3, help='Learning rate for training')
parser.add_argument("--mask", action="store_true", default=False, help='Experimental. Causes unstable training.')
parser.add_argument("--quantile_scaler", action="store_true", default=False, help='Toggle quantile scaling as the normalization method')
parser.add_argument("--q_min", type=float, default=25.0, help='Minimum quantile for quantile scaling')
parser.add_argument("--q_max", type=float, default=75.0, help='Maximum quantile for quantile scaling')
parser.add_argument("--loss", type=str, default="mse", help='Loss function to use (mse or mae)')
parser.add_argument("--dataset_size", type=int, default=None, help='Number of samples to use from the dataset (None = all)')
parser.add_argument("--augmentations", action="store_true", default=False, help='Toggle data augmentations (splicing, and blackout)')
parser.add_argument("--seed", type=int, default=42, help='Random seed for reproducibility')
parser.add_argument("--mmap", action="store_true", default=True, help='Toggle memory mapping for dataset loading (helps with large datasets and limited RAM)')

# Top performaning model args, the batch size has been reduced from 64 to 25 to fit on colab GPUs
args = parser.parse_args(['--normalization', 'frequency', '--epochs', '20', '--batch_size', '25', '--loss', 'mae', '--augmentations'])

print('Args:')
print(args)

np.random.seed(args.seed)
tf.random.set_seed(args.seed)

### Define datasets

In [None]:
# load data
mix_mags_train = np.load("./processed_data/mix_mags_train_512x128.npy", mmap_mode='r' if args.mmap else None)[:args.dataset_size]
mix_phases_train = np.load("./processed_data/mix_phases_train_512x128.npy", mmap_mode='r' if args.mmap else None)[:args.dataset_size]
vocal_train = np.load( f"./processed_data/vocal_mags_train_512x128.npy", mmap_mode='r' if args.mmap else None)[:args.dataset_size]

mix_mags_train_norm, vocal_train_norm, mix_mags_train_norm_factors = normalize(
    np.copy(mix_mags_train),
    np.copy(vocal_train),
    normalization=args.normalization,
    quantile_scaler=args.quantile_scaler,
    q_min=args.q_min,
    q_max=args.q_max,
)

In [None]:
if args.augmentations:
    print('Appling augmentations...')
    
    # Remove outliers
    true_vocal = denormalize(
        vocal_train_norm,
        mix_mags_train_norm_factors,
        normalization=args.normalization,
        quantile_scaler=args.quantile_scaler,
    )

    vocal_waves = []

    for i in range(0, len(true_vocal)):
        cur_phase = np.concatenate(mix_phases_train[i : i + 1], axis=1)
        cur_true_vocal = np.concatenate(true_vocal[i : i + 1], axis=1)
        vocal_waves.append(librosa.istft(
                cur_true_vocal[:, :, 0] * cur_phase[:, :, 0],
                hop_length=cfg.HOP_SIZE,
                window="hann",
            )
        )
    vocal_waves = np.array(vocal_waves)
    dist = np.abs(vocal_waves).sum(axis=1)
    indices = np.where(dist < 100)[0]

    mix_mags_train_norm = np.delete(mix_mags_train_norm, indices, axis=0)
    vocal_train_norm = np.delete(vocal_train_norm, indices, axis=0)
    mix_mags_train_norm_factors = np.delete(mix_mags_train_norm_factors, indices, axis=0)
    mix_phases_train = np.delete(mix_phases_train, indices, axis=0)
    
    # Splicing and blackout
    mix_blackout, vocal_blackout = blackout(mix_mags_train_norm, vocal_train_norm)
    mix_blackout = mix_blackout[:mix_blackout.shape[0]//4]
    vocal_blackout = vocal_blackout[:vocal_blackout.shape[0]//4]
    mix_consec, vocal_consec = consecutive_oversample(mix_mags_train_norm, vocal_train_norm)
    mix_consec = mix_consec[:mix_consec.shape[0]//2]
    vocal_consec = vocal_consec[:vocal_consec.shape[0]//2]

    mix_mags_train_norm = np.concatenate((mix_mags_train_norm, mix_consec, mix_blackout), axis=0)
    vocal_train_norm = np.concatenate((vocal_train_norm, vocal_consec, vocal_blackout), axis=0)
    
    
    
    
print('Datasets:')
print(f'Mixes: {mix_mags_train_norm.shape}')
print(f'Vocals: {vocal_train_norm.shape}')

#### Augmentation Visualizations

In [None]:
def plot_spectrogram_with_shading(spectrogram, title, subplot_index, shade_start=None, shade_end=None):
    plt.subplot(subplot_index)
    plt.imshow(spectrogram, aspect='auto', origin='lower')
    if shade_start is not None and shade_end is not None:
        plt.axvspan(shade_start, shade_end, color='red', alpha=0.2)
    plt.title(title)
    plt.colorbar()

i=0

# Plotting
plt.figure(figsize=(12, 8))

# Original Mix Spectrogram
plot_spectrogram_with_shading(mix_mags_train_norm[i], 'Original Mix', 321)

# Original Vocal Spectrogram
plot_spectrogram_with_shading(vocal_train_norm[i], 'Original Vocal', 322)

# Blackout Mix Spectrogram
plot_spectrogram_with_shading(mix_blackout[i], 'Blackout Mix', 323)

# Blackout Vocal Spectrogram
plot_spectrogram_with_shading(vocal_blackout[i], 'Blackout Vocal', 324)

plt.tight_layout()
plt.show()

# Plotting Consecutive Oversample with shaded highlights
plt.figure(figsize=(18, 6))

# Original Mix Spectrogram with shading on the last half
plot_spectrogram_with_shading(mix_mags_train_norm[i], 'Original Mix', 131, shade_start=64, shade_end=128)

# Original Mix + 1 Spectrogram with shading on the first half
plot_spectrogram_with_shading(mix_mags_train_norm[i+1], 'Original Mix + 1', 132, shade_start=0, shade_end=64)

# Consecutive Oversample Mix Spectrogram with corresponding shaded highlights
plot_spectrogram_with_shading(mix_consec[i], 'Consecutive Oversample Mix', 133)

plt.tight_layout()
plt.show()


In [None]:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
model_name = f"model_{timestamp}_{uuid4().hex}"
    
# Save normalization factors as pkl
with open(f"./models/scaler--{model_name}.pkl", "wb") as f:
    pickle.dump(mix_mags_train_norm_factors, f)

data_len = len(mix_mags_train_norm)

# shuffle datasets before splitting
indices = np.arange(data_len)
np.random.shuffle(indices)
mix_mags_train_norm = mix_mags_train_norm[indices]
vocal_train_norm = vocal_train_norm[indices]

val_len = int(data_len * 0.1)
val_data = (
    mix_mags_train_norm[-val_len:],
    vocal_train_norm[-val_len:],
)
train_data = (
    mix_mags_train_norm[:-val_len],
    vocal_train_norm[:-val_len],
)

dataset = tf.data.Dataset.from_tensor_slices(train_data)
dataset = (
    dataset.shuffle(args.batch_size * 2)
    .batch(args.batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
val_ds = tf.data.Dataset.from_tensor_slices(val_data)
val_ds = val_ds.batch(args.batch_size).prefetch(tf.data.AUTOTUNE)

### Build Model

In [None]:
model = get_model((cfg.FREQUENCY_BINS, cfg.SAMPLE_SZ), num_classes=1)

model.compile(
    optimizer=tf.keras.optimizers.Adam(args.lr),
    loss=args.loss,
)

os.makedirs(f"./models/{model_name}")
os.makedirs(f"./models/{model_name}/logs")

# Save args as json
with open(f"./models/{model_name}/args.json", "w") as f:
    json.dump(vars(args), f, indent=4)
    
model.summary()

### Train the model

In [None]:
history = model.fit(
    dataset,
    validation_data=val_ds,
    epochs=args.epochs,
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint(
            filepath=f"./models/{model_name}/{model_name}-val.hdf5", save_best_only=True, monitor='val_loss', mode='min', save_weights_only=True
        ),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=f"./models/{model_name}/{model_name}-train.hdf5", save_best_only=True, monitor='loss', mode='min', save_weights_only=True
        ),
        tf.keras.callbacks.TensorBoard(
            log_dir=f"./models/{model_name}/logs", histogram_freq=1
        ),
    ],
)

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

# Extracting the number of epochs
epochs = range(1, len(loss) + 1)

# Plotting the loss and validation loss
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, 'bo-', label='Training loss')
plt.plot(epochs, val_loss, 'ro-', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

### Evaluate the model

#### Test voice seperation

#### MIR_Eval

### Comparing data normalization techniques