Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending Lhotse dataloading to text/multimodal data #1295

Merged
merged 9 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions docs/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,76 @@ However, many functions and classes in Lhotse accept either a random seed or an
.. note:: The lazy seed resolution is done by calling :func:`lhotse.dataset.dataloading.resolve_seed`.


Customizing sampling constraints
--------------------------------

Since version 1.22.0, Lhotse provides a mechanism to customize how samplers measure the "length" of each example
for the purpose of determining dynamic batch size. To leverage this option, use the keyword argument ``constraint``
in :class:`~lhotse.dataset.sampling.DynamicCutSampler` or :class:`~lhotse.dataset.sampling.DynamicBucketingSampler`.
The sampling criteria are defined by implementing a subclass of :class:`~lhotse.dataset.sampling.base.SamplingConstraint`:

.. autoclass:: lhotse.dataset.sampling.base.SamplingConstraint
:members:

The default constraint is :class:`~lhotse.dataset.sampling.base.TimeConstraint` which is created from
``max_duration``, ``max_cuts``, and ``quadratic_duration`` args passed to samplers constructor.

Sampling non-audio data
***********************

Because :class:`~lhotse.dataset.sampling.base.SamplingConstraint` defines the method ``measure_length``,
it's possible to use a different attribute than duration (or a different formula) for computing the effective batch size.
This enables re-using Lhotse's sampling algorithms for other data than speech, and passing around other objects than :class:`~lhotse.cut.Cut`.

To showcase this, we added an experimental support for text-only dataloading. We introduced a few classes specifically for this purpose:

.. autoclass:: lhotse.cut.text.TextExample
:members:

.. autoclass:: lhotse.cut.text.TextPairExample
:members:

.. autoclass:: lhotse.lazy.LazyTxtIterator
:members:

.. autoclass:: lhotse.dataset.sampling.base.TokenConstraint
:members:

A minimal example of how to perform text-only dataloading is available below (note that any of these classes may be replaced by your own implementation if that is more suitable to your work)::

import torch
import numpy as np
from lhotse import CutSet
from lhotse.lazy import LazyTxtIterator
from lhotse.cut.text import TextPairExample
from lhotse.dataset import DynamicBucketingSampler, TokenConstraint
from lhotse.dataset.collation import collate_vectors

examples = CutSet(LazyTxtIterator("data.txt"))

def tokenize(example):
# tokenize as individual bytes; BPE or another technique may be used here instead
example.tokens = np.frombuffer(example.text.encode("utf-8"), np.int8)
return example

examples = examples.map(tokenize, apply_fn=None)

sampler = DynamicBucketingSampler(examples, constraint=TokenConstraint(max_tokens=1024, quadratic_length=128), num_buckets=2)

class ExampleTextDataset(torch.utils.data.Dataset):
def __getitem__(self, examples: CutSet):
tokens = [ex.tokens for ex in examples]
token_lens = torch.tensor([len(t) for t in tokens])
tokens = collate_vectors(tokens, padding_value=-1)
return tokens, token_lens

dloader = torch.utils.data.DataLoader(ExampleTextDataset(), sampler=sampler, batch_size=None)

for batch in dloader:
print(batch)

.. note:: Support for this kind of dataloading is experimental in Lhotse. If you run into any rough edges, please let us know.

Dataset's list
--------------

Expand Down
130 changes: 130 additions & 0 deletions lhotse/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from functools import partial
from typing import Any, Dict, Optional

import numpy as np

from lhotse import Recording
from lhotse.utils import ifnone


class CustomFieldMixin:
"""
:class:`CustomFieldMixin` is intended for classes such as Cut or SupervisionSegment
that support holding custom, user-defined fields.

.. caution:: Due to the way inheritance and dataclasses work before Python 3.10,
it is necessary to re-define ``custom`` attribute in dataclasses that
inherit from this mixin.
"""

def __init__(self, custom: Optional[Dict[str, Any]]) -> None:
self.custom: Optional[Dict[str, Any]] = custom

Check warning on line 21 in lhotse/custom.py

View check run for this annotation

Codecov / codecov/patch

lhotse/custom.py#L21

Added line #L21 was not covered by tests

def __setattr__(self, key: str, value: Any) -> None:
"""
This magic function is called when the user tries to set an attribute.
We use it as syntactic sugar to store custom attributes in ``self.custom``
field, so that they can be (de)serialized later.
Setting a ``None`` value will remove the attribute from ``custom``.
"""
if key in self.__dataclass_fields__:
super().__setattr__(key, value)
else:
custom = ifnone(self.custom, {})
if value is None:
custom.pop(key, None)
else:
custom[key] = value
if custom:
self.custom = custom

def __getattr__(self, name: str) -> Any:
"""
This magic function is called when the user tries to access an attribute
of :class:`.MonoCut` that doesn't exist. It is used for accessing the custom
attributes of cuts.

We use it to look up the ``custom`` field: when it's None or empty,
we'll just raise AttributeError as usual.
If ``item`` is found in ``custom``, we'll return ``custom[item]``.
If ``item`` starts with "load_", we'll assume the name of the relevant
attribute comes after that, and that value of that field is of type
:class:`~lhotse.array.Array` or :class:`~lhotse.array.TemporalArray`.
We'll return its ``load`` method to call by the user.

Example of attaching and reading an alignment as TemporalArray::

>>> cut = MonoCut('cut1', start=0, duration=4, channel=0)
>>> cut.alignment = TemporalArray(...)
>>> ali = cut.load_alignment()

"""
custom = self.custom
if custom is None:
raise AttributeError(f"No such attribute: {name}")
if name in custom:
# Somebody accesses raw [Temporal]Array manifest
# or wrote a custom piece of metadata into MonoCut.
return self.custom[name]
elif name.startswith("load_"):
# Return the method for loading [Temporal]Arrays,
# to be invoked by the user.
attr_name = name[5:]
return partial(self.load_custom, attr_name)
raise AttributeError(f"No such attribute: {name}")

def __delattr__(self, key: str) -> None:
"""Used to support ``del cut.custom_attr`` syntax."""
if key in self.__dataclass_fields__:
super().__delattr__(key)

Check warning on line 79 in lhotse/custom.py

View check run for this annotation

Codecov / codecov/patch

lhotse/custom.py#L79

Added line #L79 was not covered by tests
if self.custom is None or key not in self.custom:
raise AttributeError(f"No such member: '{key}'")
del self.custom[key]

def load_custom(self, name: str) -> np.ndarray:
"""
Load custom data as numpy array. The custom data is expected to have
been stored in cuts ``custom`` field as an :class:`~lhotse.array.Array` or
:class:`~lhotse.array.TemporalArray` manifest.

.. note:: It works with Array manifests stored via attribute assignments,
e.g.: ``cut.my_custom_data = Array(...)``.

:param name: name of the custom attribute.
:return: a numpy array with the data.
"""
from lhotse.array import Array, TemporalArray

value = self.custom.get(name)
if isinstance(value, Array):
# Array does not support slicing.
return value.load()
elif isinstance(value, TemporalArray):
# TemporalArray supports slicing.
return value.load(start=self.start, duration=self.duration)
elif isinstance(value, Recording):
# Recording supports slicing. Note: we will not slice the channels
# as cut.channels referes to cut.recording and not the custom field.
return value.load_audio(offset=self.start, duration=self.duration)
else:
raise ValueError(
f"To load {name}, the cut needs to have field {name} (or cut.custom['{name}']) "
f"defined, and its value has to be a manifest of type Array or TemporalArray."
)

def has_custom(self, name: str) -> bool:
"""
Check if the Cut has a custom attribute with name ``name``.

:param name: name of the custom attribute.
:return: a boolean.
"""
if self.custom is None:
return False
return name in self.custom

def drop_custom(self, name: str):
if self.custom is None or name not in self.custom:
return None
del self.custom[name]
return self

Check warning on line 130 in lhotse/custom.py

View check run for this annotation

Codecov / codecov/patch

lhotse/custom.py#L127-L130

Added lines #L127 - L130 were not covered by tests
108 changes: 2 additions & 106 deletions lhotse/cut/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from decimal import ROUND_DOWN
from functools import partial
from math import isclose
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

Expand All @@ -12,6 +11,7 @@

from lhotse.audio import Recording, VideoInfo
from lhotse.augmentation import AugmentFn
from lhotse.custom import CustomFieldMixin
from lhotse.cut.base import Cut
from lhotse.features import FeatureExtractor, Features
from lhotse.features.io import FeaturesWriter
Expand All @@ -25,7 +25,6 @@
compute_num_frames,
compute_num_samples,
fastcopy,
ifnone,
measure_overlap,
overlaps,
overspans,
Expand All @@ -36,7 +35,7 @@


@dataclass
class DataCut(Cut, metaclass=ABCMeta):
class DataCut(Cut, CustomFieldMixin, metaclass=ABCMeta):
"""
:class:`~lhotse.cut.DataCut` is a base class for cuts that point to actual audio data.
It can be either a :class:`~lhotse.cut.MonoCut` or a :class:`~lhotse.cut.MultiCut`.
Expand Down Expand Up @@ -71,109 +70,6 @@ class DataCut(Cut, metaclass=ABCMeta):
# Store anything else the user might want.
custom: Optional[Dict[str, Any]] = None

def __setattr__(self, key: str, value: Any) -> None:
"""
This magic function is called when the user tries to set an attribute.
We use it as syntactic sugar to store custom attributes in ``self.custom``
field, so that they can be (de)serialized later.
Setting a ``None`` value will remove the attribute from ``custom``.
"""
if key in self.__dataclass_fields__:
super().__setattr__(key, value)
else:
custom = ifnone(self.custom, {})
if value is None:
custom.pop(key, None)
else:
custom[key] = value
if custom:
self.custom = custom

def __getattr__(self, name: str) -> Any:
"""
This magic function is called when the user tries to access an attribute
of :class:`.MonoCut` that doesn't exist. It is used for accessing the custom
attributes of cuts.

We use it to look up the ``custom`` field: when it's None or empty,
we'll just raise AttributeError as usual.
If ``item`` is found in ``custom``, we'll return ``custom[item]``.
If ``item`` starts with "load_", we'll assume the name of the relevant
attribute comes after that, and that value of that field is of type
:class:`~lhotse.array.Array` or :class:`~lhotse.array.TemporalArray`.
We'll return its ``load`` method to call by the user.

Example of attaching and reading an alignment as TemporalArray::

>>> cut = MonoCut('cut1', start=0, duration=4, channel=0)
>>> cut.alignment = TemporalArray(...)
>>> ali = cut.load_alignment()

"""
custom = self.custom
if custom is None:
raise AttributeError(f"No such attribute: {name}")
if name in custom:
# Somebody accesses raw [Temporal]Array manifest
# or wrote a custom piece of metadata into MonoCut.
return self.custom[name]
elif name.startswith("load_"):
# Return the method for loading [Temporal]Arrays,
# to be invoked by the user.
attr_name = name[5:]
return partial(self.load_custom, attr_name)
raise AttributeError(f"No such attribute: {name}")

def __delattr__(self, key: str) -> None:
"""Used to support ``del cut.custom_attr`` syntax."""
if key in self.__dataclass_fields__:
super().__delattr__(key)
if self.custom is None or key not in self.custom:
raise AttributeError(f"No such member: '{key}'")
del self.custom[key]

def load_custom(self, name: str) -> np.ndarray:
"""
Load custom data as numpy array. The custom data is expected to have
been stored in cuts ``custom`` field as an :class:`~lhotse.array.Array` or
:class:`~lhotse.array.TemporalArray` manifest.

.. note:: It works with Array manifests stored via attribute assignments,
e.g.: ``cut.my_custom_data = Array(...)``.

:param name: name of the custom attribute.
:return: a numpy array with the data.
"""
from lhotse.array import Array, TemporalArray

value = self.custom.get(name)
if isinstance(value, Array):
# Array does not support slicing.
return value.load()
elif isinstance(value, TemporalArray):
# TemporalArray supports slicing.
return value.load(start=self.start, duration=self.duration)
elif isinstance(value, Recording):
# Recording supports slicing. Note: we will not slice the channels
# as cut.channels referes to cut.recording and not the custom field.
return value.load_audio(offset=self.start, duration=self.duration)
else:
raise ValueError(
f"To load {name}, the cut needs to have field {name} (or cut.custom['{name}']) "
f"defined, and its value has to be a manifest of type Array or TemporalArray."
)

def has_custom(self, name: str) -> bool:
"""
Check if the Cut has a custom attribute with name ``name``.

:param name: name of the custom attribute.
:return: a boolean.
"""
if self.custom is None:
return False
return name in self.custom

@property
def recording_id(self) -> str:
return self.recording.id if self.has_recording else self.features.recording_id
Expand Down
Loading
Loading