In [1]:
import os
import re
import numpy as np
import torch
import pickle
from pathlib import Path
from scipy.io import wavfile
import speechbrain as sb
from speechbrain.dataio.dataio import read_audio
from IPython.display import Audio
from scipy.io.wavfile import write as wavwrite
from tqdm.notebook import tqdm, trange

In [2]:
from speechbrain.pretrained import SepformerSeparation as separator
import torchaudio
from lared_dataset.constants import (processed_audio_path)


In [3]:
fs = 8000
valid_segments_ms = pickle.load(open('./valid_audio_segments.pkl', 'rb'))
valid_segments_ms = [el[1] for el in valid_segments_ms]
valid_segments_ms = {el[0]: (el[1], el[2]) for el in valid_segments_ms}

In [4]:
valid_segments_ms

{1: (0, 6652313),
 2: (0, 6082616),
 3: (0, 6130061),
 4: (817914, 6053613),
 5: (772051, 6430197),
 7: (784004, 8033955),
 35: (742437, 5647632),
 9: (0, 8572196),
 10: (0, 5610812),
 11: (1286398, 5908218),
 12: (807642, 4985841),
 13: (1387570, 5626548),
 14: (721227, 5936410),
 15: (678056, 6131083),
 16: (1442156, 2771749),
 45: (4288702, 5753146),
 17: (755955, 5935093),
 18: (768054, 5672919),
 19: (690662, 2739184),
 20: (751756, 8465912),
 21: (0, 9769708),
 22: (897660, 5328800),
 23: (926244, 8480968),
 24: (1161270, 8220734),
 25: (1359345, 7725473),
 26: (991109, 5510171),
 27: (2979805, 5468341),
 29: (3157766, 6973012),
 30: (3576005, 8426502),
 31: (4123547, 5398823),
 32: (4216462, 8415540),
 33: (4213314, 6638696),
 34: (3127441, 6328104)}

In [5]:
run_opts = {"device": "cuda"}
model = separator.from_hparams(source="speechbrain/sepformer-whamr-enhancement", savedir='pretrained_models/sepformer-whamr-enhancement', run_opts=run_opts)

audios_path = os.path.join(processed_audio_path, 'normalized')
output_path = os.path.join(processed_audio_path, 'denoised')
audio_path = os.path.join(processed_audio_path, 'samples/7.wav')

In [6]:
def denoise_file(fpath, window_len=60):
    batch, fs_file = torchaudio.load(fpath)
    pid = int(Path(fpath).stem)
    pid_valid_seg = valid_segments_ms[pid]
    batch = batch.to(model.device)
    fs_model = model.hparams.sample_rate

    tf = torchaudio.transforms.Resample(
        orig_freq=fs_file, new_freq=fs_model
    ).to(model.device)
    batch = batch.mean(dim=0, keepdim=True)
    batch = tf(batch)

    all = []
    # cnt = 0
    for ini_time in trange(0, batch.shape[1] // fs, window_len):
        end_time = ini_time + window_len

        if end_time < pid_valid_seg[0]/1000 or ini_time > pid_valid_seg[1]/1000:
            # segment is outside of valid seg
            all.append(
                torch.zeros((1, window_len * fs))
            )
        else:
            ini = ini_time*fs
            end = end_time*fs

            all.append(
                model.separate_batch(batch[:, ini: end]).cpu().squeeze(dim=2)
            )
        # cnt += 1
        # if cnt == 15:
        #     break

    res = torch.cat(all, axis=1)
    res = (
        res / res.abs().max(dim=1, keepdim=True)[0]
    )
    return res.numpy()

In [7]:
def denoise_and_store(fpath, outpath):
    res = denoise_file(fpath).transpose()
    wavwrite(outpath, fs, res)

In [8]:
for fpath in [f for f in Path(audios_path).glob('*.wav') if f.stem in ['16', '45']]:
    print(fpath)
    out_path = os.path.join(output_path, fpath.name)
    denoise_and_store(fpath, out_path)

/mnt/e/data/lared/processed/audio/normalized/16.wav


  0%|          | 0/165 [00:00<?, ?it/s]

/mnt/e/data/lared/processed/audio/normalized/45.wav


  0%|          | 0/165 [00:00<?, ?it/s]