In [None]:
import torch
import torchaudio
from IPython.display import Audio
import causal_improved_sudormrf_v3
import time as t
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Function that loads the multi-GPU saved SudorRMRF_causal_v3 weights for CPU inference
def load_sudormrf_causal_cpu(model_path):
    # 1: declarem el model (instanciem la classe)
    model = causal_improved_sudormrf_v3.CausalSuDORMRF(
        in_audio_channels=1,
        out_channels=512,
        in_channels=256,
        num_blocks=16,
        upsampling_depth=5,
        enc_kernel_size=21,
        enc_num_basis=512,
        num_sources=1,
        )
    # 2: el passem a DataParallel perquè es va guardar com un DataParallel
    model = torch.nn.DataParallel(model)
    # 3: carreguem els pesos
    model.load_state_dict(torch.load(model_path))
    # 4: El pasem a GPU. Tu podries provar torch.device("mps") que seria la teva GPU
    device = torch.device("cpu")
    model = model.module.to(device)
    # 5: posem en mode Evaluació (es desactiva dropout i coses així)
    model.eval()
    return model

In [None]:
model_path = 'e39_sudo_whamr_16k_enhnoisy_augment.pt'
model = load_sudormrf_causal_cpu(model_path)

In [None]:
# carreguem speech i noise, fem una mixture:
speech, fs = torchaudio.load('speech.wav')
noise, fs = torchaudio.load('noise.wav')
mixture = speech + noise
mixture /= torch.max(torch.abs(mixture))
plt.plot(mixture.numpy()[0])

In [None]:
# guardem l'energia de la mixture per poder normalitzar la sortida del model
ini_nrg = torch.sum(mixture ** 2)

In [None]:
Audio(mixture.numpy(), rate=fs)


In [None]:
# standarditzem la mixture (normalitzacr)
mixture = (mixture - torch.mean(mixture)) / torch.std(mixture)

In [None]:
# apliquem el model
denoised = model(mixture.unsqueeze(0)).detach()

In [None]:
# de-normalitzem
denoised /= torch.sqrt(torch.sum(denoised ** 2) / ini_nrg)

In [None]:
Audio(denoised.numpy()[0,0], rate=fs)


In [None]:
# guardem el .wav
torchaudio.save('denoised.wav', denoised[0], sample_rate=fs)