In [None]:
import logging
from pathlib import Path

from howl.context import InferenceContext
from howl.data.dataset.dataset import DatasetType, WakeWordDataset
from howl.data.dataset.dataset_loader import WakeWordDatasetLoader

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
DATASET = "hey_fire_fox"
DATASET_PATHS = (
    f"../data/raw/{DATASET}/{sample_type}" for sample_type in ("positive", "negative")
)
SAMPLE_RATE = 16000
TOKEN_TYPE = "word"
VOCAB = ["hey", "fire", "fox"]

In [None]:
ctx = InferenceContext(VOCAB, token_type=TOKEN_TYPE, use_blank=False)
loader = WakeWordDatasetLoader()
ds_kwargs = dict(sample_rate=SAMPLE_RATE, mono=True, frame_labeler=ctx.labeler)

ww_train_ds, ww_dev_ds, ww_test_ds = (
    WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs),
    WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs),
    WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs),
)
for ds_path in DATASET_PATHS:
    ds_path = Path(ds_path)
    train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs)
    ww_train_ds.extend(train_ds)
    ww_dev_ds.extend(dev_ds)
    ww_test_ds.extend(test_ds)

In [None]:
for ds in ww_train_ds, ww_dev_ds, ww_test_ds:
    ds.print_stats(
        logger,
        header=f"Wake word dataset: {ds.set_type}",
        word_searcher=ctx.searcher,
        compute_length=True,
    )

In [None]:
for ds in ww_dev_ds, ww_test_ds:
    pos_ds = ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    pos_ds.print_stats(
        logger,
        header=f"Pos dataset: {pos_ds.set_type}",
        word_searcher=ctx.searcher,
        compute_length=True,
    )

    neg_ds = ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)
    neg_ds.print_stats(
        logger,
        header=f"Neg dataset: {neg_ds.set_type}",
        word_searcher=ctx.searcher,
        compute_length=True,
    )