# Non-negative Matrix Factorization (NMF) for Speech and Noise Signal Separation

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import linalg as LA
import scipy.io.wavfile as wav
import librosa
from IPython.display import Audio


def STFT(x):
    hop = 512
    win = 1024
    D = librosa.stft(x, n_fft=1024, hop_length=hop, win_length=win, window='hann')
    return D


def W_update(W, H, X):
    ones = np.ones((len(W), len(H[0])))
    WH = np.dot(W, H)
    Ht = np.transpose(H)
    num1 = np.divide(X, WH + 1e-20)
    num = np.dot(num1, Ht)
    den = np.dot(ones, Ht)
    deltaW = np.divide(num, den + 1e-20)
    W = W * deltaW
    return W


def H_update(W, H, X):
    ones = np.ones((len(W), len(H[0])))
    WH = np.dot(W, H)
    Wt = np.transpose(W)
    num1 = np.divide(X, WH + 1e-20)
    num = np.dot(Wt, num1)
    den = np.dot(Wt, ones)
    deltaH = np.divide(num, den + 1e-20)
    H = H * deltaH
    return H


def cost_fun(X, X_hat):
    div = np.divide(X, X_hat)
    log_div = np.log(div)
    term1 = X * log_div
    E = term1 - X + X_hat
    return np.sum(E)


if __name__ == '__main__':
    trs, sr = librosa.load('/content/trs.wav', sr=None)
    trs_freq = STFT(trs)

    trn, sr = librosa.load('/content/trn.wav', sr=None)
    trn_freq = STFT(trn)

    trx, sr = librosa.load('/content/x_nmf.wav', sr=None)
    trx_freq = STFT(trx)

    S = np.abs(trs_freq)
    N = np.abs(trn_freq)
    Y = np.abs(trx_freq)

    WS = np.random.rand(len(S), 30)
    HS = np.random.rand(30, len(S[0]))
    WN = np.random.rand(len(N), 30)
    HN = np.random.rand(30, len(N[0]))
    HY = np.random.rand(60, len(Y[0]))

    tol = 1e-2
    max_iteration = 500

    # NMF on the Speech signal S

    for i in range(max_iteration):
        WS = W_update(WS, HS, S)
        HS = H_update(WS, HS, S)
        S_hat = np.dot(WS, HS)
        error = cost_fun(S, S_hat)
        if error < tol:
            print(i, "break")
            break

    # NMF on the Noise signal N

    for i in range(max_iteration):
        WN = W_update(WN, HN, N)
        HN = H_update(WN, HN, N)
        N_hat = np.dot(WN, HN)
        error1 = cost_fun(N, N_hat)
        if error1 < tol:
            print(i, "break")
            break

    # NMF on Mixture of Signal Y

    WY = np.concatenate((WS, WN), axis=1)
    for i in range(max_iteration):
        HY = H_update(WY, HY, Y)
        Y_hat = np.dot(WY, HY)
        error2 = cost_fun(Y, Y_hat)
        if error2 < tol:
            print(i, "break")
            break

    M_bar_num = np.dot(WS, HY[0:30, :])
    M_bar_den = np.dot(WY, HY)
    M_bar = np.divide(M_bar_num, M_bar_den + 1e-20)

    result = M_bar * trx_freq

    result_time = librosa.istft(result, hop_length=512, win_length=1024, window='hann')

    print("\n Noisy Signal Signal \n")
    display(Audio(data=trx, rate=sr))

    print("\n Clean Signal \n")
    display(Audio(data=result_time, rate=sr))


 Noisy Signal Signal 




 Clean Signal 

