In [None]:
#@title Utils

def mono_resample(wav: torch.Tensor, sr: int=16000, target_sr: int = 16000) -> torch.Tensor:
    """
    [C,T] or [T] -> [1,T]
    """
    if wav.dim() == 1: wav = wav.unsqueeze(0)
    elif wav.dim() == 2 and wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav


def tstamp2framelabel(tstamps, labels, speclen):
    frame_labels = np.zeros(speclen, dtype=int)
    start_frame = 0
    for end_time, label in zip(tstamps, labels):
        end_frame = end_time // 10
        frame_labels[start_frame:end_frame] = label
        start_frame = end_frame
    return torch.tensor(frame_labels)


class Augmentor:
    def __init__(self, rir_paths=None, noise_paths=None, rir_prob=0, noise_prob=0, min_snr_in_db=3, max_snr_in_db=30):
        self.do_aug = False if rir_prob==0 and noise_prob==0 else True
        if self.do_aug == False: return
        self.rir_augmentor = ApplyImpulseResponse(ir_paths=rir_paths,
                                sample_rate=16000,
                                p=rir_prob,
                                output_type='tensor')
        self.noise_augentor = AddBackgroundNoise(background_paths=noise_paths,
                            sample_rate=16000,
                            min_snr_in_db=min_snr_in_db,
                            max_snr_in_db=max_snr_in_db,
                            p=noise_prob,
                            output_type='tensor')

    def __call__(self, wav):
        return self.augment(wav)

    def augment(self, wav):
        if self.do_aug == False: return wav
        wav = wav.view(1,1,-1)
        wav = self.rir_augmentor(wav)
        wav = self.noise_augentor(wav)
        return wav.squeeze(0)


class Extractor:
    def __init__(self, apply_cmvn=True, **feature_args):
        if feature_args==None: feature_args={}
        basic_args = {'sample_rate':16000, 'n_fft':400, 'n_mels':24, 'win_length':400, 'hop_length':160}
        basic_args.update(feature_args)
        self.extractor = torchaudio.transforms.MelSpectrogram(**basic_args)
        self.apply_cmvn = apply_cmvn

    def __call__(self, wav):
        return self.extract(wav)

    def extract(self, wav:list|str):
        if isinstance(wav,str): wav, sr = torchaudio.load(wav)
        spec = self.extractor(wav)  # (1,F,T)
        spec = spec.squeeze(0).transpose(0,1)  # (T,F)
        spec = torch.log10(spec + 1e-6)
        if self.apply_cmvn==True:
            mean = spec.mean(dim=0, keepdim=True)
            std = spec.std(dim=0, keepdim=True)
            spec = (spec - mean) / (std + 1e-9)
        return spec

In [None]:
import os, math, random, pickle, gzip, warnings, json
from pathlib import Path
from os.path import join as pjoin
from collections import defaultdict
from decimal import Decimal
from multiprocessing import Pool, cpu_count

from tqdm.auto import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import kaldiio
from textgrid import TextGrid
from torch_audiomentations import AddBackgroundNoise, ApplyImpulseResponse

from personal_VAD.utils import *

# ignore torchaudio warnings
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*deprecated.*")

In [None]:
def load_librispeech(data_dir:str=None, algn_dir:str=None, save_path=None, overwrite=False):
    if os.path.exists(save_path) and not overwrite:
        with gzip.open(save_path, 'rb') as f: return pickle.load(f)

    def get_label(algn_path):
        # get label and end times(in milliseconds) from TextGrid file
        try: tg = TextGrid.fromFile(algn_path)
        except Exception as e: raise RuntimeError(f"Failed to open file: [{type(e).__name__}] {e}. So skipped")
        word_tier = tg.getFirst("words")
        label = np.char.array([('W' if interval.mark else '') for interval in word_tier.intervals], itemsize=4)
        tstamp = np.array([int(Decimal(str(interval.maxTime))*1000) for interval in word_tier.intervals])
        if tstamp[-1]!=int(Decimal(str(tg.maxTime))*1000): raise ValueError(f"Max time label is inconsistent in the file {algn_path}: {tstamp[-1]}!={tg.maxTime}. So skipped")
        return label, tstamp

    data = defaultdict(list)  # {spk_id : [(utt_path, label, tstamp)]}
    for spk_id in tqdm(os.listdir(data_dir)):
        spk_dir = pjoin(data_dir,spk_id)
        if not os.path.isdir(spk_dir): continue

        for chp_id in os.listdir(spk_dir):
            chp_dir = pjoin(spk_dir,chp_id)
            if not os.path.isdir(chp_dir): continue

            for utt_id in os.listdir(chp_dir):
                if not utt_id[-5:]=='.flac': continue
                utt_path = pjoin(chp_dir,utt_id)

                algn_path = pjoin(algn_dir,spk_id,chp_id,utt_id[:-5]+'.TextGrid')
                try:
                    label, tstamp = get_label(algn_path)
                except Exception as e:
                    print(e)
                    continue

                data[spk_id].append((utt_path, label, tstamp))

    if save_path:
        with gzip.open(save_path, 'wb') as f: pickle.dump(data, f)

    print("NOTE: you should call process_data function to properly use this data")
    return data





def process_data(data_path, gender_path, save_path):

    def load_gender_map(meta_path: str):
        gender_map = {}
        with open(meta_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith(";"): continue
                parts = [x.strip() for x in line.split("|")]
                spk_id = int(parts[0])
                gender = parts[1]=='M'
                gender_map[spk_id] = gender
        return gender_map

    gender_map = load_gender_map(gender_path)
    with gzip.open(data_path, 'rb') as f: data = pickle.load(f)
    data_proc = defaultdict(list)
    for spk, utts in tqdm(data.items()):
        spk = int(spk)
        for utt_path, label, tstamp in utts:
            utt_path = pjoin(*Path(utt_path).parts[-4:])  # remove data_dir
            tstamp = tstamp//10*10  # ensure the timestamp is 10ms unit for the 10ms unit label
            label = (label=="W").astype(int)
            data_proc[spk].append((utt_path, label, tstamp, gender_map[spk]))
    if save_path:
        with gzip.open(save_path, 'wb') as f: pickle.dump(data_proc, f)


""" subject of the following functions is raw data.pkl.gz
def data2datalist(data_path, save_path):
    with gzip.open(data_path, 'rb') as f: data = pickle.load(f)
    datalist = []
    for spk, utts in tqdm(data.items()):
        for utt_path, label, tstamp in utts:
            utt_path = pjoin(*Path(utt_path).parts[-4:])  # remove data_dir
            tstamp = tstamp//10*10  # ensure the timestamp is 10ms unit for the 10ms unit label
            label = (label=="W").astype(int)
            datalist.append((spk, utt_path, label, tstamp))
    with open(save_path, 'w') as f:
        for spk, utt_path, label, tstamp in datalist:
            f.write(f'{spk}, {utt_path}, ')
            np.savetxt(f, label, fmt="%d", newline=" ")
            f.write(', ')
            np.savetxt(f, tstamp, fmt="%d", newline=" ")
            f.write('\n')
#data2datalist_paths("./data/LibriSpeech/train-other-500/data.pkl.gz", "./data/LibriSpeech/train-other-500/datalist.txt")


def load_datalist(datalist_path):
    datalist = []
    with open(datalist_path, 'r') as f:
        for line in f:
            spk, utt_path, label, tstamp = line.split(',')
            spk = int(spk)
            utt_path = pjoin(data_dir,utt_path.strip())
            label = np.fromstring(label,sep=" ",dtype=int)
            tstamp = np.fromstring(tstamp,sep=" ",dtype=int)
            datalist.append((spk, utt_path, label, tstamp))
    return datalist
"""

""" process raw data.pkl.gz
libri_processed_dir = "./data/LibriSpeech"
gender_path = "./data/audio_datasets/LibriSpeech/LibriSpeech/SPEAKERS.TXT"
splits = ["test-clean","dev-clean","test-other","dev-other","train-clean-100", "train-clean-360", "train-other-500"]
for split in splits:
    data_path = pjoin(libri_processed_dir,split,'data.pkl.gz')
    data_proc_path = pjoin(libri_processed_dir,split,'data_proc.pkl.gz')
    process_data(data_path, gender_path, data_proc_path)
"""

In [None]:
class Augmentor:
    def __init__(self, rir_paths=None, noise_paths=None, rir_prob=0, noise_prob=0, min_snr_in_db=3, max_snr_in_db=30):
        self.do_aug = False if rir_prob==0 and noise_prob==0 else True
        if self.do_aug == False: return
        self.rir_augmentor = ApplyImpulseResponse(ir_paths=rir_paths,
                                sample_rate=16000,
                                p=rir_prob,
                                output_type='tensor')
        self.noise_augentor = AddBackgroundNoise(background_paths=noise_paths,
                            sample_rate=16000,
                            min_snr_in_db=min_snr_in_db,
                            max_snr_in_db=max_snr_in_db,
                            p=noise_prob,
                            output_type='tensor')

    def __call__(self, wav):
        return self.augment(wav)

    def augment(self, wav):
        if self.do_aug == False: return wav
        wav = wav.view(1,1,-1)
        wav = self.rir_augmentor(wav)
        wav = self.noise_augentor(wav)
        return wav.squeeze(0)


class Extractor:
    def __init__(self, apply_cmvn=True, **feature_args):
        if feature_args==None: feature_args={}
        basic_args = {'sample_rate':16000, 'n_fft':400, 'n_mels':24, 'win_length':400, 'hop_length':160}
        basic_args.update(feature_args)
        self.extractor = torchaudio.transforms.MelSpectrogram(**basic_args)
        self.apply_cmvn = apply_cmvn

    def __call__(self, wav):
        return self.extract(wav)

    def extract(self, wav:list|str):
        if isinstance(wav,str): wav, sr = torchaudio.load(wav)
        spec = self.extractor(wav)  # (1,F,T)
        spec = spec.squeeze(0).transpose(0,1)  # (T,F)
        spec = torch.log10(spec + 1e-6)
        if self.apply_cmvn==True:
            mean = spec.mean(dim=0, keepdim=True)
            std = spec.std(dim=0, keepdim=True)
            spec = (spec - mean) / (std + 1e-9)
        return spec


"""
def generate_feat(data_dir, data, feature_args:dict=None, aug_args:dict=None, save_path=None):
    augmentor = Augmentor(**aug_args)
    extractor = Extractor(**feature_args)
    feats = defaultdict(list)
    for spk, samples in tqdm(data.items()):
        for utt_path, label, tstamp, gender in samples:
            wav, sr = torchaudio.load(pjoin(data_dir,utt_path))
            wav = augmentor.augment(wav)
            feat = extractor.extract(wav)
            feat = feat[:tstamp[-1]//10,:]
            feats[spk].append((feat, label, tstamp, gender))
    with gzip.open(save_path, 'wb') as f: pickle.dump(feats, f)
    return feats
"""
_augmentor = None
_extractor = None
_data_dir = None
def _init_worker(data_dir, feature_args, aug_args):
    global _augmentor, _extractor, _data_dir
    _data_dir = data_dir
    _augmentor = Augmentor(**aug_args)
    _extractor = Extractor(**feature_args)
def _process_item(item):
    spk, utt_path, label, tstamp, gender = item
    wav, sr = torchaudio.load(pjoin(_data_dir, utt_path))
    wav = _augmentor.augment(wav)
    feat = _extractor.extract(wav)
    feat = feat[:tstamp[-1]//10,:]
    return spk, (feat, label, tstamp, gender)
def generate_feat(data_dir, data, feature_args:dict=None, aug_args:dict=None, save_path=None, num_workers=None):
    if num_workers is None: num_workers = cpu_count() - 1
    flat_data = [
        (spk, utt_path, label, tstamp, gender)
        for spk, samples in data.items()
        for utt_path, label, tstamp, gender in samples
    ]
    feats = defaultdict(list)
    with Pool(
        processes=num_workers,
        initializer=_init_worker,
        initargs=(data_dir, feature_args, aug_args),
    ) as pool:
        for spk, out in tqdm(pool.imap_unordered(_process_item, flat_data), total=len(flat_data)): feats[spk].append(out)
    if save_path is not None:
        with gzip.open(save_path, "wb") as f: pickle.dump(feats, f)
    return feats

In [None]:
feature_args = {'n_mels':24}
feature_args_noCMVN = {'n_mels':24, 'apply_cmvn':False}
noaug_args = {'rir_prob':0, 'noise_prob':0}
aug_args = {'rir_paths':...,
            'noise_paths':...,
            'rir_prob':0.5, 'noise_prob':0.5}

libri_data_dir = ...
libri_processed_dir = ...
splits = ["test-clean","dev-clean","test-other","dev-other","train-clean-100", "train-clean-360", "train-other-500"] 
for split in splits:
    feat_dir = pjoin(libri_processed_dir,split,"features")
    validate_path(feat_dir, is_dir=True)
    data_path = pjoin(libri_processed_dir,split,'data_proc.pkl.gz')
    feat_path = pjoin(feat_dir,"feats_24dim_aug_noCMVN.pkl.gz")
    data = load_librispeech(save_path=data_path)
    generate_feat(libri_data_dir, data, feature_args_noCMVN, aug_args, feat_path)

In [None]:
def simulate(data:dict[str:list], sample_num, unique=True, min_enroll_duration=5, save_path=None):
    """
    input:
        data: individual utterences {spk_id:(utt_paths, label, tstamp)}
        save_path : *.pkl.gz
    output:
        enroll_data: enrollment utterences for each speaker {spk_id:[utt_paths]}
        simul_data: simulated conversations [(target_spk, utt_paths, label, tstamp)]
    """

    # shuffle for randomness
    whole_spks = tuple(data.keys())
    remained_spks = list(whole_spks)  # if it was set type, then sampling become inefficient
    for spk in whole_spks: random.shuffle(data[spk])

    # split enrollment data
    enroll_data = defaultdict(list)
    for spk, samples in data.items():
        utts = []
        remaining_duration = min_enroll_duration
        while remaining_duration>0 and samples:
            utt_path, label, tstamp = samples.pop()
            utts.append(utt_path)
            remaining_duration -= tstamp[-1]
        enroll_data[int(spk)] = utts
        if not samples:
            print(f'speaker {spk} has no enough speech for enrollment')
    for i,spk in enumerate(remained_spks):
        if not data[spk]: del remained_spks[i]

    # start simulation
    simul_data = []
    for _ in tqdm(range(sample_num)):
        if not remained_spks: print(f"run out of data. {len(simul_data)}"); break

        spk_num = min(random.randint(1,3), len(remained_spks))
        sampled_spks = random.sample(remained_spks, spk_num)

        spks = []
        utt_paths = []
        labels = []
        lengths = []
        tstamps = []
        last_time = 0
        for spk in sampled_spks:
            if unique:
                utt_path, label, tstamp = data[spk].pop()
                if len(data[spk])==0: remained_spks.remove(spk)
            else:
                utt_path, label, tstamp = random.choice(data[spk])
            spk = int(spk)
            label = (label=="W").astype(int)*spk
            tstamp = tstamp//10*10  # ensure the timestamp is 10ms unit for the 10ms unit label
            spks.append(spk)
            utt_paths.append(utt_path)
            labels.append(label)
            lengths.append(tstamp[-1])
            tstamps.append(tstamp+last_time)
            last_time = tstamps[-1][-1]
        spks = np.array(spks)
        utt_paths = np.array(utt_paths)
        lengths = np.array(lengths)
        labels = np.concatenate(labels)
        tstamps = np.concatenate(tstamps)
        simul_data.append((spks, utt_paths, lengths, labels, tstamps))

    if save_path:
        with gzip.open(save_path, 'wb') as f: pickle.dump((enroll_data,simul_data), f)

    return enroll_data, simul_data

In [None]:
def concat_wavs(utt_paths:list[str], lengths:list[int]=None, sr=16000):
    """
    concatenate utterances from paths.
    provide last tstamp to ensure matching between wav and tstamp.
    """
    if lengths is not None: lengths = lengths*sr//1000
    else: lengths = [-1]*len(utt_paths)
    concat_wav = []
    for utt_path,l in zip(utt_paths,lengths):
        wav, file_sr = torchaudio.load(utt_path)  # (1, T)
        assert file_sr == sr
        concat_wav.append(wav[:,:l] if l!=-1 else wav)
    concat_wav = torch.cat(concat_wav, dim=-1)  # (1, sum_T)
    return concat_wav


class Augmentor:
    def __init__(self, rir_paths=None, noise_paths=None, rir_prob=0, noise_prob=0, min_snr_in_db=3, max_snr_in_db=30):
        self.do_aug = False if rir_prob==0 and noise_prob==0 else True
        if self.do_aug == False: return
        self.rir_augmentor = ApplyImpulseResponse(ir_paths=rir_paths,
                                sample_rate=16000,
                                p=rir_prob,
                                output_type='tensor')
        self.noise_augentor = AddBackgroundNoise(background_paths=noise_paths,
                            sample_rate=16000,
                            min_snr_in_db=min_snr_in_db,
                            max_snr_in_db=max_snr_in_db,
                            p=noise_prob,
                            output_type='tensor')
    def augment(self, wav):
        if self.do_aug == False: return wav
        wav = wav.view(1,1,-1)
        wav = self.rir_augmentor(wav)
        wav = self.noise_augentor(wav)
        return wav.squeeze(0)


class Extractor:
    def __init__(self, apply_cmvn=True, **feature_args):
        if feature_args==None: feature_args={}
        basic_args = {'sample_rate':16000, 'n_fft':400, 'n_mels':24, 'win_length':400, 'hop_length':160}
        basic_args.update(feature_args)
        self.extractor = torchaudio.transforms.MelSpectrogram(**basic_args)
        self.apply_cmvn = apply_cmvn

    def extract(self, wav:list|str):
        if isinstance(wav,str): wav, sr = torchaudio.load(wav)
        spec = self.extractor(wav)  # (1,F,T)
        spec = spec.squeeze(0).transpose(0,1)  # (T,F)
        spec = torch.log10(spec + 1e-6)
        if self.apply_cmvn==True:
            mean = spec.mean(dim=0, keepdim=True)
            std = spec.std(dim=0, keepdim=True)
            spec = (spec - mean) / (std + 1e-9)
        return spec


def get_feat_from_enroll(enroll_data:dict[str:list[str]], feature_args:dict, aug_args=None, save_path=None):
    feats = {}
    augmentor = Augmentor(**aug_args)
    extractor = Extractor(**feature_args)
    for spk,utts in tqdm(enroll_data.items()):
        wav = concat_wavs(utts)
        wav = augmentor.augment(wav)
        feat = extractor.extract(wav)
        feats[spk] = feat
    with gzip.open(save_path, 'wb') as f: pickle.dump(feats, f)
    return feats

"""
def get_feat_from_data(data, feature_args:dict, aug_args:dict=None, save_path=None):
    if aug_args==None: aug_args={}
    augmentor = Augmentor(**aug_args)
    extractor = Extractor(**feature_args)
    feats = []
    for spks, utt_paths, lengths, labels, tstamps in tqdm(data):
        wav = concat_wavs(utt_paths, lengths)
        wav = augmentor.augment(wav)
        feat = extractor.extract(wav)
        feats.append(feat)
    with gzip.open(save_path, 'wb') as f: torch.save(feats, f)
    return feats
"""
augmentor = None
extractor = None
def init_worker(aug_args, feature_args):
    global augmentor, extractor
    augmentor = Augmentor(**aug_args)
    extractor = Extractor(**feature_args)
def process_one(entry):
    spks, utt_paths, lengths, labels, tstamps = entry
    wav = concat_wavs(utt_paths, lengths)
    wav = augmentor.augment(wav)
    feat = extractor.extract(wav)
    return feat
def get_feat_from_data(data, feature_args: dict, aug_args: dict = None, save_path=None, num_workers=None):
    if aug_args is None: aug_args = {}
    if num_workers is None: num_workers = cpu_count(); print('num_workers:',num_workers)

    with Pool(
        processes=num_workers,
        initializer=init_worker,
        initargs=(aug_args, feature_args)
    ) as pool:
        feats = list(tqdm(pool.imap(process_one, data), total=len(data)))

    if save_path is not None:
        with gzip.open(save_path, 'wb') as f: torch.save(feats, f)
    return feats

In [None]:
libri_data_dir = ...
libri_algn_dir = ...
libri_processed_dir = ...

feature_args = {'n_mels':24}
noaug_args = {'rir_prob':0, 'noise_prob':0}
aug_args = {'rir_paths':...,
            'noise_paths':...,
            'rir_prob':0.5, 'noise_prob':0.5}

splits = ["test-clean","dev-clean","test-other","dev-other","train-clean-100", "train-clean-360", "train-other-500"]
for split in splits:
    print(f"process {split}")
    data_dir = pjoin(libri_data_dir, split)
    algn_dir = pjoin(libri_algn_dir, split)

    simul_dir = pjoin(libri_processed_dir,split,"simul_infos")
    feat_dir = pjoin(libri_processed_dir,split,"features")
    validate_path(simul_dir, is_dir=True)
    validate_path(feat_dir, is_dir=True)

    data_path = pjoin(libri_processed_dir,split,'data.pkl.gz')
    simul_path = pjoin(simul_dir,'max.pkl.gz')
    simul_feat_path = pjoin(feat_dir,'aug_24dim.pt.gz') # noaug_24dim_noCMVN
    enroll_feat_path = pjoin(feat_dir,"enroll_aug_24dim.pt.gz") # enroll_aug_24dim_noCMVN

    if os.path.exists(data_path):
        print('load processed data')
        with gzip.open(data_path, 'rb') as f: data = pickle.load(f)
    else:
        data = load_librispeech(data_dir, algn_dir, save_path=data_path)

    enroll_data, simul_data = simulate(data, sample_num=100000000, save_path=simul_path)
    with gzip.open(simul_path, 'rb') as f: enroll_data, simul_data = pickle.load(f)

    get_feat_from_enroll(enroll_data, feature_args, aug_args, save_path=enroll_feat_path)
    get_feat_from_data(simul_data, feature_args, noaug_args, save_path=simul_feat_path)