In [None]:
#| default_exp prepare_t2s_dataset

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import sys
import os
import itertools
from pathlib import Path

import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

from fastprogress import progress_bar
from fastcore.script import *

import whisper, whisperx
from whisperspeech import vad, wh_transcribe, vq_stoks, extract_acoustic
import webdataset as wds

# T2S dataset preparation

We take a webdataset shard and extract semantic tokens and transcriptions from it.

We use VAD chunks merged with randomized maximum length to also generate some short samples.

In [None]:
#| exporti
def flac_to_t2s_name(input):
    return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz"

In [None]:
flac_to_t2s_name('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar')

'librilight-large-6454-t2s-000000.tar.gz'

In [None]:
#| exporti
class Transcriber:
    """
    A helper class to transcribe a batch of 30 second audio chunks.
    """
    def __init__(self, model_size, lang=False):
        self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang)
        # without calling vad_model at least once the rest segfaults for some reason...
        self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000})
        
    def transcribe(self, batch):
        batch = whisper.log_mel_spectrogram(batch)
        embs = self.model.model.encode(batch.cpu().numpy())
        return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in 
            self.model.model.model.generate(
                embs,
                [self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch),
            )])

In [None]:
#| exporti
@call_parse
def prepare_t2s(
    input:str,  # FLAC webdataset file path (or - to read the names from stdin)
    proc_dataset_path:Path, # processed VAD files path
    output:str=None, # output file name
    vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
    n_samples:int=None, # process a limited amount of samples
    batch_size:int=1, # process several segments at once
    transcription_model:str="small.en",
):
    if ":" in vq_model:
        repo, fname = vq_model.split(":", 1)
        vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
    else:
        vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
    transcriber = Transcriber(transcription_model)
        
    if input == "-":
        input = [f.strip() for f in sys.stdin.readlines()]
        assert output, "please provide the output shard name"
    else:
        if output is None: output = flac_to_t2s_name(input)
        input = [input]
        
    total = n_samples//batch_size if n_samples else 'noinfer'
    if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)")

    ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose(
        wds.decode(wds.torch_audio),
        vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
        wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}),
        lambda x: wh_transcribe.split_to_chunks(x),
        # drop the first and last segment because they tend to be inaccurate
        # (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix)
        wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
        wds.to_tuple('__key__', 'rpad', 'samples'),
        wds.batched(64),
    )

    dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)

    speakers = set()
    tmp = output+".tmp"
    with wds.TarWriter(tmp) as sink:
        for keys, rpads, samples in progress_bar(dl, total=total):
            with record_function('to_cuda'):
                csamples = samples.cuda()
            with record_function('transcribe'):
                txts = transcriber.transcribe(csamples)
            with record_function('vq_stoks'):
                stoks = vq_model.encode_audio(csamples)
            with record_function('from_cuda'):
                stoks = stoks.cpu().numpy().astype(np.int16)
            for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks):
                speakers.add(key.split('/')[1])
                sink.write({
                    "__key__": key,
                    "txt": txt,
                    "stoks.npy": _stoks[:int(-rpad/16000 * 25)],
                })
    with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
    if not n_samples:
        os.rename(tmp, output)

In [None]:
ds = wds.WebDataset(['/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar'], rename_files=vad.fix_dots_in_names).compose(wds.decode(wds.torch_audio))
sum([x['flac'][0].shape[-1]/16000/3600 for x in progress_bar(ds, total='noinfer')])

85.79258390624999

In [None]:
ds = wds.WebDataset(['/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar'], rename_files=vad.fix_dots_in_names).compose(
    wds.decode(wds.torch_audio),
    vq_stoks.merge_in(vq_stoks.derived_dataset('/data/whisperspeech-processed-wds/', 'vad')),
    wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}),
    lambda x: wh_transcribe.split_to_chunks(x),
    # drop the first and last segment because they tend to be inaccurate
    # (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix)
    wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
)
sum([x['samples'].shape[-1] for x in progress_bar(ds, total='noinfer')])

tensor([-15.7778, -14.8559, -13.3381,  ...,  -0.0186,   0.1787,   0.2196])

In [None]:
prepare_t2s('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=500, batch_size=32)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.9. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1+cu118. Bad things might happen unless you revert torch to 1.x.


In [None]:
prepare_t2s('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=500, batch_size=32)

Lightning automatically upgraded your loaded checkpoint from v1.5.4 to v2.0.9. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../../../.cache/torch/whisperx-vad-segmentation.bin`


Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.10.0+cu102, yours is 2.0.1+cu118. Bad things might happen unless you revert torch to 1.x.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [None]:
## Batch size tests

In [None]:
prepare_s2a('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=1000)

In [None]:
prepare_s2a('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=1000, batch_size=2)

In [None]:
prepare_s2a('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=1000, batch_size=4)

In [None]:
prepare_s2a('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=1000, batch_size=8)

In [None]:
# stoks only
prepare_s2a('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=1000, batch_size=4)

In [None]:
# atoks only
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    prepare_s2a('/data/whisperspeech-wds/librilight-large-6454-flac-000000.tar', '/data/whisperspeech-processed-wds/', vq_model='vqmodel-4e-hyptuned-32gpu.model', n_samples=10, batch_size=1)
prof.export_chrome_trace("trace-bs1.json")

STAGE:2023-10-06 14:25:45 71030:71030 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


STAGE:2023-10-06 14:25:47 71030:71030 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-10-06 14:25:47 71030:71030 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [None]:
!ls -lh librilight-large-6454-s2a-000000.tar.gz.tmp
!tar -tf librilight-large-6454-s2a-000000.tar.gz.tmp

-rw-r--r-- 1 root root 1.3M Oct  6 09:17 librilight-large-6454-s2a-000000.tar.gz.tmp
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_000.atoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_000.stoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_001.atoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_001.stoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_002.atoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_002.stoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_003.atoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_003.stoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_004.atoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_004.stoks.npy
large/6454/abaft_funnel_1307_librivox_64kb_mp3/funnel_01_kipling_64kb_005.a

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()