In [3]:
import torch
import numpy as np
from scipy.io import wavfile
import librosa
from torchaudio.transforms import Spectrogram, InverseSpectrogram
import os

In [4]:
model = torch.jit.load("D:/Projects/LCT-GAN/.data/FTFNet_scripted.pt")
model.eval()
device = torch.device("cpu")

In [5]:
# STFT & iSTFT settings
framelen = 512
hoplen = 256
win = lambda x: torch.sqrt(torch.hann_window(x)).to(device) # sqrt-hann window
to_spec = Spectrogram(n_fft=framelen, hop_length=hoplen, power=None, window_fn=win)
from_spec = InverseSpectrogram(n_fft=framelen, hop_length=hoplen, window_fn=win)

In [6]:
def infer(input_audio, output_audio):
    x, _ = librosa.load(input_audio, sr=16000)
    x = torch.tensor(x.astype(np.float32)).to(device).unsqueeze(0)
    inputs = to_spec(x).permute(0,2,1).cfloat() # from waveform to spectrogram (channel, T, F)

    outputs = model(inputs)
    
    xtilde = from_spec(outputs)
    xtilde = xtilde.cpu().detach().numpy() # convert the torch.tensor back to numpy array
    xtilde = (xtilde * (2 ** 15)).astype(np.int16)
    wavfile.write(output_audio, 16000, xtilde)


In [None]:
input_audio_files = [
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/impulse/noisy_fileid_1_snr14.25_tl-23.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/music/noisy_fileid_4_snr14.07_tl-24.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/roadside/noisy_fileid_0.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/static1/noisy_fileid_0_snr-3.34_tl-17.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/static2/noisy_fileid_1_snr14.25_tl-23.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/water/noisy_fileid_2_snr-2.39_tl-30.wav",
]

output_audio_files = [
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/impulse/enhanced_fileid_1_snr14.25_tl-23.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/music/enhanced_fileid_4_snr14.07_tl-24.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/roadside/enhanced_fileid_0.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/static1/enhanced_fileid_0_snr-3.34_tl-17.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/static2/enhanced_fileid_1_snr14.25_tl-23.wav",
    "D:/Projects/LCT-GAN/.data/subjective_test_audios/water/enhanced_fileid_2_snr-2.39_tl-30.wav",
]

for input_audio, output_audio in zip(input_audio_files, output_audio_files):
    infer(input_audio, output_audio)