# Metrics

In [1]:
path_to_dataset = '/home/vladimir/PycharmProjects/ASR/data/datasets/librispeech/train-clean-100'

In [1]:
import os
import glob
from glob import glob

import numpy as np

import random

import librosa
import soundfile as sf
import pyloudnorm as pyln
import matplotlib.pyplot as plt

import torch
from torchmetrics.audio import SignalDistortionRatio, ScaleInvariantSignalDistortionRatio
from IPython import display

from concurrent.futures import ProcessPoolExecutor
import warnings
warnings.filterwarnings("ignore")

There are three main scores to evaluate source to noise ratio:

- SNR(Signal-To-Noise Ratio).
\begin{gather*}
    SNR(s, \hat{s}) = 10 \log_{10}{\|s\|^2 \over \|s - \hat{s}\|^2}
\end{gather*}

- SDR(Signal-To-Distortion Ratio, https://inria.hal.science/inria-00544230/document).
For decomposition
\begin{gather*}
    \hat{s} = s + \varepsilon^{\rm{spat}} + \varepsilon^{\rm{interf}} + \varepsilon^{\rm{artif}}
\end{gather*}
the SDR score looks like
\begin{gather*}
    SDR(\hat{s}, s) = 10 \log_{10}{\|\hat{s}\|^2 \over \|\varepsilon^{\rm{spat}} + \varepsilon^{\rm{artif}}\|^2}
\end{gather*}


- SI-SDR(Scale-Invariant Signal-To-Distortion Ratio, https://browse.arxiv.org/pdf/1811.02508.pdf).
\begin{gather*}
    SI\text{-}SDR(\hat{s}, s) = 10 \log_{10}{\|{\hat{s}^T s \over \|s\|^2}s\|^2 \over \|{\hat{s}^T s \over \|s\|^2}s - \hat{s}\|^2}
\end{gather*}

# Audios Mixer

In [2]:
def snr_mixer(clean, noise, snr):
    amp_noise = np.linalg.norm(clean) / 10**(snr / 20)
    noise_norm = (noise / np.linalg.norm(noise)) * amp_noise
    mix = clean + noise_norm
    return mix


def vad_merge(w, top_db):
    intervals = librosa.effects.split(w, top_db=top_db)
    temp = list()
    for s, e in intervals:
        temp.append(w[s:e])
    return np.concatenate(temp, axis=None)


def cut_audios(s1, s2, sec, sr):
    cut_len = sr * sec
    len1 = len(s1)
    len2 = len(s2)

    s1_cut = []
    s2_cut = []

    segment = 0
    while (segment + 1) * cut_len < len1 and (segment + 1) * cut_len < len2:
        s1_cut.append(s1[segment * cut_len:(segment + 1) * cut_len])
        s2_cut.append(s2[segment * cut_len:(segment + 1) * cut_len])

        segment += 1

    return s1_cut, s2_cut


def fix_length(s1, s2, min_or_max='max'):
    # Fix length
    if min_or_max == 'min':
        utt_len = np.minimum(len(s1), len(s2))
        s1 = s1[:utt_len]
        s2 = s2[:utt_len]
    else:  # max
        utt_len = np.maximum(len(s1), len(s2))
        s1 = np.append(s1, np.zeros(utt_len - len(s1)))
        s2 = np.append(s2, np.zeros(utt_len - len(s2)))
    return s1, s2

In [3]:
def create_mix(idx, triplet, snr_levels, out_dir, test=False, sr=16000, **kwargs):
    trim_db, vad_db = kwargs["trim_db"], kwargs["vad_db"]
    audioLen = kwargs["audioLen"]

    s1_path = triplet["target"]
    s2_path = triplet["noise"]
    ref_path = triplet["reference"]
    target_id = triplet["target_id"]
    noise_id = triplet["noise_id"]

    s1, _ = sf.read(os.path.join('', s1_path))
    s2, _ = sf.read(os.path.join('', s2_path))
    ref, _ = sf.read(os.path.join('', ref_path))

    meter = pyln.Meter(sr) # create BS.1770 meter

    louds1 = meter.integrated_loudness(s1)
    louds2 = meter.integrated_loudness(s2)
    loudsRef = meter.integrated_loudness(ref)

    s1Norm = pyln.normalize.loudness(s1, louds1, -29)
    s2Norm = pyln.normalize.loudness(s2, louds2, -29)
    refNorm = pyln.normalize.loudness(ref, loudsRef, -23.0)

    amp_s1 = np.max(np.abs(s1Norm))
    amp_s2 = np.max(np.abs(s2Norm))
    amp_ref = np.max(np.abs(refNorm))

    if amp_s1 == 0 or amp_s2 == 0 or amp_ref == 0:
        return

    if trim_db:
        ref, _ = librosa.effects.trim(refNorm, top_db=trim_db)
        s1, _ = librosa.effects.trim(s1Norm, top_db=trim_db)
        s2, _ = librosa.effects.trim(s2Norm, top_db=trim_db)

    if len(ref) < sr:
        return

    # path_mix = os.path.join(out_dir, f"{target_id}_{noise_id}_" + "%06d" % idx + "-mixed.wav")
    # path_target = os.path.join(out_dir, f"{target_id}_{noise_id}_" + "%06d" % idx + "-target.wav")
    # path_ref = os.path.join(out_dir, f"{target_id}_{noise_id}_" + "%06d" % idx + "-ref.wav")
    speaker_dir = os.path.join(out_dir, target_id)
    if not os.path.exists(speaker_dir):
        os.makedirs(speaker_dir)

    # Изменяем пути для сохранения файлов
    path_mix = os.path.join(speaker_dir, f"{target_id}_{noise_id}_" + "%06d" % idx + "-mixed.wav")
    path_target = os.path.join(speaker_dir, f"{target_id}_{noise_id}_" + "%06d" % idx + "-target.wav")
    path_ref = os.path.join(speaker_dir, f"{target_id}_{noise_id}_" + "%06d" % idx + "-ref.wav")
    snr = np.random.choice(snr_levels, 1).item()

    if not test:
        s1, s2 = vad_merge(s1, vad_db), vad_merge(s2, vad_db)
        s1_cut, s2_cut = cut_audios(s1, s2, audioLen, sr)

        for i in range(len(s1_cut)):
            mix = snr_mixer(s1_cut[i], s2_cut[i], snr)

            louds1 = meter.integrated_loudness(s1_cut[i])
            s1_cut[i] = pyln.normalize.loudness(s1_cut[i], louds1, -23.0)
            loudMix = meter.integrated_loudness(mix)
            mix = pyln.normalize.loudness(mix, loudMix, -23.0)

            path_mix_i = path_mix.replace("-mixed.wav", f"_{i}-mixed.wav")
            path_target_i = path_target.replace("-target.wav", f"_{i}-target.wav")
            path_ref_i = path_ref.replace("-ref.wav", f"_{i}-ref.wav")
            sf.write(path_mix_i, mix, sr)
            sf.write(path_target_i, s1_cut[i], sr)
            sf.write(path_ref_i, ref, sr)
    else:
        s1, s2 = fix_length(s1, s2, 'max')
        mix = snr_mixer(s1, s2, snr)
        louds1 = meter.integrated_loudness(s1)
        s1 = pyln.normalize.loudness(s1, louds1, -23.0)

        loudMix = meter.integrated_loudness(mix)
        mix = pyln.normalize.loudness(mix, loudMix, -23.0)

        sf.write(path_mix, mix, sr)
        sf.write(path_target, s1, sr)
        sf.write(path_ref, ref, sr)

In [4]:
class LibriSpeechSpeakerFiles:
    def __init__(self, speaker_id, audios_dir, audioTemplate="*-norm.wav"):
        self.id = speaker_id
        self.files = []
        self.audioTemplate=audioTemplate
        self.files = self.find_files_by_worker(audios_dir)

    def find_files_by_worker(self, audios_dir):
        speakerDir = os.path.join(audios_dir,self.id) #it is a string
        chapterDirs = os.scandir(speakerDir)
        files=[]
        for chapterDir in chapterDirs:
            files = files + [file for file in glob(os.path.join(speakerDir,chapterDir.name)+"/"+self.audioTemplate)]
        return files

In [5]:
class MixtureGenerator:
    def __init__(self, speakers_files, out_folder, nfiles=5000, test=False, randomState=42):
        self.speakers_files = speakers_files # list of SpeakerFiles for every speaker_id
        self.nfiles = nfiles
        self.randomState = randomState
        self.out_folder = out_folder
        self.test = test
        random.seed(self.randomState)
        if not os.path.exists(self.out_folder):
            os.makedirs(self.out_folder)

    def generate_triplets(self):
        i = 0
        all_triplets = {"reference": [], "target": [], "noise": [], "target_id": [], "noise_id": []}
        while i < self.nfiles:
            spk1, spk2 = random.sample(self.speakers_files, 2)

            if len(spk1.files) < 2 or len(spk2.files) < 2:
                continue

            target, reference = random.sample(spk1.files, 2)
            noise = random.choice(spk2.files)
            all_triplets["reference"].append(reference)
            all_triplets["target"].append(target)
            all_triplets["noise"].append(noise)
            all_triplets["target_id"].append(spk1.id)
            all_triplets["noise_id"].append(spk2.id)
            i += 1

        return all_triplets

    def triplet_generator(self, target_speaker, noise_speaker, number_of_triplets):
        max_num_triplets = min(len(target_speaker.files), len(noise_speaker.files))
        number_of_triplets = min(max_num_triplets, number_of_triplets)

        target_samples = random.sample(target_speaker.files, k=number_of_triplets)
        reference_samples = random.sample(target_speaker.files, k=number_of_triplets)
        noise_samples = random.sample(noise_speaker.files, k=number_of_triplets)

        triplets = {"reference": [], "target": [], "noise": [],
                    "target_id": [target_speaker.id] * number_of_triplets, "noise_id": [noise_speaker.id] * number_of_triplets}
        triplets["target"] += target_samples
        triplets["reference"] += reference_samples
        triplets["noise"] += noise_samples

        return triplets

    def generate_mixes(self, snr_levels=[0], num_workers=10, update_steps=10, **kwargs):

        triplets = self.generate_triplets()

        with ProcessPoolExecutor(max_workers=num_workers) as pool:
            futures = []

            for i in range(self.nfiles):
                triplet = {"reference": triplets["reference"][i],
                           "target": triplets["target"][i],
                           "noise": triplets["noise"][i],
                           "target_id": triplets["target_id"][i],
                           "noise_id": triplets["noise_id"][i]}

                futures.append(pool.submit(create_mix, i, triplet,
                                           snr_levels, self.out_folder,
                                           test=self.test, **kwargs))

            for i, future in enumerate(futures):
                future.result()
                if (i + 1) % max(self.nfiles // update_steps, 1) == 0:
                    print(f"Files Processed | {i + 1} out of {self.nfiles}")

In [6]:
path_train = '/home/vladimir/PycharmProjects/TTS/data/datasets/librispeech/train-clean-100'
path_val = '/home/vladimir/PycharmProjects/TTS/data/datasets/librispeech/test-clean'

path_mixtures_train = '/home/vladimir/PycharmProjects/TTS/temp_datasets/train_clean_big'
path_mixtures_val = '/home/vladimir/PycharmProjects/TTS/temp_datasets/test'

speakersTrain = [el.name for el in os.scandir(path_train) if len(el.name) <= 3]
speakersVal = [el.name for el in os.scandir(path_val) if len(el.name) <= 3]

speakers_files_train = [LibriSpeechSpeakerFiles(i, path_train, audioTemplate="*.flac") for i in speakersTrain]
speakers_files_val = [LibriSpeechSpeakerFiles(i, path_val, audioTemplate="*.flac") for i in speakersVal]

mixer_train = MixtureGenerator(speakers_files_train,
                                path_mixtures_train,
                                nfiles=5000,
                                test=False)

mixer_val = MixtureGenerator(speakers_files_val,
                                path_mixtures_val,
                                nfiles=200,
                                test=True)

In [7]:
mixer_train.generate_mixes(snr_levels=[-5, 5],
                           num_workers=4,
                           update_steps=100,
                           trim_db=20,
                           vad_db=20,
                           audioLen=3)

Files Processed | 50 out of 5000
Files Processed | 100 out of 5000
Files Processed | 150 out of 5000
Files Processed | 200 out of 5000
Files Processed | 250 out of 5000
Files Processed | 300 out of 5000
Files Processed | 350 out of 5000
Files Processed | 400 out of 5000
Files Processed | 450 out of 5000
Files Processed | 500 out of 5000
Files Processed | 550 out of 5000
Files Processed | 600 out of 5000
Files Processed | 650 out of 5000
Files Processed | 700 out of 5000
Files Processed | 750 out of 5000
Files Processed | 800 out of 5000
Files Processed | 850 out of 5000
Files Processed | 900 out of 5000
Files Processed | 950 out of 5000
Files Processed | 1000 out of 5000
Files Processed | 1050 out of 5000
Files Processed | 1100 out of 5000
Files Processed | 1150 out of 5000
Files Processed | 1200 out of 5000
Files Processed | 1250 out of 5000
Files Processed | 1300 out of 5000
Files Processed | 1350 out of 5000
Files Processed | 1400 out of 5000
Files Processed | 1450 out of 5000
Files

In [10]:
mixer_val.generate_mixes(snr_levels=[-5, 5],
                           num_workers=2,
                           update_steps=100,
                           trim_db=None,
                           vad_db=20,
                           audioLen=3)

Files Processed | 2 out of 200
Files Processed | 4 out of 200
Files Processed | 6 out of 200
Files Processed | 8 out of 200
Files Processed | 10 out of 200
Files Processed | 12 out of 200
Files Processed | 14 out of 200
Files Processed | 16 out of 200
Files Processed | 18 out of 200
Files Processed | 20 out of 200
Files Processed | 22 out of 200
Files Processed | 24 out of 200
Files Processed | 26 out of 200
Files Processed | 28 out of 200
Files Processed | 30 out of 200
Files Processed | 32 out of 200
Files Processed | 34 out of 200
Files Processed | 36 out of 200
Files Processed | 38 out of 200
Files Processed | 40 out of 200
Files Processed | 42 out of 200
Files Processed | 44 out of 200
Files Processed | 46 out of 200
Files Processed | 48 out of 200
Files Processed | 50 out of 200
Files Processed | 52 out of 200
Files Processed | 54 out of 200
Files Processed | 56 out of 200
Files Processed | 58 out of 200
Files Processed | 60 out of 200
Files Processed | 62 out of 200
Files Proces

In [21]:
train_mixes = os.listdir(path_mixtures_train)
val_mixes = os.listdir(path_mixtures_val)

In [22]:
train_mixes[:100]

['1673_2035_000000_0-mixed.wav',
 '1673_2035_000000_0-target.wav',
 '1673_2035_000000_0-ref.wav',
 '2086_2035_000002_0-mixed.wav',
 '2086_2035_000002_0-target.wav',
 '2086_2035_000002_0-ref.wav',
 '3536_1919_000003_0-mixed.wav',
 '3536_1919_000003_0-target.wav',
 '3536_1919_000003_0-ref.wav',
 '6319_3000_000004_0-mixed.wav',
 '6319_3000_000004_0-target.wav',
 '3536_1919_000003_1-mixed.wav',
 '6319_3000_000004_0-ref.wav',
 '3536_1919_000003_1-target.wav',
 '3536_1919_000003_1-ref.wav',
 '5338_2902_000006_0-mixed.wav',
 '5338_2902_000006_0-target.wav',
 '5338_2902_000006_0-ref.wav',
 '84_3752_000007_0-mixed.wav',
 '84_3752_000007_0-target.wav',
 '84_3752_000007_0-ref.wav',
 '1988_2086_000010_0-mixed.wav',
 '1988_2086_000010_0-target.wav',
 '1988_2086_000010_0-ref.wav',
 '3081_6345_000012_0-mixed.wav',
 '3081_6345_000012_0-target.wav',
 '3081_6345_000012_0-ref.wav',
 '6319_84_000011_0-mixed.wav',
 '6319_84_000011_0-target.wav',
 '1272_1988_000013_0-mixed.wav',
 '6319_84_000011_0-ref.wav',

In [23]:
ref_train = sorted(glob(os.path.join(path_mixtures_train, '*-ref.wav')))
mix_train = sorted(glob(os.path.join(path_mixtures_train, '*-mixed.wav')))
target_train = sorted(glob(os.path.join(path_mixtures_train, '*-target.wav')))

In [24]:
mix_train[:10]

['/home/vladimir/Документы/data/train/1272_1988_000013_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1673_174_000027_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1673_2035_000000_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1673_6241_000026_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1919_1988_000024_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1988_1993_000067_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1988_2086_000010_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1988_3000_000049_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1993_251_000059_0-mixed.wav',
 '/home/vladimir/Документы/data/train/1993_6319_000038_0-mixed.wav']

In [25]:
target_train[:10]

['/home/vladimir/Документы/data/train/1272_1988_000013_0-target.wav',
 '/home/vladimir/Документы/data/train/1673_174_000027_0-target.wav',
 '/home/vladimir/Документы/data/train/1673_2035_000000_0-target.wav',
 '/home/vladimir/Документы/data/train/1673_6241_000026_0-target.wav',
 '/home/vladimir/Документы/data/train/1919_1988_000024_0-target.wav',
 '/home/vladimir/Документы/data/train/1988_1993_000067_0-target.wav',
 '/home/vladimir/Документы/data/train/1988_2086_000010_0-target.wav',
 '/home/vladimir/Документы/data/train/1988_3000_000049_0-target.wav',
 '/home/vladimir/Документы/data/train/1993_251_000059_0-target.wav',
 '/home/vladimir/Документы/data/train/1993_6319_000038_0-target.wav']

In [26]:
ref, mix, target = ref_train[0], mix_train[0], target_train[0]

In [27]:
from IPython import display

display.display(display.Audio(ref, rate=16000))
display.display(display.Audio(mix, rate=16000))
display.display(display.Audio(target, rate=16000))

In [28]:
from speechbrain.pretrained import SepformerSeparation as separator
import torchaudio

model = separator.from_hparams(source="speechbrain/sepformer-wsj02mix", savedir='pretrained_models/sepformer-wsj02mix')

# for custom file, change path
est_sources = model.separate_file(path='speechbrain/sepformer-wsj02mix/test_mixture.wav') 

torchaudio.save("source1hat.wav", est_sources[:, :, 0].detach().cpu(), 8000)
torchaudio.save("source2hat.wav", est_sources[:, :, 1].detach().cpu(), 8000)


In [47]:
est_sources = model.separate_file(mix)

torchaudio.save("source1hat.wav", est_sources[:, :, 0].detach().cpu(), 8000)
torchaudio.save("source2hat.wav", est_sources[:, :, 1].detach().cpu(), 8000)

Resampling the audio from 16000 Hz to 8000 Hz


In [48]:
display.display(display.Audio("source1hat.wav", rate=4000))

In [49]:
display.display(display.Audio("source2hat.wav", rate=16000))

In [11]:
import speechbrain

In [15]:
from speechbrain.pretrained import EncoderDecoderASR

asr_model = EncoderDecoderASR.from_hparams(source="speechbrain/asr-wav2vec2-commonvoice-en", savedir="pretrained_models/asr-wav2vec2-commonvoice-en")


(…)onvoice-en/resolve/main/hyperparams.yaml:   0%|          | 0.00/3.19k [00:00<?, ?B/s]

(…)60/resolve/main/preprocessor_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

(…)vec2-large-lv60/resolve/main/config.json:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]


KeyboardInterrupt



In [27]:
from speechbrain.pretrained import EncoderASR

asr_model = EncoderASR.from_hparams(source="speechbrain/asr-wav2vec2-librispeech", savedir="pretrained_models/asr-wav2vec2-librispeech")


pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]


KeyboardInterrupt



In [None]:
# !/usr/bin/env python

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch
# !/usr/bin/env python

import torch as th
import torch.nn as nn


class ChannelwiseLayerNorm(nn.LayerNorm):
    """
    Channel-wise layer normalization based on nn.LayerNorm
    Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
    Output: 3D tensor with same shape
    """

    def __init__(self, *args, **kwargs):
        super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs)

    def forward(self, x):
        if x.dim() != 3:
            raise RuntimeError("{} requires a 3D tensor input".format(
                self.__name__))
        x = th.transpose(x, 1, 2)
        x = super().forward(x)
        x = th.transpose(x, 1, 2)
        return x


class GlobalLayerNorm(nn.Module):
    """
    Global layer normalization
    Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
    Output: 3D tensor with same shape
    """

    def __init__(self, dim, eps=1e-05, elementwise_affine=True):
        super(GlobalLayerNorm, self).__init__()
        self.eps = eps
        self.normalized_dim = dim
        self.elementwise_affine = elementwise_affine
        if elementwise_affine:
            self.beta = nn.Parameter(th.zeros(dim, 1))
            self.gamma = nn.Parameter(th.ones(dim, 1))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

    def forward(self, x):
        if x.dim() != 3:
            raise RuntimeError("{} requires a 3D tensor input".format(
                self.__name__))
        # calculate the mean, variance over the channel and time dimensions
        mean = th.mean(x, (1, 2), keepdim=True)
        var = th.mean((x - mean) ** 2, (1, 2), keepdim=True)
        if self.elementwise_affine:
            x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta
        else:
            x = (x - mean) / th.sqrt(var + self.eps)
        return x

    def extra_repr(self):
        return "{normalized_dim}, eps={eps}, " \
               "elementwise_affine={elementwise_affine}".format(**self.__dict__)


class Conv1D(nn.Conv1d):
    """
    1D Conv based on nn.Conv1d for 2D or 3D tensor
    Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
    Output: Default 3D tensor with [N, C_out, L_out]
            If C_out=1 and squeeze is true, return 2D tensor
    """

    def __init__(self, *args, **kwargs):
        super(Conv1D, self).__init__(*args, **kwargs)

    def forward(self, x, squeeze=False):
        if x.dim() not in [2, 3]:
            raise RuntimeError("{} require a 2/3D tensor input".format(
                self.__name__))
        x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
        if squeeze:
            x = th.squeeze(x)
        return x


class ConvTrans1D(nn.ConvTranspose1d):
    """
    1D Transposed Conv based on nn.ConvTranspose1d for 2D or 3D tensor
    Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
    Output: 2D tensor with [N, L_out]
    """

    def __init__(self, *args, **kwargs):
        super(ConvTrans1D, self).__init__(*args, **kwargs)

    def forward(self, x):
        if x.dim() not in [2, 3]:
            raise RuntimeError("{} require a 2/3D tensor input".format(
                self.__name__))
        x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))

        # squeeze the channel dimension 1 after reconstructing the signal
        return th.squeeze(x, 1)


class TCNBlock(nn.Module):
    def __init__(self,
                 in_channels=256,
                 conv_channels=512,
                 kernel_size=3,
                 dilation=1, causal=False):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, conv_channels, 1)
        self.net = nn.Sequential(
            nn.PReLU(),
            GlobalLayerNorm(conv_channels, eps=1e-05),
            nn.Conv1d(
                conv_channels,
                conv_channels,
                kernel_size,
                groups=conv_channels,
                padding=(dilation * (kernel_size - 1)) // 2,
                dilation=dilation,
                bias=True),
            nn.PReLU(),
            GlobalLayerNorm(conv_channels, eps=1e-05),
            nn.Conv1d(conv_channels, in_channels, 1, bias=True)
        )

    def forward(self, x):
        y = self.conv(x)
        y = self.net(y)
        return y + x


class TCNBlockSpeaker(TCNBlock):
    def __init__(self,
                 in_channels=256,
                 spk_embed_dim=100,
                 conv_channels=512,
                 kernel_size=3,
                 dilation=1,
                 causal=False):
        super().__init__(in_channels, conv_channels, kernel_size, dilation)
        self.conv = Conv1D(in_channels + spk_embed_dim, conv_channels, 1)

    def forward(self, x, aux):
        aux = th.unsqueeze(aux, -1)
        aux = aux.repeat(1, 1, x.shape[-1])
        y = th.cat([x, aux], 1)
        y = self.conv(y)
        y = self.net(y)
        return y + x


class ResBlock(nn.Module):
    def __init__(self, input_size, out_size):
        super().__init__()
        self.first = nn.Sequential(
            nn.Conv1d(input_size, out_size, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_size),
            nn.PReLU(),
            nn.Conv1d(out_size, out_size, kernel_size=1, bias=False),
            nn.BatchNorm1d(out_size)
        )

        if input_size != out_size:
            self.downsample = nn.Sequential(
                nn.Conv1d(input_size, out_size, kernel_size=1, bias=False)
            )
        else:
            self.downsample = None
        self.second = nn.Sequential(
                nn.PReLU(),
                nn.MaxPool1d(3)
        )

    def forward(self, x):
        y = self.first(x)
        if self.downsample is not None:
            x = self.downsample(x)
        y += x
        y = self.second(y)
        return y

class SharedEncoder(nn.Module):
    def __init__(self, N, L1, L2, L3):
        super().__init__()
        self.L1 = L1
        self.L2 = L2
        self.L3 = L3
        self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0)
        self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
        self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)

    def forward(self, x):
        w1 = F.relu(self.encoder_1d_short(x))
        T = w1.shape[-1]
        xlen1 = x.shape[-1]
        xlen2 = (T - 1) * (self.L1 // 2) + self.L2
        xlen3 = (T - 1) * (self.L1 // 2) + self.L3
        w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
        w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
        return w1, w2, w3


class SpExPlus(nn.Module):
    def __init__(self,
                 L1=20,
                 L2=80,
                 L3=160,
                 N=192,
                 B=4,
                 O=192,
                 P=256,
                 Q=3,
                 num_spks=4,
                 spk_embed_dim=128,
                 causal=False):
        super().__init__()
        # n x S => n x N x T, S = 4s*8000 = 32000
        self.L1 = L1
        self.L2 = L2
        self.L3 = L3
        self.shared_encoder = SharedEncoder(N, L1, L2, L3)
        # before repeat blocks, always cLN
        self.ln = ChannelwiseLayerNorm(3 * N)
        # n x N x T => n x O x T
        self.proj = Conv1D(3 * N, O, 1)
        self.conv_block_1 = TCNBlockSpeaker(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q,
                                            causal=causal, dilation=1)
        self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q,
                                                     causal=causal)
        self.conv_block_2 = TCNBlockSpeaker(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q,
                                            causal=causal, dilation=1)
        self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q,
                                                     causal=causal)
        self.conv_block_3 = TCNBlockSpeaker(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q,
                                            causal=causal, dilation=1)
        self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q,
                                                     causal=causal)
        self.conv_block_4 = TCNBlockSpeaker(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q,
                                            causal=causal, dilation=1)
        self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q,
                                                     causal=causal)
        # n x O x T => n x N x T
        self.mask1 = Conv1D(O, N, 1)
        self.mask2 = Conv1D(O, N, 1)
        self.mask3 = Conv1D(O, N, 1)
        # using ConvTrans1D: n x N x T => n x 1 x To
        # To = (T - 1) * L // 2 + L
        self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
        self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
        self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
        self.num_spks = num_spks

        self.spk_encoder = nn.Sequential(
            ChannelwiseLayerNorm(3 * N),
            Conv1D(3 * N, O, 1),
            ResBlock(O, O),
            ResBlock(O, P),
            ResBlock(P, P),
            Conv1D(P, spk_embed_dim, 1),
        )

        self.pred_linear = nn.Linear(spk_embed_dim, num_spks)

    def _build_stacks(self, num_blocks, **block_kwargs):
        """
        Stack B numbers of TCN block, the first TCN block takes the speaker embedding
        """
        blocks = [
            TCNBlock(**block_kwargs, dilation=(2 ** b))
            for b in range(1, num_blocks)
        ]
        return nn.Sequential(*blocks)

    def forward(self, mix_audio, reference_audio, reference_audio_len, **batch):
        x = mix_audio
        aux = reference_audio
        aux_len = reference_audio_len
        if x.dim() == 1:
            x = th.unsqueeze(x, 0)

        xlen1 = x.shape[-1]
        w1, w2, w3 = self.shared_encoder(x)
        # n x 3N x T
        y = self.ln(th.cat([w1, w2, w3], 1))
        # n x O x T
        y = self.proj(y)

        # speaker encoder (share params from speech encoder)
        aux_w1, aux_w2, aux_w3 = self.shared_encoder(aux)
        aux = self.spk_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1))
        aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
        aux_T = ((aux_T // 3) // 3) // 3
        aux = th.sum(aux, -1) / aux_T.view(-1, 1).float()

        # In
        y = self.conv_block_1(y, aux)
        y = self.conv_block_1_other(y)
        y = self.conv_block_2(y, aux)
        y = self.conv_block_2_other(y)
        y = self.conv_block_3(y, aux)
        y = self.conv_block_3_other(y)
        y = self.conv_block_4(y, aux)
        y = self.conv_block_4_other(y)

        # n x N x T
        m1 = F.relu(self.mask1(y))
        m2 = F.relu(self.mask2(y))
        m3 = F.relu(self.mask3(y))
        S1 = w1 * m1
        S2 = w2 * m2
        S3 = w3 * m3

        return {'ests': self.decoder_1d_short(S1),
                'ests2': self.decoder_1d_middle(S2)[:, :xlen1],
                'ests3': self.decoder_1d_long(S3)[:, :xlen1],
                'spk_pred': self.pred_linear(aux)}

