Skip to content

Commit

Permalink
Initial support for video (#1151)
Browse files Browse the repository at this point in the history
* Tutorial materials in main readme page

* Initial crude video support in AudioSource and Recording

* Add downsized test fixture video

* Support for loading video + audio at the same time

* Enforce consistent video and audio duration

* Support for changing video resolution

* Basic video support for most cut types

* Support for padded video MixedCuts

* Enforce audio duration and video duration to be consistent when creating Recording, solving appending/padding issues

* Add missing assertion

* Stricter tests for padding and appending video cuts

* Minimal set of utilities for PyTorch video dataloading

* Grid audio-visual speech corpus recipe + support videos with missing num frames in their header

* Skip video test for PyTorch < 2.0

* Fix issue with torchaudio.info usage
  • Loading branch information
pzelasko committed Sep 21, 2023
1 parent 3dde48d commit 7b60f86
Show file tree
Hide file tree
Showing 27 changed files with 1,487 additions and 53 deletions.
11 changes: 10 additions & 1 deletion docs/corpus.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Standard data preparation recipes
We provide a number of standard data preparation recipes. By that, we mean a collection of a Python function +
a CLI tool that create the manifests given a corpus directory.

.. list-table:: Currently supported corpora
.. list-table:: Currently supported audio corpora
:widths: 30 50
:header-rows: 1

Expand Down Expand Up @@ -193,6 +193,15 @@ a CLI tool that create the manifests given a corpus directory.
- :func:`lhotse.recipes.xbmu_amdo31`


.. list-table:: Currently supported video corpora
:widths: 30 50
:header-rows: 1

* - Corpus name
- Function
* - Grid Audio-Visual Speech Corpus
- :func:`lhotse.recipes.prepare_grid`

Adding new corpora
------------------

Expand Down
1 change: 1 addition & 0 deletions lhotse/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .utils import (
AudioLoadingError,
DurationMismatchError,
VideoInfo,
get_audio_duration_mismatch_tolerance,
null_result_on_audio_loading_error,
set_audio_duration_mismatch_tolerance,
Expand Down
161 changes: 137 additions & 24 deletions lhotse/audio/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import numpy as np
import torch

from lhotse.audio.utils import AudioLoadingError, verbose_audio_loading_exceptions
from lhotse.audio.utils import (
AudioLoadingError,
VideoInfo,
verbose_audio_loading_exceptions,
)
from lhotse.augmentation import Resample
from lhotse.utils import Pathlike, Seconds, compute_num_samples

Expand Down Expand Up @@ -398,6 +402,7 @@ class LibsndfileCompatibleAudioInfo(NamedTuple):
frames: int
samplerate: int
duration: float
video: Optional[VideoInfo] = None


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -426,7 +431,7 @@ def torchaudio_2_0_ffmpeg_enabled() -> bool:
from packaging import version

ver = version.parse(torchaudio.__version__)
if ver == version.parse("2.0.0"):
if ver >= version.parse("2.0"):
return os.environ.get("TORCHAUDIO_USE_BACKEND_DISPATCHER", "0") == "1"
if ver >= version.parse("2.1.0"):
return True
Expand Down Expand Up @@ -515,6 +520,101 @@ def torchaudio_info(
)


def torchaudio_ffmpeg_streamer_info(
path_or_fileobj: Union[Path, str, BytesIO]
) -> LibsndfileCompatibleAudioInfo:
from torchaudio.io import StreamReader

is_fileobj = not isinstance(path_or_fileobj, Path)
is_mpeg = not is_fileobj and any(
str(path_or_fileobj).endswith(ext) for ext in (".mp3", ".mp4", ".m4a")
)
if not is_fileobj:
path_or_fileobj = str(path_or_fileobj)
stream = StreamReader(path_or_fileobj)

# Collect the information about available video and audio streams.
num_streams = stream.num_src_streams
audio_streams = {}
video_streams = {}
for stream_idx in range(num_streams):
info = stream.get_src_stream_info(stream_idx)
if info.media_type == "video":
video_streams[stream_idx] = info
elif info.media_type == "audio":
audio_streams[stream_idx] = info
else:
raise RuntimeError(f"Unexpected media_type: {info}")

assert (
len(video_streams) < 2
), f"Lhotse currently does not support more than one video stream in a file (found {len(video_streams)})."
assert len(audio_streams) < 2, (
f"Lhotse currently does not support files with more than a single FFMPEG "
f"audio stream yet (found {len(audio_streams)}). "
f"Note that this is not the same as multi-channel which is generally supported."
)

meta = {}

if video_streams:
((video_stream_idx, video_stream),) = list(video_streams.items())
tot_frames = video_stream.num_frames

if tot_frames == 0: # num frames not available in header/metadata
stream.add_basic_video_stream(
round(video_stream.frame_rate), stream_index=video_stream_idx
)
for (chunk,) in stream.stream():
tot_frames += chunk.shape[0]
stream.remove_stream(0)

meta["video"] = VideoInfo(
fps=video_stream.frame_rate,
height=video_stream.height,
width=video_stream.width,
num_frames=tot_frames,
)

if audio_streams:
((audio_stream_idx, audio_stream),) = list(audio_streams.items())
stream.add_basic_audio_stream(
frames_per_chunk=int(audio_stream.sample_rate),
stream_index=audio_stream_idx,
)

def _try_read_num_samples():
if is_mpeg or is_fileobj:
# These cases often have insufficient or corrupted metadata, so we might need to scan
# the full audio stream to learn the actual number of frames. If video is available,
# we can quickly verify before performing the costly reading.
video_info = meta.get("video", None)
if video_info is not None:
audio_duration = audio_stream.num_frames / audio_stream.sample_rate
# for now 1ms tolerance
if abs(audio_duration - video_info.duration) < 1e-3:
return audio_stream.num_frames
return 0
else:
return audio_stream.num_frames

tot_samples = _try_read_num_samples()
if tot_samples == 0:
# There was a mismatch between video and audio duration in metadata,
# we'll have to read the file to figure it out.
for (chunk,) in stream.stream():
tot_samples += chunk.shape[0]

meta.update(
channels=audio_stream.num_channels,
frames=tot_samples,
samplerate=int(audio_stream.sample_rate),
duration=tot_samples / audio_stream.sample_rate,
)

return LibsndfileCompatibleAudioInfo(**meta)


def torchaudio_load(
path_or_fd: Pathlike, offset: Seconds = 0, duration: Optional[Seconds] = None
) -> Tuple[np.ndarray, int]:
Expand Down Expand Up @@ -554,9 +654,9 @@ def torchaudio_2_ffmpeg_load(
if offset > 0 or duration is not None:
audio_info = torchaudio.info(path_or_fd, backend="ffmpeg")
if offset > 0:
frame_offset = compute_num_samples(offset, audio_info.samplerate)
frame_offset = compute_num_samples(offset, audio_info.sample_rate)
if duration is not None:
num_frames = compute_num_samples(duration, audio_info.samplerate)
num_frames = compute_num_samples(duration, audio_info.sample_rate)
if isinstance(path_or_fd, IOBase):
# Set seek pointer to the beginning of the file as torchaudio.info
# might have left it at the end of the header
Expand Down Expand Up @@ -643,7 +743,7 @@ def audioread_info(path: Pathlike) -> LibsndfileCompatibleAudioInfo:
# We just read the file and compute the number of samples
# -- no other method seems fully reliable...
with audioread.audio_open(
path, backends=_available_audioread_backends()
str(path), backends=_available_audioread_backends()
) as input_file:
shape = audioread_load(input_file)[0].shape
if len(shape) == 1:
Expand Down Expand Up @@ -928,6 +1028,18 @@ def parse_channel_from_ffmpeg_output(ffmpeg_stderr: bytes) -> str:
)


def soundfile_info(path: Pathlike) -> LibsndfileCompatibleAudioInfo:
import soundfile as sf

info_ = sf.info(str(path))
return LibsndfileCompatibleAudioInfo(
channels=info_.channels,
frames=info_.frames,
samplerate=info_.samplerate,
duration=info_.duration,
)


def sph_info(path: Pathlike) -> LibsndfileCompatibleAudioInfo:
samples, sampling_rate = read_sph(path)
return LibsndfileCompatibleAudioInfo(
Expand Down Expand Up @@ -1035,6 +1147,15 @@ def info(

is_path = isinstance(path, (Path, str))

if is_path and Path(path).suffix.lower() == ".sph":
# We handle SPHERE as another special case because some old codecs (i.e. "shorten" codec)
# can't be handled by neither pysoundfile nor pyaudioread.
return sph_info(path)

if is_path and Path(path).suffix.lower() == ".opus":
# We handle OPUS as a special case because we might need to force a certain sampling rate.
return opus_info(path, force_opus_sampling_rate=force_opus_sampling_rate)

if force_read_audio:
# This is a reliable fallback for situations when the user knows that audio files do not
# have duration metadata in their headers.
Expand All @@ -1043,27 +1164,19 @@ def info(
assert (
is_path
), f"info(obj, force_read_audio=True) is not supported for object of type: {type(path)}"
return audioread_info(str(path))

if is_path and Path(path).suffix.lower() == ".opus":
# We handle OPUS as a special case because we might need to force a certain sampling rate.
return opus_info(path, force_opus_sampling_rate=force_opus_sampling_rate)

if is_path and Path(path).suffix.lower() == ".sph":
# We handle SPHERE as another special case because some old codecs (i.e. "shorten" codec)
# can't be handled by neither pysoundfile nor pyaudioread.
return sph_info(path)
return audioread_info(path)

try:
# Try to parse the file using torchaudio first.
return torchaudio_info(path)
if torchaudio_2_0_ffmpeg_enabled():
return torchaudio_ffmpeg_streamer_info(path)
else: # hacky but easy way to proceed...
raise Exception("Skipping - torchaudio ffmpeg streamer unavailable")
except:
try:
# Try to parse the file using pysoundfile as a fallback.
import soundfile as sf

return sf.info(str(path))
return torchaudio_info(path)
except:
# Try to parse the file using audioread as the last fallback.
return audioread_info(str(path))
# If both fail, then Python 3 will display both exception messages.
try:
return soundfile_info(path)
except:
return audioread_info(path)
# If all fail, then Python 3 will display all exception messages.
102 changes: 102 additions & 0 deletions lhotse/audio/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

import numpy as np
import torch

from lhotse.utils import Decibels, Seconds, compute_num_samples

Expand Down Expand Up @@ -173,3 +174,104 @@ def add_to_mix(

def audio_energy(audio: np.ndarray) -> float:
return float(np.average(audio**2))


class VideoMixer:
"""
Simple video "mixing" class that actually does not mix anything but supports concatenation.
"""

def __init__(
self,
base_video: torch.Tensor,
fps: float,
base_offset: Seconds = 0.0,
):
from intervaltree import IntervalTree

self.tracks = [base_video]
self.offsets = [compute_num_samples(base_offset, fps)]
self.fps = fps
self.dtype = self.tracks[0].dtype
self.tree = IntervalTree()
self.tree.addi(self.offsets[0], self.offsets[0] + base_video.shape[0])

def _pad_track(
self, video: torch.Tensor, offset: int, total: Optional[int] = None
) -> torch.Tensor:
if total is None:
total = video.shape[0] + offset
assert (
video.shape[0] + offset <= total
), f"{video.shape[0]} + {offset} <= {total}"
return torch.nn.functional.pad(
video,
(0, 0, 0, 0, 0, 0, offset, total - video.shape[0] - offset),
mode="constant",
value=0,
)
# return torch.from_numpy(
# np.pad(
# video.numpy(),
# pad_width=(
# (offset, total - video.shape[0] - offset),
# (0, 0),
# (0, 0),
# (0, 0),
# ),
# )
# )

@property
def num_frames_total(self) -> int:
longest = 0
for offset, video in zip(self.offsets, self.tracks):
longest = max(longest, offset + video.shape[0])
return longest

@property
def unmixed_video(self) -> List[torch.Tensor]:
"""
Return a list of numpy arrays with the shape (C, num_samples), where each track is
zero padded and scaled adequately to the offsets and SNR used in ``add_to_mix`` call.
"""
total = self.num_frames_total
return [
self._pad_track(track, offset=offset, total=total)
for offset, track in zip(self.offsets, self.tracks)
]

@property
def mixed_video(self) -> torch.Tensor:
"""
Return a numpy ndarray with the shape (num_channels, num_samples) - a mix of the tracks
supplied with ``add_to_mix`` calls.
"""
total = self.num_frames_total
mixed = self.tracks[0].new_zeros((total,) + self.tracks[0].shape[1:])
for offset, track in zip(self.offsets, self.tracks):
mixed[offset : offset + track.shape[0]] = track
return mixed

def add_to_mix(
self,
video: torch.Tensor,
offset: Seconds = 0.0,
):
if video.size == 0:
return # do nothing for empty arrays

assert offset >= 0.0, "Negative offset in mixing is not supported."
frame_offset = compute_num_samples(offset, self.fps)

from intervaltree import Interval

interval = Interval(frame_offset, frame_offset + video.shape[0])
assert not self.tree.overlaps(interval), (
f"Cannot add an overlapping video. Got {interval} while we "
f"have the following intervals: {self.tree.all_intervals()}"
)

self.tracks.append(video)
self.offsets.append(frame_offset)
self.tree.add(interval)
Loading

0 comments on commit 7b60f86

Please sign in to comment.