In [None]:
import os, math, random, pickle, gzip, warnings, json
from pathlib import Path, PureWindowsPath
from os.path import join as pjoin
from collections import defaultdict
from decimal import Decimal
from multiprocessing import Pool, cpu_count

from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
import kaldiio
from textgrid import TextGrid
from torch_audiomentations import AddBackgroundNoise, ApplyImpulseResponse

from personal_VAD.utils import *

# ignore torchaudio warnings
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*deprecated.*")

In [None]:
def load_chime(meta_path=None, algn_dir:str=None, save_path=None, overwrite=False):
    if save_path and os.path.exists(save_path) and not overwrite:
        with gzip.open(save_path, 'rb') as f: return pickle.load(f)
    validate_path(save_path)

    def get_label(algn_path):
        with open(algn_path, "r", encoding="utf-8") as f: algn = json.load(f)
        algn.sort(key=lambda x: (x["speaker"], float(x["start_time"])))
        spk2label = defaultdict(list)
        spk2tstamp = defaultdict(list)
        for seg in algn:
            spk = seg["speaker"]
            start_ms = int(Decimal(seg["start_time"])*100)*10
            end_ms = int(Decimal(seg["end_time"])*100)*10
            spk2label[spk].append(0)
            spk2tstamp[spk].append(start_ms)
            spk2label[spk].append(1)
            spk2tstamp[spk].append(end_ms)
        return {spk:(spk2label[spk],spk2tstamp[spk]) for spk in spk2label.keys()}
    spk2label = {}
    for filename in os.listdir(algn_dir): spk2label.update(get_label(pjoin(algn_dir,filename)))

    convs = {}  # {spk : (conv_path, label, tstamp)}
    utts = defaultdict(list)  # {spk : [(conv_path, start_time, end_time)]}
    df = pd.read_csv(meta_path)
    noalign_spks = set()
    for _, row in df.iterrows():
        spk, start_time, end_time, conv_path = row['speaker_id'], row['start_time'], row['end_time'], row['audio_path']
        if spk not in spk2label: noalign_spks.add(spk); continue
        conv_path = pjoin(*Path(PureWindowsPath(conv_path)).parts[6:])
        convs[spk] = (conv_path, *spk2label[spk])
        utts[spk].append((conv_path, start_time, end_time))

    if save_path:
        with gzip.open(save_path, 'wb') as f: pickle.dump((convs, utts), f)

    #print('no align speakers', noalign_spks)
    return convs, utts

In [None]:
def mono_resample(wav: torch.Tensor, sr: int=16000, target_sr: int = 16000) -> torch.Tensor:
    """
    [C,T] or [T] -> [1,T]
    """
    if wav.dim() == 1: wav = wav.unsqueeze(0)
    elif wav.dim() == 2 and wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav


def tstamp2framelabel(tstamps, labels, speclen):
    frame_labels = np.zeros(speclen, dtype=int)
    start_frame = 0
    for end_time, label in zip(tstamps, labels):
        end_frame = end_time // 10
        frame_labels[start_frame:end_frame] = label
        start_frame = end_frame
    return torch.tensor(frame_labels)


class Augmentor:
    def __init__(self, rir_paths=None, noise_paths=None, rir_prob=0, noise_prob=0, min_snr_in_db=3, max_snr_in_db=30):
        self.do_aug = False if rir_prob==0 and noise_prob==0 else True
        if self.do_aug == False: return
        self.rir_augmentor = ApplyImpulseResponse(ir_paths=rir_paths,
                                sample_rate=16000,
                                p=rir_prob,
                                output_type='tensor')
        self.noise_augentor = AddBackgroundNoise(background_paths=noise_paths,
                            sample_rate=16000,
                            min_snr_in_db=min_snr_in_db,
                            max_snr_in_db=max_snr_in_db,
                            p=noise_prob,
                            output_type='tensor')

    def __call__(self, wav):
        return self.augment(wav)

    def augment(self, wav):
        if self.do_aug == False: return wav
        wav = wav.view(1,1,-1)
        wav = self.rir_augmentor(wav)
        wav = self.noise_augentor(wav)
        return wav.squeeze(0)


class Extractor:
    def __init__(self, apply_cmvn=True, **feature_args):
        if feature_args==None: feature_args={}
        basic_args = {'sample_rate':16000, 'n_fft':400, 'n_mels':24, 'win_length':400, 'hop_length':160}
        basic_args.update(feature_args)
        self.extractor = torchaudio.transforms.MelSpectrogram(**basic_args)
        self.apply_cmvn = apply_cmvn

    def __call__(self, wav):
        return self.extract(wav)

    def extract(self, wav:list|str):
        if isinstance(wav,str): wav, sr = torchaudio.load(wav)
        spec = self.extractor(wav)  # (1,F,T)
        spec = spec.squeeze(0).transpose(0,1)  # (T,F)
        spec = torch.log10(spec + 1e-6)
        if self.apply_cmvn==True:
            mean = spec.mean(dim=0, keepdim=True)
            std = spec.std(dim=0, keepdim=True)
            spec = (spec - mean) / (std + 1e-9)
        return spec



def generate_feat_conv(data_dir, data, feature_args:dict=None, aug_args:dict=None, save_path=None, overwrite=False):
    """
    data: {spk : (conv_path, label, tstamp)}
    """
    if save_path and os.path.exists(save_path) and not overwrite:
        with gzip.open(save_path, 'rb') as f: return pickle.load(f)
    augmentor = Augmentor(**aug_args)
    extractor = Extractor(**feature_args)
    feats = {}
    for spk, (conv_path, label, tstamp) in tqdm(data.items()):
        wav, sr = torchaudio.load(pjoin(data_dir,conv_path))
        wav = mono_resample(wav, sr)
        wav = augmentor.augment(wav)
        feat = extractor.extract(wav)
        feats[spk] = (feat, tstamp2framelabel(tstamp, label, len(feat)))
    with gzip.open(save_path, 'wb') as f: pickle.dump(feats, f)
    return feats

In [None]:
chime_data_dir = ...
chime_meta_dir = ...
chime_align_dir = ...
chime_processed_dir = ...

feature_args = {'n_mels':24}
feature_args_noCMVN = {'n_mels':24, 'apply_cmvn':False}
noaug_args = {'rir_prob':0, 'noise_prob':0}
aug_args = {'rir_paths':...,
            'noise_paths':...,
            'rir_prob':0.5, 'noise_prob':0.5}

splits = ["eval","dev","train"]
for split in splits:
    print(split)
    meta_file = pjoin(chime_meta_dir, f"CHiME_{split}_meta.csv")
    algn_dir = pjoin(chime_align_dir, split)
    data_path = pjoin(chime_processed_dir,split,'data.pkl.gz')
    data_dir = pjoin(chime_data_dir, f'CHiME6_{split}')
    convs, utts = load_chime(meta_file, algn_dir, save_path=data_path)

    conv_path = pjoin(chime_processed_dir,split,"conv_24dim_noaug_noCMVN.pkl.gz")
    validate_path(conv_path)
    conv_feats = generate_feat_conv(data_dir, convs, feature_args_noCMVN, noaug_args, save_path=conv_path)

    #utt_path = pjoin(chime_processed_dir,split,"utt_24dim_aug.pkl.gz")
    #with gzip.open(utt_path, 'wb') as f: pickle.dump(utts, f)