In [None]:
#| default_exp vad_merge

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import random

import numpy as np
import torch
import torch.nn.functional as F

from fastprogress import progress_bar
from fastcore.script import *

from whisperspeech import utils
import webdataset as wds

# VAD merging

We merge the VAD segments into longer chunks to make training more efficient (otherwise we'll spend a lot of time calculating padding).

In [None]:
#| export
def derived_name(input, kind, base="audio"):
    return input.replace(base, kind) + ".gz"

In [None]:
ds = wds.WebDataset(['../wolnelektury-wds2/wolnelektury-audio-000000.tar']).compose(
    wds.decode(wds.torch_audio),
    utils.merge_in(utils.derived_dataset('vad')),
    utils.find_audio,
    utils.split_to_chunks,
    utils.merge_in(utils.derived_dataset('spk_emb')),
)

In [None]:
import IPython
import time

In [None]:
prev = None
for s in progress_bar(ds, total=20):
    sim = F.cosine_similarity(torch.tensor(s['spk_emb.npy']), torch.tensor((prev if prev is not None else s)['spk_emb.npy']), dim=0)
    secs = s['tend'] - s['tstart']
    same = sim > 0.6 if secs > 2 else sim > 0.1
    if not same: print("new")
    print(s['__key__'], sim, secs)
    display(IPython.display.Audio(s['samples'], rate=s['sample_rate']))
    if secs > 2:
        prev = s
    time.sleep(.5)
s

In [None]:
#| exporti
# we need to split first to merge in the spk_emb.npy data
# this is similar to utils.split_to_chunks but works without the audio data
def split(stream, ikey='vad.npy'):
    empty = []
    for s in stream:
        imax = len(s[ikey]) - 1
        if len(s[ikey]) == 0:
            # Preserve info about audio files without any speech.
            # We need to push this info through a weird side-channel 
            # because we want to be able to a merge with naively
            # splitted data.
            empty.append({"__key__": s['__key__'] + "_none",
                   "src_key": s['__key__'],
                   "__url__": s['__url__']})
        for i,(ts,te) in enumerate(s[ikey]):
            yield {"__key__": s['__key__'] + f"_{i:03d}",
                   "src_key": s['__key__'],
                   "__url__": s['__url__'],
                   "i": i, "imax": imax,
                   "tstart": ts, "tend": te,
                   "empty": empty}
            empty = []

def merge_by_src_key(stream):
    ms = None
    for s in stream:
        # push accumulated data
        if ms and s['src_key'] != ms['__key__']:
            yield ms
            ms = None
        # push all empty files we might have lost
        for vs in s.get("empty", []):
            yield {
                "__url__": vs['__url__'],
                "__key__": vs['src_key'],
                "spk_emb.npy": [],
                "vad.npy": [],
            }
        # prepare a merged record for the new data
        if ms is None:
            ms = {
                "__url__": s['__url__'],
                "__key__": s['src_key'],
                "spk_emb.npy": [],
                "vad.npy": [],
            }
        ms["spk_emb.npy"].append(s["spk_emb.npy"])
        ms["vad.npy"].append([s['tstart'], s['tend']])
    yield ms

In [None]:
ds = wds.WebDataset([utils.derived_name('../wolnelektury-wds2/wolnelektury-audio-000000.tar', 'vad')]).compose(
    wds.decode(),
    split,
    utils.merge_in(utils.derived_dataset('spk_emb', base='vad', suffix='')),
    merge_by_src_key,
)

In [None]:
for s in ds: break
s

In [None]:
#| exporti
def random_cutter(dur):
    if random.random() < 0.5:
        return dur > 30 * (random.random()*0.95+0.05)
    else:
        return dur > 30

def chunk_merger(stream, should_cut=lambda x: x > 30):
    for s in stream:
        segments, speakers = s['vad.npy'], s['spk_emb.npy']
        if len(segments) == 0:
            s['vad.npy'], s['spk_emb.npy'] = np.array([]), np.array([])
            yield s
            continue
        curr_start = segments[0][0]
        curr_end = 0
        curr_spk = None
        curr_chunks = []
        spk_acc = torch.tensor(speakers[0])
        spk_acc_N = 1
        merged = []
        merged_chunks = []
        merged_spk = []

        for (ts,te),new_spk in zip(segments, speakers):
            secs = te - ts
            new_spk = torch.tensor(new_spk)
            spk_change = False
            if curr_spk is not None:
                sim = F.cosine_similarity(curr_spk, new_spk, dim=0)
                spk_change = sim < 0.5 if secs > 2 else sim < 0.1
            if (spk_change or should_cut(te - curr_start)) and curr_end - curr_start > 0:
                merged.append((curr_start, curr_end))
                merged_spk.append(spk_acc / spk_acc_N)
                merged_chunks.append(curr_chunks)
                curr_start = ts
                spk_acc = new_spk
                curr_chunks = []
            curr_spk = new_spk
            if secs > 2:
                spk_acc += new_spk
                spk_acc_N += 1
            curr_end = te
            curr_chunks.append((ts, te))
        merged.append((curr_start, curr_end))
        merged_spk.append(spk_acc / spk_acc_N)
        merged_chunks.append(curr_chunks)
        s['vad.npy'], s['spk_emb.npy'] = np.array(merged), torch.stack(merged_spk).numpy()
        s['subvads.pyd'] = merged_chunks
        yield s

In [None]:
ds = wds.WebDataset([utils.derived_name('../wolnelektury-wds2/wolnelektury-audio-000000.tar', 'vad')]).compose(
    wds.decode(),
    split,
    utils.merge_in(utils.derived_dataset('spk_emb', base='vad', suffix='')),
    merge_by_src_key,
    chunk_merger,
)

In [None]:
for s in ds: break
s

In [None]:
ds = wds.WebDataset(['../wolnelektury-wds2/wolnelektury-audio-000000.tar']).compose(
    wds.decode(wds.torch_audio),
    utils.merge_in(utils.derived_dataset('vad')),
    utils.find_audio,
    utils.split_to_chunks,
    utils.merge_in(utils.derived_dataset('spk_emb')),
    merge_by_src_key,
    chunk_merger,
    utils.merge_in(utils.derived_dataset('audio', suffix='', decoders=[wds.torch_audio])),
    utils.find_audio,
    lambda x: utils.split_to_chunks(x, metakeys=['spk_emb.npy']),
)

In [None]:
for s in ds: break
s

In [None]:
prev = None
for s in progress_bar(ds, total=20):
    sim = F.cosine_similarity(torch.tensor(s['spk_emb.npy']), torch.tensor((prev if prev is not None else s)['spk_emb.npy']), dim=0)
    secs = s['tend'] - s['tstart']
    same = sim > 0.6 if secs > 2 else sim > 0.1
    if not same: print("new")
    print(s['__key__'], sim, secs, sum([e-s for s,e in s['orig_s']['subvads.pyd'][s['i']]]))
    display(IPython.display.Audio(s['samples'], rate=s['sample_rate']))
    if secs > 2:
        prev = s
    time.sleep(.5)

In [None]:
#| exporti
@call_parse
def prepare_mvad(
    input:str,  # FLAC webdataset file path (or - to read the names from stdin)
    output:str=None, # output file name
    eqvad:bool=False, # make the chunk length distribution more uniform
):
    if eqvad:
        def merger(x):
            return chunk_merger(x, random_cutter)
        kind = 'eqvad'
    else:
        merger = chunk_merger
        kind = 'maxvad'
    
    ds = wds.WebDataset([utils.derived_name(input, 'vad')]).compose(
        wds.decode(),
        split,
        utils.merge_in(utils.derived_dataset('spk_emb', base='vad', suffix='')),
        merge_by_src_key,
        merger,
    )

    with utils.AtomicTarWriter(derived_name(input, kind)) as sink:
        for s in progress_bar(ds, total='noinfer'):
            sink.write(s)

In [None]:
prepare_mvad('../wolnelektury-wds2/wolnelektury-audio-000000.tar')

In [None]:
!tar tf ../wolnelektury-wds2/wolnelektury-maxvad-000000.tar.gz

./kornhauser-wiatr/kornhauser-wiatr_001.spk_emb.npy
./kornhauser-wiatr/kornhauser-wiatr_001.subvads.pyd
./kornhauser-wiatr/kornhauser-wiatr_001.vad.npy
./fraszki-ksiegi-pierwsze-epitafium-wysockiemu/jan-kochanowski-fraszki-ksiegi-pierwsze-epitafium-wysockiemu.spk_emb.npy
./fraszki-ksiegi-pierwsze-epitafium-wysockiemu/jan-kochanowski-fraszki-ksiegi-pierwsze-epitafium-wysockiemu.subvads.pyd
./fraszki-ksiegi-pierwsze-epitafium-wysockiemu/jan-kochanowski-fraszki-ksiegi-pierwsze-epitafium-wysockiemu.vad.npy
./kucharczyk-jak-modlitwa-ochrania-przed-zlodziejami/jak-modlitwa-ochrania-przed-zlodziejami.spk_emb.npy
./kucharczyk-jak-modlitwa-ochrania-przed-zlodziejami/jak-modlitwa-ochrania-przed-zlodziejami.subvads.pyd
./kucharczyk-jak-modlitwa-ochrania-przed-zlodziejami/jak-modlitwa-ochrania-przed-zlodziejami.vad.npy
./nowakowska-niska-rozdzielczosc-proba-wody/proba-wody.spk_emb.npy
./nowakowska-niska-rozdzielczosc-proba-wody/proba-wody.subvads.pyd
./nowakowska-niska-rozdzielczosc-pro

./grabinski-ksiega-ognia-bialy-wyrak/grabinski-ksiega-ognia-bialy-wyrak.subvads.pyd
./grabinski-ksiega-ognia-bialy-wyrak/grabinski-ksiega-ognia-bialy-wyrak.vad.npy
./konopnicka-w-polu/w-polu-pojdziemy-w-pole-w-ranny-czas.spk_emb.npy
./konopnicka-w-polu/w-polu-pojdziemy-w-pole-w-ranny-czas.subvads.pyd
./konopnicka-w-polu/w-polu-pojdziemy-w-pole-w-ranny-czas.vad.npy
./lis-i-osiel/ignacy-krasicki-bajki-i-przypowiesci-lis-i-osiel.spk_emb.npy
./lis-i-osiel/ignacy-krasicki-bajki-i-przypowiesci-lis-i-osiel.subvads.pyd
./lis-i-osiel/ignacy-krasicki-bajki-i-przypowiesci-lis-i-osiel.vad.npy
./do-delljusa/do-delljusa.spk_emb.npy
./do-delljusa/do-delljusa.subvads.pyd
./do-delljusa/do-delljusa.vad.npy
./satyry-czesc-pierwsza-zona-modna/satyry-czesc-pierwsza-zona-modna.spk_emb.npy
./satyry-czesc-pierwsza-zona-modna/satyry-czesc-pierwsza-zona-modna.subvads.pyd
./satyry-czesc-pierwsza-zona-modna/satyry-czesc-pierwsza-zona-modna.vad.npy
./janicki-i-nas-wybawi/i-nas-wybawi.spk_emb.npy
./j

In [None]:
#| exporti
def chunked_audio_dataset(shards, kind='maxvad'):
    return wds.WebDataset(shards).compose(
        wds.decode(utils.torch_audio_opus),
        utils.merge_in(utils.derived_dataset(kind)),
        utils.find_audio,
        lambda x: utils.split_to_chunks(x, metakeys=['spk_emb.npy']),
    )

In [None]:
ds = chunked_audio_dataset(['../wolnelektury-wds2/wolnelektury-audio-000000.tar'])
prev = None
for s in progress_bar(ds, total=6):
    sim = F.cosine_similarity(torch.tensor(s['spk_emb.npy']), torch.tensor((prev if prev is not None else s)['spk_emb.npy']), dim=0)
    if sim < 0.5: print("new")
    print(s['__key__'], sim, s['tend'] - s['tstart'], sum([e-s for s,e in s['orig_s']['subvads.pyd'][s['i']]]))
    display(IPython.display.Audio(s['samples'], rate=s['sample_rate']))
    time.sleep(.5)
    prev = s