# Expectation-Maximization Clustering with MRF Smoothing

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import linalg as LA
import scipy.io.wavfile as wav
import librosa
from IPython.display import Audio


def STFT(x):
    hop = 512
    win = 1024
    D = librosa.stft(x, n_fft=1024, hop_length=hop, win_length=win, window='hann')
    return D


def B_update(B, T, X):
    ones = np.ones((len(B), len(B)))
    BT = np.dot(B, T)
    Tt = np.transpose(T)
    num1 = np.divide(X, BT + 1e-20)
    num = np.dot(num1, Tt)
    B = B * num

    den = np.dot(ones, B)
    B = np.divide(B, den + 1e-20)

    return B


def T_update(B, T, X):
    ones = np.ones((len(B[0]), len(B[0])))
    BT = np.dot(B, T)
    Bt = np.transpose(B)
    num1 = np.divide(X, BT + 1e-20)
    num = np.dot(Bt, num1)
    T = num * T

    den = np.dot(ones, T)
    T = np.divide(T, den + 1e-20)

    return T


def cost_fun(X, X_hat):
    div = np.divide(X, X_hat)
    log_div = np.log(div)
    term1 = X * log_div
    E = term1 - X + X_hat
    return np.sum(E)


if __name__ == '__main__':
    trs, sr = librosa.load('/content/trs.wav', sr=None)
    trs_freq = STFT(trs)

    trn, sr = librosa.load('/content/trn.wav', sr=None)
    trn_freq = STFT(trn)

    trx, sr = librosa.load('/content/tex.wav', sr=None)
    trx_freq = STFT(trx)

    S = np.abs(trs_freq)
    N = np.abs(trn_freq)
    Y = np.abs(trx_freq)

    seed_value = 1  # pick any value here
    rng = np.random.default_rng(seed_value)
    lower_bound = 0.0
    upper_bound = 1.0
    BS = rng.uniform(lower_bound, upper_bound, size=(len(S), 30))
    TS = rng.uniform(lower_bound, upper_bound, size=(30, len(S[0])))
    BN = rng.uniform(lower_bound, upper_bound, size=(len(N), 30))
    TN = rng.uniform(lower_bound, upper_bound, size=(30, len(N[0])))
    TY = rng.uniform(lower_bound, upper_bound, size=(60, len(Y[0])))

    tol = 1e-2
    max_iteration = 500

    for i in range(max_iteration):
        BS = B_update(BS, TS, S)
        TS = T_update(BS, TS, S)
        S_hat = np.dot(BS, TS)
        error = cost_fun(S, S_hat)
        if error < tol:
            break

    for i in range(max_iteration):
        BN = B_update(BN, TN, N)
        TN = T_update(BN, TN, N)
        N_hat = np.dot(BN, TN)
        error1 = cost_fun(N, N_hat)
        if error1 < tol:
            break

    BY = np.concatenate((BS, BN), axis=1)
    for i in range(max_iteration):
        TY = T_update(BY, TY, Y)
        Y_hat = np.dot(BY, TY)
        error2 = cost_fun(Y, Y_hat)
        if error2 < tol:
            break

    M_bar_num = np.dot(BS, TY[0:30, :])
    M_bar_den = np.dot(BY, TY)
    M_bar = np.divide(M_bar_num, M_bar_den + 1e-20)

    result = M_bar * trx_freq

    result_time = librosa.istft(result, n_fft=1024, hop_length=512, win_length=1024, window='hann')

    # wav.write("Q1.wav", sr, result_time)

    ### SNR
    s_hat = result_time

    ts = trx[0:len(s_hat)]

    num = np.dot(ts.T, ts)
    den = np.dot((ts - s_hat).T, (ts - s_hat))

    SNR = 10 * np.log(num / den)

    print("\n Noisy Signal \n")
    display(Audio(data=trx, rate=sr))

    print("\n Clean Signal \n")
    display(Audio(data=result_time, rate=sr))

    print("\n SNR Value obtained is: ", SNR)



 Noisy Signal 




 Clean Signal 




 SNR Value obtained is:  13.194795352615472
