In [1]:
import sys
import os
import time
import soundfile as sf
import matplotlib.pyplot as plt
import numpy as np
from scipy import signal as sg
from scipy import fftpack as fp
from scipy import linalg
from sklearn.cluster import KMeans

import librosa
import librosa.display
from museval.metrics import bss_eval_images, bss_eval_sources

from sklearn.utils import check_random_state
from sklearn.utils.extmath import randomized_svd
from numpy.linalg import norm 

# K-shape

In [2]:
### Function for getting a shape-based distance (SBD) ###
def get_SBD(x, y):
    
    #Define FFT-size based on the length of input
    p = int(x.shape[0])
    FFTlen = int(2**np.ceil(np.log2(2*p-1)))
    
    #Compute the normalized cross-correlation function (NCC)
    CC = fp.ifft(fp.fft(x, FFTlen)*fp.fft(y, FFTlen).conjugate()).real
    
    #Reorder the IFFT result
    CC = np.concatenate((CC[-(p-1):], CC[:p]), axis=0)
    
    #To avoid zero division
    denom = linalg.norm(x) * linalg.norm(y)
    if denom < 1e-10:
        denom = numpy.inf
    NCC = CC / denom
    
    #Search for the argument to maximize NCC
    ndx = np.argmax(NCC, axis=0)
    dist = 1 - NCC[ndx]
    #Get the shift parameter (s=0 if no shift)
    s = ndx - p + 1
    
    #Insert zero padding based on the shift parameter s
    if s > 0:
        y_shift = np.concatenate((np.zeros(s), y[0:-s]), axis=0)
    elif s == 0:
        y_shift = np.copy(y)
    else:
        y_shift = np.concatenate((y[-s:], np.zeros(-s)), axis=0)
    
    return dist, y_shift

### Function for updating k-shape centroid ###
def shape_extraction(X, v):
    
    #Define the length of input
    N = int(X.shape[0])
    p = int(X.shape[1])
    
    #Construct the phase shifted signal
    Y = np.zeros((N, p))
    for i in range(N):
        #Call my function for getting the SBD between centeroid and data
        _, Y[i, :] = get_SBD(v, X[i, :])
    
    #Construct the matrix M for Rayleigh quotient
    S = Y.T @ Y
    Q = np.eye(p) - np.ones((p, p)) / p
    M = Q.T @ (S @ Q)
    
    #Get the eigenvector corresponding to the maximum eigenvalue
    eigen_val, eigen_vec = linalg.eig(M)
    ndx = np.argmax(eigen_val, axis=0)
    new_v = eigen_vec[:, ndx].real
    
    #The ill-posed problem has both +v and -v as solution
    MSE_plus = np.sum((Y - new_v)**2)
    MSE_minus = np.sum((Y + new_v)**2)
    if MSE_minus < MSE_plus:
        new_v = -1*new_v
    
    return new_v

### Function for checking empty clusters ###
def check_empty(label, num_clu):
    
    #Get unique label (which must include all number 0~num_clu-1)
    label = np.unique(label)
    
    #Search empty clusters
    emp_ind = []
    for i in range(num_clu):
        if i not in label:
            emp_ind.append(i)
    
    #Output the indices corresponding to the empty clusters
    return emp_ind

### Function for getting KShape clustering ###
def get_KShape(X, num_clu, max_iter, num_init):
    
    #Define the length of input
    N = int(X.shape[0])  #The number of data
    p = int(X.shape[1])  #The length of temporal axis
    
    #Repeat for each trial (initialization)
    minloss = np.inf
    for init in range(num_init):
        
        #Initialize label, centroid, loss as raondom numbers
        label = np.round((num_clu-1) * np.random.rand(N))
        center = np.random.rand(num_clu, p)
        loss = np.inf
        
        #Normalize the centroid
        center = center - np.average(center, axis=1)[:, np.newaxis]
        center = center / np.std(center, axis=1)[:, np.newaxis]
        
        #Copy the label temporarily
        new_label = np.copy(label)
        new_center = np.copy(center)
        
        #Repeat for each iteration
        for rep in range(max_iter):
            
            #Reset loss value
            new_loss = 0
            
            ### Refinement step (update center) ###
            #Repeat for each cluster
            for j in range(num_clu):
                
                #Construct data matrix for the j-th cluster
                clu_X = []
                for i in range(N):
                    #If the i-th data belongs to the j-th cluster
                    if label[i] == j:
                        clu_X.append(X[i, :])
                clu_X = np.array(clu_X)
                
                #Call my function for updating centroid
                new_center[j,:] = shape_extraction(clu_X, center[j,:])
                
                #Normalize the centroid
                new_center = new_center - np.average(new_center, axis=1)[:, np.newaxis]
                new_center = new_center / np.std(new_center, axis=1)[:, np.newaxis]
            
            ### Assignment step (update label) ###
            #Repeat for each data
            for i in range(N):
                
                #Define the minimum distance
                mindist = np.inf
                
                #Repeat for each cluster
                for j in range(num_clu):
                    
                    #Call my function for getting the shape based distance
                    dist, _ = get_SBD(new_center[j,:], X[i, :])
                    
                    #Assign the label corresponding to the minimum distance
                    if dist < mindist:
                        #Update minimum distance
                        mindist = dist
                        new_label[i] = j
                
                #Get summation of the SBD
                new_loss = new_loss + mindist
            
            ### Error handling (avoid empty clusters) ###
            #Call my function for checking empty clusters
            emp_ind = check_empty(new_label, num_clu)
            if len(emp_ind) > 0:
                for ind in emp_ind:
                    #Assign the same index of data as the one of cluster
                    new_label[ind] = ind
            
            #Get out of the loop if loss and label unchange
            if loss - new_loss < 1e-6 and (new_label == label).all():
                #print("The iteration stopped at {}".format(rep+1))
                break
            
            #Update parameters
            label = np.copy(new_label)
            center = np.copy(new_center)
            loss = np.copy(new_loss)
            #print("Loss value: {:.3f}".format(new_loss))
        
        #Output the result corresponding to minimum loss
        if loss < minloss:
            out_label = np.copy(label).astype(np.int16)
            out_center = np.copy(center)
            minloss = loss
    
    #Output the label vector and centroid matrix
    return out_label, out_center, minloss

# NMF

In [3]:
# NMFの初期値
def initialize_nmf(X, n_components, init=None, eps=1e-6, random_state=None):

    n_samples, n_features = X.shape

    if init is None:
        if n_components < n_features:
            init = 'nndsvd'
        else:
            init = 'random'

    # Random initialization
    if init == 'random':
        avg = np.sqrt(X.mean() / n_components)
        rng = check_random_state(random_state)
        H = avg * rng.randn(n_components, n_features)
        W = avg * rng.randn(n_samples, n_components)
        # we do not write np.abs(H, out=H) to stay compatible with
        # numpy 1.5 and earlier where the 'out' keyword is not
        # supported as a kwarg on ufuncs
        np.abs(H, H)
        np.abs(W, W)
        return W, H

    # NNDSVD initialization
    U, S, V = randomized_svd(X, n_components, random_state=random_state)
    W, H = np.zeros(U.shape), np.zeros(V.shape)

    # The leading singular triplet is non-negative
    # so it can be used as is for initialization.
    W[:, 0] = np.sqrt(S[0]) * np.abs(U[:, 0])
    H[0, :] = np.sqrt(S[0]) * np.abs(V[0, :])

    for j in range(1, n_components):
        x, y = U[:, j], V[j, :]

        # extract positive and negative parts of column vectors
        x_p, y_p = np.maximum(x, 0), np.maximum(y, 0)
        x_n, y_n = np.abs(np.minimum(x, 0)), np.abs(np.minimum(y, 0))

        # and their norms
        x_p_nrm, y_p_nrm = norm(x_p), norm(y_p)
        x_n_nrm, y_n_nrm = norm(x_n), norm(y_n)

        m_p, m_n = x_p_nrm * y_p_nrm, x_n_nrm * y_n_nrm

        # choose update
        if m_p > m_n:
            u = x_p / x_p_nrm
            v = y_p / y_p_nrm
            sigma = m_p
        else:
            u = x_n / x_n_nrm
            v = y_n / y_n_nrm
            sigma = m_n

        lbd = np.sqrt(S[j] * sigma)
        W[:, j] = lbd * u
        H[j, :] = lbd * v

    W[W < eps] = 0
    H[H < eps] = 0

    if init == "nndsvd":
        pass
    elif init == "nndsvda":
        avg = X.mean()
        W[W == 0] = avg
        H[H == 0] = avg
    elif init == "nndsvdar":
        rng = check_random_state(random_state)
        avg = X.mean()
        W[W == 0] = abs(avg * rng.randn(len(W[W == 0])) / 100)
        H[H == 0] = abs(avg * rng.randn(len(H[H == 0])) / 100)
    else:
        raise ValueError(
            'Invalid init parameter: got %r instead of one of %r' %
            (init, (None, 'random', 'nndsvd', 'nndsvda', 'nndsvdar')))

    return W, H


#Function for audio pre-processing
def pre_processing(data, Fs, down_sam):
    
    #Transform stereo into monoral
    if data.ndim == 2:
        wavdata = 0.5*data[:, 0] + 0.5*data[:, 1]
    else:
        wavdata = data
    
    #Down sampling and normalization of the wave
    if down_sam is not None:
        wavdata = sg.resample_poly(wavdata, down_sam, Fs)
        Fs = down_sam
    
    return wavdata, Fs


#Function for getting STFT
def get_STFT(wav, Fs, frame_length, frame_shift):
    
    #Calculate the index of window size and overlap
    FL = round(frame_length * Fs)
    FS = round(frame_shift * Fs)
    OL = FL - FS
    
    #Execute STFT
    freqs, times, dft = sg.stft(wav, fs=Fs, window='hamm', nperseg=FL, noverlap=OL)
    arg = np.angle(dft) #Preserve the phase
    Adft = np.abs(dft) #Preserve the absolute amplitude
    Y = Adft
    
    #Display the size of input
    print("Spectrogram size (freq, time) = " + str(Y.shape))
    
    return Y, arg, Fs, freqs, times

#Function for getting inverse STFT
def get_invSTFT(Y, arg, Fs, frame_length, frame_shift):
    
    #Restrive the phase from original wave
    Y = Y * np.exp(1j*arg)
    
    #Get the inverse STFT
    FL = round(frame_length * Fs)
    FS = round(frame_shift * Fs)
    OL = FL - FS
    _, rec_wav = sg.istft(Y, fs=Fs, window='hamm', nperseg=FL, noverlap=OL)
    
    return rec_wav, Fs

#Function for removing components closing to zero
def get_nonzero(tensor):
    
    tensor = np.where(np.abs(tensor) < 1e-10, 1e-10+tensor, tensor)
    return tensor


#Function for getting basements and weights matrix by NMF
def get_NMF(Y, num_iter, num_base, loss_func, norm_H):
    
    #Initialize basements and weights based on the Y size(k, n)
    K, N = Y.shape[0], Y.shape[1]
    if num_base >= K or num_base >= N:
        print("The number of basements should be lower than input size.")
        sys.exit()
    
    #Remove Y entries closing to zero
    Y = get_nonzero(Y)
    
    #H = np.random.rand(K, num_base) #basements (distionaries)
    #U = np.random.rand(num_base, N) #weights (coupling coefficients)
    
    if init_mode == "None":
        init_H = np.random.rand(K, num_base) #basements (distionaries)
        init_U = np.random.rand(num_base, N) #weights (coupling coefficients)
    else:
        init_H, init_U = initialize_nmf(Y, num_base, init_mode) # None, 'random', 'nndsvd', 'nndsvda', 'nndsvdar'
        
    init_U = get_nonzero(init_U)
    init_H = get_nonzero(init_H)
    
    H = init_H
    U = init_U
    
    #Initialize loss
    loss = np.zeros(num_iter)
    
    #For a progress bar
    unit = int(np.floor(num_iter/10))
    bar = "#" + " " * int(np.floor(num_iter/unit))
    start = time.time()
    
    #In the case of squared Euclidean distance
    if loss_func == "EU":
        
        #Repeat num_iter times
        for i in range(num_iter):
            
            #Display a progress bar
            print("\rProgress:[{0}] {1}/{2} Processing..".format(bar, i, num_iter), end="")
            if i % unit == 0:
                bar = "#" * int(np.ceil(i/unit)) + " " * int(np.floor((num_iter-i)/unit))
                print("\rProgress:[{0}] {1}/{2} Processing..".format(bar, i, num_iter), end="")
            
            #Update the basements
            X = H @ U
            H = H * (Y @ U.T) / get_nonzero(X @ U.T)
            #Normalize the basements
            if norm_H == True:
                H = H / H.sum(axis=0, keepdims=True)
            
            #Update the weights
            X = H @ U
            U = U * (H.T @ Y) / get_nonzero(H.T @ X)
            
            #Normalize to ensure equal energy
            if norm_H == False:
                A = np.sqrt(np.sum(U**2, axis=1)/np.sum(H**2, axis=0))
                H = H * A[np.newaxis, :]
                U = U / A[:, np.newaxis]
            
            #Compute the loss function
            X = H @ U
            loss[i] = np.sum((Y - X)**2)
    
    #In the case of Kullback–Leibler divergence
    elif loss_func == "KL":
        
        #Repeat num_iter times
        for i in range(num_iter):
            
            #Display a progress bar
            print("\rProgress:[{0}] {1}/{2} Processing..".format(bar, i, num_iter), end="")
            if i % unit == 0:
                bar = "#" * int(np.ceil(i/unit)) + " " * int(np.floor((num_iter-i)/unit))
                print("\rProgress:[{0}] {1}/{2} Processing..".format(bar, i, num_iter), end="")
            
            #Update the basements
            X = get_nonzero(H @ U)
            denom_H = U.T.sum(axis=0, keepdims=True)
            H = H * ((Y / X) @ U.T) / get_nonzero(denom_H)
            #Normalize the basements
            if norm_H == True:
                H = H / H.sum(axis=0, keepdims=True)
            
            #Update the weights
            X = get_nonzero(H @ U)
            denom_U = H.T.sum(axis=1, keepdims=True)
            U = U * (H.T @ (Y / X)) / get_nonzero(denom_U)
            
            #Normalize to ensure equal energy
            if norm_H == False:
                A = np.sqrt(np.sum(U**2, axis=1)/np.sum(H**2, axis=0))
                H = H * A[np.newaxis, :]
                U = U / A[:, np.newaxis]
            
            #Compute the loss function
            X = get_nonzero(H @ U)
            loss[i] = np.sum(Y*np.log(Y) - Y*np.log(X) - Y + X)
    
    #In the case of Itakura–Saito divergence
    elif loss_func == "IS":
            
        #Repeat num_iter times
        for i in range(num_iter):
            
            #Display a progress bar
            print("\rProgress:[{0}] {1}/{2} Processing..".format(bar, i, num_iter), end="")
            if i % unit == 0:
                bar = "#" * int(np.ceil(i/unit)) + " " * int(np.floor((num_iter-i)/unit))
                print("\rProgress:[{0}] {1}/{2} Processing..".format(bar, i, num_iter), end="")
            
            #Update the basements
            X = get_nonzero(H @ U)
            denom_H = np.sqrt(X**-1 @ U.T)
            H = H * np.sqrt((Y / X**2) @ U.T) / get_nonzero(denom_H)
            #Normalize the basements (it is recommended when IS divergence)
            H = H / H.sum(axis=0, keepdims=True)
            
            #Update the weights
            X = get_nonzero(H @ U)
            denom_U = np.sqrt(H.T @ X**-1)
            U = U * (np.sqrt(H.T @ (Y / X**2))) / get_nonzero(denom_U)
            
            #Compute the loss function
            X = get_nonzero(X)
            loss[i] = np.sum(Y / X - np.log(Y) + np.log(X) - 1)
    
    else:
        print("The deviation shold be either 'EU', 'KL', or 'IS'.")
        sys.exit()
    
    #Finish the progress bar
    bar = "#" * int(np.ceil(num_iter/unit))
    print("\rProgress:[{0}] {1}/{2} {3:.2f}sec Completed!".format(bar, i+1, num_iter, time.time()-start), end="")
    print()
    
    return H, U, loss

#■Function for generating Mel-scale filters■
def melFilterBank(Fs, fftsize, Mel_channel, Mel_norm, Amax):
    
    #Mel-frequency is proportional to "log(f/Mel_scale + 1)" [Default]700 or 1000
    Mel_scale = 700
    
    #Define Mel-scale parameter m0 based on "1000Mel = 1000Hz"
    m0 = 1000.0 / np.log(1000.0 / Mel_scale + 1.0)
    
    #Resolution of frequency
    df = Fs / fftsize
    
    #Define Nyquist frequency (the end point) as Hz, mel, and index scale
    Nyq = Fs / 2
    mel_Nyq = m0 * np.log(Nyq / Mel_scale + 1.0)
    n_Nyq = int(np.floor(fftsize / 2))+1
    
    #Calculate the Mel-scale interval between triangle-shaped structures
    #Divided by channel+1 because the termination is not the center of triangle but its right edge
    dmel = mel_Nyq / (Mel_channel + 1)
    
    #List up the center position of each triangle
    mel_center = np.arange(1, Mel_channel + 1) * dmel
    
    #Convert the center position into Hz-scale
    f_center = Mel_scale * (np.exp(mel_center / m0) - 1.0)
    
    #Define the center, start, and end position of triangle as index-scale
    n_center = np.round(f_center / df)
    n_start = np.hstack(([0], n_center[0 : Mel_channel - 1]))
    n_stop = np.hstack((n_center[1 : Mel_channel], [n_Nyq]))
    
    #Initial condition is defined as 0 padding matrix
    output = np.zeros((n_Nyq, Mel_channel))
    
    #Mel-scale filters are periodic triangle-shaped structures
    #Repeat every channel
    for c in np.arange(0, Mel_channel):
        
        #Slope of a triangle(growing slope)
        upslope = 1.0 / (n_center[c] - n_start[c])
        
        #Add a linear function passing through (nstart, 0) to output matrix 
        for x in np.arange(n_start[c], n_center[c]):
            #Add to output matrix
            x = int(x)
            output[x, c] = (x - n_start[c]) * upslope
        
        #Slope of a triangle(declining slope)
        dwslope = 1.0 / (n_stop[c] - n_center[c])
        
        #Add a linear function passing through (ncenter, 1) to output matrix 
        for x in np.arange(n_center[c], n_stop[c]):
            #Add to output matrix
            x = int(x)
            output[x, c] = 1.0 - ((x - n_center[c]) * dwslope)
        
        #Normalize area underneath each Mel-filter into 1
        #[Ref] T.Ganchev, N.Fakotakis, and G.Kokkinakis, Proc. of SPECOM 1, 191-194 (2005)
        #[URL] https://pdfs.semanticscholar.org/f4b9/8dbd75c87a86a8bf0d7e09e3ebbb63d14954.pdf
        if Mel_norm == True:
            output[:, c] = output[:, c] * 2 / (n_stop[c] - n_start[c])
    
    #Return Mel-scale filters as list (row=frequency, column=Mel channel)
    return output

#■Function for calculating MFCC■
def get_Melfeature(A, Fs, frame_length, frame_shift, Mel_channel, Mel_norm, MFCC_num, Amax, clu_mode):
    
    #Calculate the index of window size and overlap
    FL = round(frame_length * Fs)
    FS = round(frame_shift * Fs)
    OL = FL - FS
    
    #Call my function for generating Mel-scale filters(row: fftsize/2, column: Channel)
    filterbank = melFilterBank(Fs, FL, Mel_channel, Mel_norm, Amax)
    
    #Multiply the filters into the STFT amplitude
    melA = A.T @ filterbank
    
    #Normalization and get logarithm
    melA = melA * Amax / np.amax(melA)
    melA = np.log10(melA + 1) #Non-negative value
    
    #In the case of k-means clustering method
    if clu_mode == "kmeans":
        #Get the DCT coefficients (DCT: Discrete Cosine Transformation)
        output = fp.realtransforms.dct(melA, type=2, norm="ortho", axis=1)
        
        #Trim the MFCC features from C(0) to C(MFCC_num-1)
        output = np.array(output[:, 0:MFCC_num])
    
    #In the case of second NMF clustering method
    elif clu_mode == "2ndNMF":
        output = melA
    
    #Return MFCC or mel-spectrogram as (frames, order) numpy array
    return output

def get_clustering(H, num_clus,frame_length, frame_shift):
    ### Clustering step (to get label for each sound source) ###
    #In the case of k-means clustering

    #Call my function for getting MFCCs
    MFCC = get_Melfeature(H**2, Fs, frame_length, frame_shift, Mel_channel, Mel_norm, MFCC_num, Amax, clu_mode)

    #Normalize MFCC along with basements-axis
    MFCC = MFCC - np.average(MFCC, axis=0)[np.newaxis, :]
    MFCC = MFCC / np.std(MFCC, axis=0)[np.newaxis, :]

    #Get clustering by kmeans++
    clf = KMeans(n_clusters=num_clus, init='k-means++')
    labels = np.array(clf.fit(MFCC).labels_)

    return MFCC, labels

def cos_sim(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

def cos_sim_matrix(matrix):
    """
    item-feature 行列が与えられた際に
    item 間コサイン類似度行列を求める関数
    """
    d = matrix @ matrix.T  # item-vector 同士の内積を要素とする行列

    # コサイン類似度の分母に入れるための、各 item-vector の大きさの平方根
    norm = (matrix * matrix).sum(axis=1, keepdims=True) ** .5

    # それぞれの item の大きさの平方根で割っている（なんだかスマート！）
    return d / norm / norm.T

# 評価指標
def _scale_bss_eval(references, estimate, compute_sir_sar=True):
    """
    Helper for scale_bss_eval to avoid infinite recursion loop.
    """
    source = references
    source_energy = (source ** 2).sum()

    alpha = (
        source @ estimate / source_energy
    )

    e_true = source
    e_res = estimate - e_true

    signal = (e_true ** 2).sum()
    noise = (e_res ** 2).sum()

    snr = 10 * np.log10(signal / noise)

    e_true = source * alpha
    e_res = estimate - e_true

    signal = (e_true ** 2).sum()
    noise = (e_res ** 2).sum()

    si_sdr = 10 * np.log10(signal / noise)

    srr = -10 * np.log10((1 - (1/alpha)) ** 2)
    sd_sdr = snr + 10 * np.log10(alpha ** 2)

    si_sir = np.nan
    si_sar = np.nan

#     if compute_sir_sar:
#         references_projection = references.T @ references

#         references_onto_residual = np.dot(references.transpose(), e_res)
#         b = np.linalg.solve(references_projection, references_onto_residual)

#         e_interf = np.dot(references, b)
#         e_artif = e_res - e_interf

#         si_sir = 10 * np.log10(signal / (e_interf ** 2).sum())
#         si_sar = 10 * np.log10(signal / (e_artif ** 2).sum())

    #return si_sdr, si_sir, si_sar, sd_sdr, snr, srr
    return si_sdr


def preEmphasis(signal, p):
    """プリエンファシスフィルタ"""
    # 係数 (1.0, -p) のFIRフィルタを作成
    return sg.lfilter([1.0, -p], [1.0], signal)

def deeEmphasis(signal, p):
    """プリエンファシスフィルタ"""
    # 係数 (1.0, -p) のFIRフィルタを作成
    return sg.lfilter([1.0], [1.0, -p], signal)


# Spectrogram Class    
class SpecClass():
    
    def __init__(self, frame_length, frame_shift, data):
        self.frame_length = frame_length
        self.frame_shift = frame_shift
        self.Y, self.arg, self.Fs, self.freqs, self.times = get_STFT(data, Fs, self.frame_length, self.frame_shift)
        #Call my function for updating NMF basements and weights
        
        self.H, self.U, self.loss = get_NMF(self.Y, num_iter, num_base, loss_func, False)
        
        if label_flag: 
            self.MFCC, self.label = get_clustering(self.H, num_clus, self.frame_length, self.frame_shift)
        else:
            self.label, center, loss = get_KShape(self.U , num_clus, max_iter=100, num_init=10)
        
    def get_wav(self):
        
        label0 = self.label
        label1 = np.ones_like(label0) - label0

        sep_H = self.H * label0
        sep_X = self.Y * (sep_H @ self.U) / (self.H @ self.U)
        sep_wav0, Fs = get_invSTFT(sep_X, self.arg, self.Fs, self.frame_length, self.frame_shift)
        sep_wav0 = sep_wav0[: int(data.shape[0])] #inverse stft includes residual part due to zero padding
        
        sep_H = self.H * label1
        sep_X = self.Y * (sep_H @ self.U) / (self.H @ self.U)
        sep_wav1, Fs = get_invSTFT(sep_X, self.arg, self.Fs, self.frame_length, self.frame_shift)
        sep_wav1 = sep_wav1[: int(data.shape[0])] #inverse stft includes residual part due to zero padding

        return sep_wav0, sep_wav1

In [4]:
#MFCC clustering is according to "Source-filter based clustering for monaural blind source separation"
#[Ref] M. Spiertz and V. Gnann, (2009), Proc. International Conference on Digital Audio Effect
#[URL] http://dafx.de/paper-archive/2009/papers/paper_13.pdf

#Setup
down_sam = None        #Downsampling rate (Hz) [Default]None
num_iter = 100         #The number of iteration in NMF [Default]200
num_base = 20        #The number of basements in NMF [Default]20~30
loss_func = "KL"       #Select either EU, KL, or IS divergence [Default]KL
Mel_channel = 20       #The number of frequency channel for Mel-scale filters [Default]20
Mel_norm = True        #Normalize the area underneath each Mel-filter into 1 [Default]True
MFCC_num = 9           #The number of MFCCs including C(0) [Default]9
Amax = 1e4             #Normalization for log-Mel conversion [Default]1e4 (10000)
clu_mode = "2ndNMF"    #Clustering mode introduced by M. Spiertz's paper [Default]kmeans or 2ndNMF
clu_loss = "EU"        #Using 2ndNMF, select either EU, KL, or IS divergence [Default]EU
clu_iter = 100         #Using 2ndNMF, specify the number of iteration in 2nd NMF [Default]100
num_clus = 2           

#label_flag 0:MFCC-k-means 1:k-shape

#Define random seed
np.random.seed(seed=32)

init_mode = "nndsvda"       # None, 'random', 'nndsvd', 'nndsvda', 'nndsvdar'

num_clus = 2

dataset_path = "./mix_dataset/"

save_path = "./save_data-k-shape-only/"

class_pathes = ["glassbreak","gunshot", "babycry"]
sn_rates = ["-4.0", "-2.0", "0.0", "2.0", "4.0"]

In [6]:
for sn_rate in sn_rates:

    for class_path in class_pathes:

        path = dataset_path + "sn" + sn_rate + "/" + class_path + "/"
        data_files = os.listdir(path)

        sep_path = save_path + "sn" + sn_rate + "/separation/" + class_path + "/"
        sep_files_length = len(os.listdir(sep_path))
        sep_file_name =  "sep" + str(sep_files_length)

        os.mkdir( sep_path + sep_file_name) 

        log_path =  save_path + "sn" + sn_rate  + "/log/" + class_path + "/"
        log_files_length = len(os.listdir(log_path))
        log_file_name = "log" + str(log_files_length) + ".txt"

        open(log_path + log_file_name, "w")

        plot_path =  save_path + "sn" + sn_rate  + "/plot/" + class_path + "/"
        plot_files_length = len(os.listdir(plot_path))
        plot_file_name = "plot" + str(plot_files_length) + ".txt"

        open(plot_path + plot_file_name, "w")
    

        for index, file in enumerate(data_files):

            truth, Fs = sf.read(path + file + "/truth.wav")
            truth, Fs = pre_processing(truth, Fs, down_sam)                   

            data, Fs = sf.read(path + file + "/input.wav")
            data, Fs = pre_processing(data, Fs, down_sam)      
                
            label_flag = False
            org1 = SpecClass(0.024, 0.012, data)
            
            step2_wav0, step2_wav1 = org1.get_wav()

            with open(log_path + log_file_name, 'a') as f:

                print("---------------" + file + "---------------", file=f)
                print("" ,file=f)

                print("STEP2:" + "frame_length:" + str(org1.frame_length), file=f)
                print("" ,file=f)

                step2_sdr0 = _scale_bss_eval(truth, step2_wav0)
                step2_sdr1 = _scale_bss_eval(truth, step2_wav1)
                        
                sdr_list = [step2_sdr0, step2_sdr1]
                wav_list = [step2_wav0, step2_wav1]

                print(sdr_list.index(max(sdr_list)), file=f)
                sf.write(sep_path + sep_file_name + "/sep_" + file + "_sdr_" + '{:.2f}'.format(sdr_list[sdr_list.index(max(sdr_list))]) + ".wav", wav_list[sdr_list.index(max(sdr_list))], Fs)

                print("" ,file=f)
                print("" ,file=f)
                print("" ,file=f)

            with open(plot_path + plot_file_name, 'a') as f:
                print(max(sdr_list), file=f)

Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.45sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.46sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.54sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.54sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.60sec

Progress:[##########] 100/100 2.61sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.60sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.56sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.54sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) =

Progress:[##########] 100/100 2.56sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.63sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.56sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.82sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.54sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.54sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) =

Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.63sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.56sec Completed!
Spectrogram size (freq, time) =

Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.60sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) =

Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.61sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) =

Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.60sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.54sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) =

Progress:[##########] 100/100 2.55sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.57sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.56sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.62sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.61sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.58sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.61sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.59sec Completed!
Spectrogram size (freq, time) = (530, 835)
Progress:[##########] 100/100 2.60sec Completed!
Spectrogram size (freq, time) =