In [1]:
import os
import numpy as np
import torch
import pickle
from pathlib import Path
from scipy.io import wavfile
import speechbrain as sb
from speechbrain.dataio.dataio import read_audio
from IPython.display import Audio
from scipy.io.wavfile import write as wavwrite
from tqdm.notebook import tqdm, trange

In [2]:
from speechbrain.pretrained import SepformerSeparation as separator
import torchaudio
from lared_dataset.constants import (processed_audio_path)


In [3]:
fs = 8000
valid_segments_ms = pickle.load(open('./valid_audio_segments.pkl', 'rb'))
valid_segments_ms = [el[1] for el in valid_segments_ms]
valid_segments_ms = {el[0]: (el[1], el[2]) for el in valid_segments_ms}

In [4]:
run_opts = {"device": "cuda"}
model = separator.from_hparams(source="speechbrain/sepformer-whamr-enhancement", savedir='pretrained_models/sepformer-whamr-enhancement', run_opts=run_opts)

audios_path = os.path.join(processed_audio_path, 'normalized')
output_path = os.path.join(processed_audio_path, 'denoised')
audio_path = os.path.join(processed_audio_path, 'samples/7.wav')

In [5]:
def denoise_file(fpath, window_len=60):
    batch, fs_file = torchaudio.load(fpath)
    pid = int(Path(fpath).stem)
    pid_valid_seg = valid_segments_ms[pid]
    batch = batch.to(model.device)
    fs_model = model.hparams.sample_rate

    tf = torchaudio.transforms.Resample(
        orig_freq=fs_file, new_freq=fs_model
    ).to(model.device)
    batch = batch.mean(dim=0, keepdim=True)
    batch = tf(batch)

    all = []
    # cnt = 0
    for ini_time in trange(0, batch.shape[1] // fs, window_len):
        end_time = ini_time + window_len

        if end_time < pid_valid_seg[0]/1000 or ini_time > pid_valid_seg[1]/1000:
            # segment is outside of valid seg
            all.append(
                torch.zeros((1, window_len * fs))
            )
        else:
            ini = ini_time*fs
            end = end_time*fs

            all.append(
                model.separate_batch(batch[:, ini: end]).cpu().squeeze(dim=2)
            )
        # cnt += 1
        # if cnt == 15:
        #     break

    res = torch.cat(all, axis=1)
    res = (
        res / res.abs().max(dim=1, keepdim=True)[0]
    )
    return res.numpy()

In [6]:
def denoise_and_store(fpath, outpath):
    res = denoise_file(fpath).transpose()
    wavwrite(outpath, fs, res)

In [7]:
for fpath in Path(audios_path).glob('*.wav'):
    print(fpath)
    out_path = os.path.join(output_path, fpath.name)
    denoise_and_store(fpath, out_path)

/mnt/e/data/lared/processed/audio/normalized/1.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/13.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/14.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/15.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/20.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/21.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/25.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/26.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/29.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/30.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/32.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/4.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/45.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/9.wav


  0%|          | 0/165 [00:00<?, ?it/s]