In [7]:
import onnxruntime
import numpy as np
import tqdm

import glob
import soundfile as sf

import librosa
import scipy

import os

In [3]:
# Load model
so = onnxruntime.SessionOptions()
so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
ort_session = onnxruntime.InferenceSession("nsnetv2_converted.onnx", so)

In [4]:
# Feature extraction
def stft(signal, frame_size = 160, window_size = 320):
        window = np.sqrt(np.hanning(window_size + 1)[:-1]).astype(np.float32)

        last_frame = len(signal) % frame_size
        if last_frame == 0:
            last_frame = frame_size

        padded_signal = np.pad(signal, ((window_size - frame_size, window_size - last_frame),))
        frames = librosa.util.frame(padded_signal, len(window), frame_size, axis=0)
        spec = scipy.fft.rfft(frames * window, n=window_size)
        return spec                  

def istft(signal, frame_size = 160, window_size = 320):
    window = np.sqrt(np.hanning(window_size + 1)[:-1]).astype(np.float32)
    frames = scipy.fft.irfft(signal, axis=-1)
    
    # crop frames if dft_size is larger than window_size
    frames = frames[:, :window_size] * window

    n_parts = window_size // frame_size

    assert frames.shape[0] >= n_parts

    target = frames[n_parts - 1:, :frame_size].copy()
    for n in range(1, n_parts):
        offset = n * frame_size
        target += frames[n_parts - 1 - n:-n, offset:offset + frame_size]

    # flatten the result
    target.shape = target.size,
    return target    
  
def logpow_msrtc(sig):
    pspec = np.maximum(sig**2, 1e-12)
    return np.log10(pspec)    
    
def build_features_logspec_plcmask(signal, is_lost):
    signal_stft = stft(signal)
    feat = np.abs(signal_stft)
    feat_logpow = logpow_msrtc(feat)
    
    feat_angle = np.angle(signal_stft)
    feat_phasor = np.stack([np.sin(feat_angle), np.cos(feat_angle)], axis=-1)  

    # packet loss indicator mask
    is_lost_frame_arr = np.repeat(is_lost, 2)
    num_freqs = feat_logpow.shape[1]

    def arr_to_mask(arr):
        return np.repeat(np.expand_dims(arr, 1), num_freqs, axis=1)

    left_mask = arr_to_mask(np.append(0, is_lost_frame_arr))
    right_mask = arr_to_mask(np.append(is_lost_frame_arr, 0))

    return [feat_logpow, left_mask, right_mask, feat_phasor[:, :, 0], feat_phasor[:, :, 1]], [feat, feat_phasor[:, :, 0], feat_phasor[:, :, 1]]

In [10]:
# Process data
wav_files = glob.glob(r"blind\lossy_signals\*.wav")
out_dir = "blind_out"
os.makedirs(out_dir, exist_ok = True)

for file in tqdm.tqdm(wav_files):
    data, _ = sf.read(file)
    lost_file = file.split(".wav")[0] + "_is_lost.txt"
    loss_mask = np.loadtxt(lost_file)
        
    feats, feats_recon = build_features_logspec_plcmask(data, loss_mask)
    feats = np.array(feats).swapaxes(0, 1)
    feats = feats.reshape(feats.shape[0], -1)
    
    feats_recon = np.array(feats_recon).swapaxes(0, 1)
    feats_recon = feats_recon.reshape(feats_recon.shape[0], -1)
    
    h0 = np.zeros((1, 1, 134))
    h1 = np.zeros((1, 1, 100))
    result = []
    for idx, feat_row in enumerate(feats):
        feat_row = feat_row.reshape(1, 1, -1)
        ort_inputs = {"input": feat_row.astype(np.float32), "h01": h0.astype(np.float32), "h02": h1.astype(np.float32)}
        y, h0, h1 = ort_session.run(None, ort_inputs)
        if idx // 2 >= len(loss_mask) or loss_mask[idx // 2] == 0:
            y = feats_recon[idx, :].reshape(y.shape)
        else:
            y = y
        y_abs, y_sin, y_cos = np.split(y, 3, axis=1)
        y_complex = y_abs * (y_cos + 1j * y_sin)
        result.append(y_complex)
    result = istft(np.array(result).squeeze())
    sf.write(os.path.join(out_dir, os.path.basename(file)), result, 16000)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 966/966 [02:48<00:00,  5.74it/s]
