In [None]:
#| default_exp extract_stoks

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import sys
import os
import itertools
from pathlib import Path

import numpy as np
import torch
import torchaudio
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

from fastprogress import progress_bar
from fastcore.script import *

from speechbrain.pretrained import EncoderClassifier
from whisperspeech import vq_stoks, utils, vad_merge
import webdataset as wds

from whisperspeech.inference import get_compute_device

# Semantic token extraction

We take a webdataset shard and extract acoustic and semantic tokens from it.

We don't use the VAD data since the S2A should work on any random 30 second window.

In [None]:
vq_model = vq_stoks.RQBottleneckTransformer.load_model("vqmodel-medium-en+pl-512c-dim64.model").cuda()

In [None]:
vq_model.ensure_whisper('cuda')

In [None]:
vq_model.whmodel[0].encoder

AudioEncoder(
  (conv1): Conv1d(80, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(1024, 1024, kernel_size=(3,), stride=(2,), padding=(1,))
  (blocks): ModuleList(
    (0-23): 24 x ResidualAttentionBlock(
      (attn): MultiHeadAttention(
        (query): Linear(in_features=1024, out_features=1024, bias=True)
        (key): Linear(in_features=1024, out_features=1024, bias=False)
        (value): Linear(in_features=1024, out_features=1024, bias=True)
        (out): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (attn_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=4096, out_features=1024, bias=True)
      )
      (mlp_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (ln_post): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)

In [None]:
#| exporti
@call_parse
def prepare_stoks(
    input:str,  # FLAC webdataset file path (or - to read the names from stdin)
    vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks-v2.model", # the model path (use repo_id:filename to download it from hugginface)
    n_samples:int=None, # process a limited amount of samples
    batch_size:int=64, # process several segments at once
    kind:str="maxvad", # could be eqvad to get more uniform chunk lengths
    
):
    device = get_compute_device()
    vq_model = vq_stoks.RQBottleneckTransformer.load_model(vq_model).to(device)
    vq_model.ensure_whisper()
#     vq_model.encode_mel = torch.compile(vq_model.encode_mel, mode="reduce-overhead", fullgraph=True)
    
    spk_classifier = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
                                                    savedir=f"{os.environ['HOME']}/.cache/speechbrain/",
                                                    run_opts = {"device": device})
    
    total = n_samples//batch_size if n_samples else 'noinfer'

    ds = vad_merge.chunked_audio_dataset([input], kind).compose(
        utils.resampler(16000, 'samples_16k'),
        wds.to_tuple('__key__', 'rpad_s', 'samples_16k'),
        wds.batched(64),
    )

    dl = wds.WebLoader(ds, num_workers=1, batch_size=None).unbatched().batched(batch_size)

    with utils.AtomicTarWriter(utils.derived_name(input, f'{kind}-stoks', dir="."), throwaway=n_samples is not None) as sink:
        for keys, rpad_ss, samples16k in progress_bar(dl, total=total):
            with torch.no_grad():
                samples16k = samples16k.to(device).to(torch.float16)
                stoks = vq_model.encode_audio(samples16k).cpu().numpy().astype(np.int16)
                spk_embs = spk_classifier.encode_batch(
                   samples16k, wav_lens=torch.tensor(30 - rpad_ss, dtype=torch.float)/30)[:,0,:].cpu().numpy()
            for key, rpad_s, _stoks, spk_emb in zip(keys, rpad_ss, stoks, spk_embs):
                _stoks = _stoks[:int((30-rpad_s) * 25 + .5)]
                s = {
                    "__key__": key,
                    "stoks.npy": _stoks,
                }
                if spk_emb is not None: s["spk_emb.npy"] = spk_emb
                sink.write(s)
        sys.stdout.write("\n")

In [None]:
%pdb

Automatic pdb calling has been turned ON


In [None]:
prepare_stoks('../wolnelektury-wds2/wolnelektury-audio-000000.tar', n_samples=1024, batch_size=16)

In [None]:
prepare_stoks('../wolnelektury-wds2/wolnelektury-audio-000000.tar', n_samples=1024, batch_size=16)

In [None]:
prepare_stoks('../wolnelektury-wds2/wolnelektury-audio-000000.tar', n_samples=1024, batch_size=32)

In [None]:
prepare_stoks('../wolnelektury-wds2/wolnelektury-audio-000000.tar', n_samples=1024, batch_size=64)

In [None]:
prepare_stoks('../wolnelektury-wds2/wolnelektury-audio-000000.tar', n_samples=1024, batch_size=64)

In [None]:
!ls -lh ../wolnelektury-wds2/wolnelektury-maxvad-stoks-000000.tar
!tar -tf ../wolnelektury-wds2/wolnelektury-maxvad-stoks-000000.tar

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()