<h3>Notebook for Contrastive Learning and Drum Experimentations<h3>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd drive/MyDrive/SSLMB/
%ls 

In [None]:
!pip install -r requirements.txt
!pip install stempeg

<h4>Step 1: import the packages and functions we need.<h4>

In [None]:
import os 
import random
import torch
import stempeg

import numpy             as np
import matplotlib.pyplot as plt
import librosa           as audio_lib
import IPython.display   as ipd

import processing.input_rep as IR

from models.models import Pretext_CNN

from spleeter.separator import Separator
import processing.source_separation as source_separation

fp_musdb18 = "Jupyter Data Sets/musdb18/"

<h5>The following functions are used to verify whether a signal contains drums, and more importantly make sure that these aren't too overbearing compared to the rest of the signal. These are used to pre-process our stems for the pretext task.<h5>

In [None]:
def compute_rms(signal):
    """
    Function for combining a signal's Root Mean Square (RMS) value.
    -- signal : input waveform
    """
    rms = audio_lib.feature.rms(y=signal, frame_length=2048, hop_length=512)

    # Compute mean and standard deviations of rms for stems
    mean_rms = np.mean(rms)
    std_rms  = np.std(rms)

    return rms, mean_rms, std_rms


def check_drum_stem(drums, other, low_bound):
    """
    Function for thresholding drums. Goal is to make sure drum clip has enough energy.
    -- drums     : drum signal
    -- other     : rest of song signal
    -- low_bound : lower percentage bound
    """
    ros_rms,  _, _ = compute_rms(other)
    drum_rms, _, _ = compute_rms(drums)

    rms_check1 = drum_rms[:] > ros_rms[:] / 2
    rms_check2 = drum_rms[:] < ros_rms[:] * 4
    rms_check1 = rms_check1.astype(int)[0]
    rms_check2 = rms_check2.astype(int)[0]
    rms_check  = rms_check1[:] * rms_check2[:]
    rms_sum    = np.sum(rms_check)
    rms_perc   = rms_sum / len(rms_check)

    if (low_bound < rms_perc < 1.):
        return True, rms_perc
    
    else:
        return False, rms_perc

    
def takeSecond(elem):
    """
    Function that returns second element of list or tuple. Used to sort list of tuples by second element later on.
    elem : tuple or list
    """
    return elem[1]


def gen_vqt(signal, sample_rate):
    """
    Generates a high-resolution XQT spectrogram.
    -- signal      : signal to compute XQT on
    -- sample_rate : self-explanatory 
    """
    hop_length  = 256
    first_note  = 'C0'
    octave_reso = 12
    num_octaves = 8

    fmin = audio_lib.note_to_hz(first_note)

    VQT = audio_lib.vqt(y=signal, sr=sample_rate, hop_length=hop_length, fmin=fmin, 
                        n_bins=num_octaves*octave_reso, bins_per_octave=octave_reso)
    
    return VQT

def plot_vqt(signal, sample_rate, title, colorbar=False, axis=False, save=None):
    """
    Generates a high-resolution XQT spectrogram.
    -- signal      : signal to compute XQT on
    -- sample_rate : self-explanatory 
    """
    hop_length  = 256
    first_note  = 'C0'
    octave_reso = 12
    num_octaves = 8

    fmin = audio_lib.note_to_hz(first_note)

    VQT = audio_lib.vqt(y=signal, sr=sample_rate, hop_length=hop_length, fmin=fmin, 
                        n_bins=num_octaves*octave_reso, bins_per_octave=octave_reso)
    
    VQT = np.abs(VQT)

    VQT = np.log(0.1 + VQT)

    fig, ax = plt.subplots()

    img = audio_lib.display.specshow(VQT, hop_length=hop_length,
                                     sr=sample_rate, x_axis='time', y_axis='cqt_note', ax=ax, cmap='gray_r')
    
    if (title != None):
        ax.set_title(title)

    if (colorbar == True):
        fig.colorbar(img, ax=ax, format="%+2.0f dB")

    if (axis == False):
        plt.axis('off')

    if (save != None):
        plt.savefig(save, dpi=300)

    plt.show()

    return

# Load the separation model:
spl_mod   = "4stems"
m         = 'spleeter:{}'.format(spl_mod)
separator = Separator(m)

In [None]:
%ls Jupyter\ Data\ Sets/

<h4>Step 2: import XX second stems. Load signals and log VQTs from MUS DB data set.<h4>

In [None]:
# Can be 5 or 10
XX = 10

if (XX == 5):
    VQT_len = 313

else:
    VQT_len = 626

mus_train = os.listdir(fp_musdb18 + 'train/')
mus_test  = os.listdir(fp_musdb18 + 'test/')

len_musdb = len(mus_train) + len(mus_test)

musdb = np.zeros((150, 3, 96, VQT_len))

full__signals = np.zeros((150, XX * 16000))
drums_signals = np.zeros((150, XX * 16000))
other_signals = np.zeros((150, XX * 16000))

low_bound = 0.3

idx = 0
for fp in mus_train:
    S, rate = stempeg.read_stems(fp_musdb18 + 'train/' + fp, stem_id=[0], sample_rate=44100)
    
    stems = source_separation.wv_run_spleeter(S, rate, separator, spl_mod)

    S = (S[:, 0] + S[:, 1]) / 2
    
    drums = np.zeros((S.shape[0]))
    other = np.zeros((S.shape[0]))
        
    drums[:] = (stems['drums'][:, 0] + stems['drums'][:, 1]) / 2
    other[:] = (stems['other'][:, 0] + stems['other'][:, 1] + stems['vocals'][:, 0] + stems['vocals'][:, 1]
              + stems['bass'][:, 0] + stems['bass'][:, 1] ) / 2
    
    stem_status = False
    
    iters = 0
    while (stem_status == False):
        if (iters > 25):
            break
            
        temp_idx    = random.randint(0, len(drums) - XX * rate - 1)
        
        temp_drums  = drums[temp_idx:temp_idx + XX * rate]
        temp_other  = other[temp_idx:temp_idx + XX * rate]
        
        stem_status, stem_pow = check_drum_stem(temp_drums, temp_other, low_bound)
        
        iters += 1
        
    if (stem_status == False):
        continue

    temp__full = S[temp_idx:temp_idx + XX * rate]
    temp__full = audio_lib.resample(temp__full, 44100, 16000)
    temp_drums = audio_lib.resample(temp_drums, 44100, 16000)
    temp_other = audio_lib.resample(temp_other, 44100, 16000)
                    
    VQT_drums = IR.generate_XQT(temp_drums, 16000, 'vqt')
    VQT_other = IR.generate_XQT(temp_other, 16000, 'vqt')
    VQT__full = IR.generate_XQT(temp__full, 16000, 'vqt')
    
    drums_signals[idx, :] = temp_drums[:]
    other_signals[idx, :] = temp_other[:]
    full__signals[idx, :] = temp__full[:]
    
    musdb[idx, 0, :, :] = VQT_other[:, :]
    musdb[idx, 1, :, :] = VQT_drums[:, :]
    musdb[idx, 2, :, :] = VQT__full[:, :]
        
    print("{} -- {} : RMS Pow% is {:.3f}.".format(idx, 'train/' + fp, stem_pow))
    
    idx += 1
        
for fp in mus_test:
    S, rate = stempeg.read_stems(fp_musdb18 + 'test/' + fp, stem_id=[0], sample_rate=44100)
    
    stems = source_separation.wv_run_spleeter(S, rate, separator, spl_mod)

    S = (S[:, 0] + S[:, 1]) / 2
    
    drums = np.zeros((S.shape[0]))
    other = np.zeros((S.shape[0]))
        
    drums[:] = (stems['drums'][:, 0] + stems['drums'][:, 1]) / 2
    other[:] = (stems['other'][:, 0] + stems['other'][:, 1] + stems['vocals'][:, 0] + stems['vocals'][:, 1]
              + stems['bass'][:, 0] + stems['bass'][:, 1] ) / 2
    
    stem_status = False
    
    iters = 0
    while (stem_status == False):
        if (iters > 25):
            break
            
        temp_idx    = random.randint(0, len(drums) - XX * rate - 1)
        
        temp_drums  = drums[temp_idx:temp_idx + XX * rate]
        temp_other  = other[temp_idx:temp_idx + XX * rate]
        
        stem_status, stem_pow = check_drum_stem(temp_drums, temp_other, low_bound)
        
        iters += 1
        
    if (stem_status == False):
        continue

    temp__full = S[temp_idx:temp_idx + XX * rate]
    temp__full = audio_lib.resample(temp__full, 44100, 16000)
    temp_drums = audio_lib.resample(temp_drums, 44100, 16000)
    temp_other = audio_lib.resample(temp_other, 44100, 16000)
                    
    VQT_drums = IR.generate_XQT(temp_drums, 16000, 'vqt')
    VQT_other = IR.generate_XQT(temp_other, 16000, 'vqt')
    VQT__full = IR.generate_XQT(temp__full, 16000, 'vqt')
    
    drums_signals[idx, :] = temp_drums[:]
    other_signals[idx, :] = temp_other[:]
    full__signals[idx, :] = temp__full[:]
    
    musdb[idx, 0, :, :] = VQT_other[:, :]
    musdb[idx, 1, :, :] = VQT_drums[:, :]
    musdb[idx, 2, :, :] = VQT__full[:, :]
        
    print("{} -- {} : RMS Pow% is {:.3f}.".format(idx, 'test/' + fp, stem_pow))
    
    idx += 1
    
temp_musdb = np.zeros((idx, 3, 96, VQT_len))

temp_drums_signals = np.zeros((idx, XX * 16000))
temp_other_signals = np.zeros((idx, XX * 16000))
temp_full__signals = np.zeros((idx, XX * 16000))

temp_musdb[:, :, :, :] = musdb[:idx, :, :, :]

temp_drums_signals = drums_signals[:idx, :]
temp_other_signals = other_signals[:idx, :]
temp_full__signals = full__signals[:idx, :]

musdb = temp_musdb

drums_signals = temp_drums_signals
other_signals = temp_other_signals
full__signals = temp_full__signals
        
musdb = torch.from_numpy(musdb)
print("MUS DB bank shape is : {}.".format(musdb.shape))

<h4>Step 3: load pretext task model on cpu or gpu.<h4>

In [None]:
# Load pretext task model weights
device     = torch.device('cpu')
model      = Pretext_CNN(pretext=False)
state_dict = torch.load("models/saved/shift_pret_cnn_16.pth", map_location=device)
model.load_state_dict(state_dict)

if (torch.cuda.is_available() == True):
    model = model.cuda()

model.eval()

<h4>Step 4: compute embeddings for percussive and non-percussive signals. Plot and listen for fun!<h4>

In [None]:
torch.cuda.empty_cache()

# Randomly select an anchor 
anchor_idx = random.randint(0, len(musdb) - 1)

print("Randomly chosen index is : {}.".format(anchor_idx))

# Compute model output for anchor
anchor     = musdb[anchor_idx, 0, :, :].reshape((1, 1, musdb.shape[2], musdb.shape[3]))

if (torch.cuda.is_available() == True):
    anchor = anchor.cuda()
    anchor_emb = model.anchor(anchor.float())

else:
    anchor_emb = model.anchor(anchor.float())

# Compute model output for all positives
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-08)

if (torch.cuda.is_available == True):
    cos = cos.cuda()

l = []

pos_idx = anchor_idx

positive     = musdb[pos_idx, 1, :, :].reshape((1, 1, musdb.shape[2], musdb.shape[3]))

if (torch.cuda.is_available() == True):
    positive = positive.cuda()
    positive_emb = model.postve(positive.float())

else:
    positive_emb = model.postve(positive.float())

# Compute Cosine Similarity
cos_output = cos(anchor_emb, positive_emb)

print("\nCosine similarity  between anchor and positive is: {:.3f}.".format(float(cos_output)))
        
ipd.display(ipd.Audio(drums_signals[pos_idx], rate=16000))
ipd.display(ipd.Audio(other_signals[anchor_idx], rate=16000))
ipd.display(ipd.Audio(other_signals[anchor_idx] + drums_signals[pos_idx], rate=16000))

if (torch.cuda.is_available() == True):
    positive = positive.cuda()
    positive_emb = model.postve(positive.float())
else:
    positive_emb = model.postve(positive.float())

plt.rcParams["figure.figsize"] = (10,4)

#########################################################
print("\nPlots for Results section:")

if (torch.cuda.is_available() == True):
    plt.plot(anchor_emb.cpu().detach().numpy().reshape(VQT_len))
    plt.plot(positive_emb.cpu().detach().numpy().reshape(VQT_len))
else:
    plt.plot(anchor_emb.detach().numpy().reshape(VQT_len))
    plt.plot(positive_emb.detach().numpy().reshape(VQT_len))
plt.legend(["Non-percussive", "Percussive"])
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.ylim((-0.1, 1))
plt.yticks([0., 0.25, 0.5, 0.75, 1.0])
plt.show()

x1 = np.linspace(0, XX, XX * 16000)
x2 = np.linspace(0, XX, VQT_len)

plt.plot(x1, drums_signals[pos_idx])
if (torch.cuda.is_available() == True):
    plt.plot(x2, positive_emb.cpu().detach().numpy().reshape(VQT_len))
else:
    plt.plot(x2, positive_emb.detach().numpy().reshape(VQT_len))
plt.legend(["Signal", "Embedding"])
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.ylim((-1, 1))
plt.yticks([-1.0, -0.5, 0., 0.5, 1.0])
plt.show()

plt.plot(x1, other_signals[anchor_idx])
if (torch.cuda.is_available() == True):
    plt.plot(x2, anchor_emb.cpu().detach().numpy().reshape(VQT_len))
else:
    plt.plot(x2, anchor_emb.detach().numpy().reshape(VQT_len))
plt.legend(["Signal", "Embedding"])
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.ylim((-1, 1))
plt.yticks([-1.0, -0.5, 0., 0.5, 1.0])
plt.show()