From 83afef059ba1a29eb92bb6cb922f1f8e0ffd5965 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Sat, 27 Jan 2024 15:12:16 -0500 Subject: [PATCH] Download worker refactor (#288) * ClippingSubsampler rewrite and bug fixes * More refactoring of ClippingSubsampler, plus a fix to _get_clip_intervals * Finished refactoring ClippingSubsampler * Final code changes * Added docstrings * Passed tests and linting * Made type annotations consistent with Python 3.8 * More annotation fixes * The Python 3.8 annotation needs a lot of hand-holding, it seems * Pylint has to cut it out, I swear to God * No real change, just relauching unit tests which failed due to connection timeouts * Linting issue * Another linting issue * Separated per-shard code from code that should only be executed once * Pulled ShardStatus parameters into their own data type * Cleaned up shard processing error handling * Cleaned up code * Bug fixes * Formatting * Fixed linting issues * Fixing more damn linting * Added a missing docstring * Unified SubsetWorker and DownloadWorker code * Bug fixes * Linting * Linting again * Forgot a docstring * Removed unnecessary manual thread handling * Removed unused import --------- Co-authored-by: iejMac Co-authored-by: Romain Beaumont --- video2dataset/subsamplers/__init__.py | 2 + video2dataset/types.py | 4 + video2dataset/workers/download_worker.py | 304 +++++++---------------- video2dataset/workers/subset_worker.py | 197 +++------------ video2dataset/workers/worker.py | 197 +++++++++++++++ 5 files changed, 330 insertions(+), 374 deletions(-) create mode 100644 video2dataset/workers/worker.py diff --git a/video2dataset/subsamplers/__init__.py b/video2dataset/subsamplers/__init__.py index 90e4cd58..53b4141f 100644 --- a/video2dataset/subsamplers/__init__.py +++ b/video2dataset/subsamplers/__init__.py @@ -12,3 +12,5 @@ from .optical_flow_subsampler import OpticalFlowSubsampler from .whisper_subsampler import WhisperSubsampler from .caption_subsampler import CaptionSubsampler + +from .subsampler import Subsampler diff --git a/video2dataset/types.py b/video2dataset/types.py index 77240d32..03ae7c1e 100644 --- a/video2dataset/types.py +++ b/video2dataset/types.py @@ -10,3 +10,7 @@ class EncodeFormats(TypedDict, total=False): class Streams(TypedDict, total=False): video: List[bytes] audio: List[bytes] + + +# TODO: make more structured +Metadata = dict diff --git a/video2dataset/workers/download_worker.py b/video2dataset/workers/download_worker.py index d008f921..3c1bdac4 100644 --- a/video2dataset/workers/download_worker.py +++ b/video2dataset/workers/download_worker.py @@ -1,29 +1,15 @@ """the downloader module handles the downloading""" - +import fsspec import math -import time +from multiprocessing.pool import ThreadPool import pyarrow as pa +import time import traceback - -import fsspec - -from multiprocessing.pool import ThreadPool -from threading import Semaphore -from typing import List, Any -import numpy as np +from typing import cast from video2dataset.data_reader import VideoDataReader -from video2dataset.logger import CappedCounter from video2dataset.logger import write_stats -from video2dataset.subsamplers import ( - ClippingSubsampler, - CutDetectionSubsampler, - FrameSubsampler, - FFProbeSubsampler, - NoOpSubsampler, - ResolutionSubsampler, - AudioRateSubsampler, -) +from video2dataset.workers.worker import ShardStatus, Streams, get_subsamplers, process_sample def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count): @@ -52,252 +38,154 @@ def __init__( self.save_caption = save_caption self.output_folder = output_folder self.column_list = column_list - self.encode_formats = encode_formats + self.input_encode_formats = encode_formats self.config = config - self.data_reader = VideoDataReader(encode_formats, tmp_dir, config["reading"]) - - self.clipping_subsampler = ClippingSubsampler( - 5, # oom_clip_count + self.url_indice = self.column_list.index("url") + self.caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None + self.oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"])) + self.subsamplers, self.output_encode_formats = get_subsamplers( + config, encode_formats, - **self.config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], + do_clipping=("clips" in self.column_list), ) - need_keyframes = self.clipping_subsampler.precision == "keyframe_adjusted" - - self.ffprobe_subsampler = None - if "FFProbeSubsampler" in self.config["subsampling"] or need_keyframes: - self.ffprobe_subsampler = FFProbeSubsampler( - **self.config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] - ) - self.ffprobe_subsampler.extract_keyframes |= need_keyframes - - self.cut_detector = None - self.cuts_are_clips = False - if "CutDetectionSubsampler" in self.config["subsampling"]: - if "args" in self.config["subsampling"]["CutDetectionSubsampler"]: - self.cut_detector = CutDetectionSubsampler( - **self.config["subsampling"]["CutDetectionSubsampler"]["args"] - ) - self.cuts_are_clips = self.config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) - - self.noop_subsampler = NoOpSubsampler() - - video_subsamplers: List[Any] = [] - if "ResolutionSubsampler" in self.config["subsampling"]: - video_subsamplers.append(ResolutionSubsampler(**self.config["subsampling"]["ResolutionSubsampler"]["args"])) - if "FrameSubsampler" in self.config["subsampling"]: - video_subsamplers.append(FrameSubsampler(**self.config["subsampling"]["FrameSubsampler"]["args"])) - - audio_subsamplers: List[Any] = [] - if "AudioRateSubsampler" in self.config["subsampling"]: - audio_subsamplers.append(AudioRateSubsampler(**self.config["subsampling"]["AudioRateSubsampler"]["args"])) - - self.subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} def __call__( self, row, ): try: - self.download_shard(row) + shard_file, shard_id = row + self.process_shard(shard_file, shard_id) return (True, row) except Exception as err: # pylint: disable=broad-except traceback.print_exc() print(f"shard {row[0]} failed with error {err}") return (False, row) - def download_shard( + def get_shard_processors( self, - row, + shard_file: str, + shard_id: int, ): - """Function to start an video downloading in one process""" - - # shard_id, shard_file = row - shard_file, shard_id = row - start_time = time.time() + """Get objects for loading and writing data""" fs, shard_path = fsspec.core.url_to_fs(shard_file) + print(shard_path) with fs.open(shard_path, "rb") as f: df = pa.ipc.open_file(f).read_all() + schema = df.schema schema = df.schema schema = ( schema.append(pa.field("key", pa.string())) .append(pa.field("status", pa.string())) .append(pa.field("error_message", pa.string())) ) - + shard_sample_writer = self.sample_writer_class( + shard_id, + self.output_folder, + self.save_caption, + self.config["storage"]["oom_shard_count"], + schema, + self.output_encode_formats, + ) pydict = df.select(self.column_list).to_pydict() shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list)))) - del pydict - del df - - status_dict = CappedCounter() - count = len(shard_to_dl) - successes = 0 - failed = { - "failed_to_download": 0, - "failed_to_subsample": 0, - } - bytes_downloaded = 0 - url_indice = self.column_list.index("url") - caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None - key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl] + def rm_shard_path(): + fs.rm(shard_path) - semaphore = Semaphore(self.config["distribution"]["thread_count"]) + return shard_sample_writer, shard_to_dl, rm_shard_path - def data_generator(): - for e in key_url_list: - semaphore.acquire() # pylint: disable=(consider-using-with) - yield e + def process_shard( + self, + shard_file: str, + shard_id: int, + ): + """Function to start an video downloading in one process""" - loader = data_generator() + start_time = time.time() + shard_sample_writer, shard_to_dl, rm_shard_path = self.get_shard_processors(shard_file, shard_id) + shard_status = ShardStatus(count=len(shard_to_dl)) - # The subsamplers might change the output format, so we need to update the writer - writer_encode_formats = self.encode_formats.copy() - if self.subsamplers["audio"]: - writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"] - if self.subsamplers["video"]: - writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"] + def data_generator(): + for key_and_url in [(key, x[self.url_indice]) for key, x in shard_to_dl]: + yield key_and_url - # give schema to writer - sample_writer = self.sample_writer_class( - shard_id, - self.output_folder, - self.save_caption, - self.config["storage"]["oom_shard_count"], - schema, - writer_encode_formats, - ) - oom_sample_per_shard = math.ceil(math.log10(self.config["storage"]["number_sample_per_shard"])) + data_reader_call_param_generator = data_generator() with ThreadPool(self.config["distribution"]["thread_count"]) as thread_pool: - for key, streams, yt_meta_dict, error_message in thread_pool.imap_unordered( + for key, streams, yt_meta_dict, shard_status.error_message in thread_pool.imap_unordered( self.data_reader, # pylint: disable=(unnecessary-lambda) - loader, + data_reader_call_param_generator, ): try: _, sample_data = shard_to_dl[key] str_key = compute_key( - key, shard_id, oom_sample_per_shard, self.config["storage"]["oom_shard_count"] + key, shard_id, self.oom_sample_per_shard, self.config["storage"]["oom_shard_count"] ) - meta = { + caption = sample_data[self.caption_indice] if self.caption_indice is not None else None + metadata = { **{self.column_list[i]: sample_data[i] for i in range(len(self.column_list))}, "key": str_key, "status": None, - "error_message": error_message, + "error_message": shard_status.error_message, "yt_meta_dict": yt_meta_dict, } - - if error_message is not None: - print(error_message) - if "[youtube]" in error_message: # video-specific error, remove videoID - error_message = "ERROR: [youtube]:" + error_message.split(":")[-1] - raise ValueError("failed_to_download") - - for stream in streams.values(): - bytes_downloaded += len(stream) - for mod in streams: - streams[mod] = [streams[mod]] - - if self.ffprobe_subsampler is not None: - streams, meta, error_message = self.ffprobe_subsampler(streams, meta) - if error_message is not None: - raise ValueError("failed_to_subsample") - - if self.config["storage"]["captions_are_subtitles"]: # create clips - # all langs have same start and end times - subtitles = meta["yt_meta_dict"]["subtitles"][list(meta["yt_meta_dict"]["subtitles"].keys())[0]] - meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif self.cut_detector is not None: # apply cut detection to get clips - streams, cuts, error_message = self.cut_detector(streams) - - if error_message is not None: - raise ValueError("failed_to_subsample") - - meta["cuts"] = cuts - - if self.cuts_are_clips: - cuts = meta["cuts"]["cuts_original_fps"] - native_fps = meta["cuts"]["original_fps"] - meta["clips"] = (np.array(cuts) / native_fps).tolist() - - # 1 video -> many videos (either clipping or noop which does identity broadcasting) - broadcast_subsampler = ( - self.clipping_subsampler - if ( - "clips" in self.column_list - or self.config["storage"]["captions_are_subtitles"] - or self.cuts_are_clips - ) - else self.noop_subsampler - ) - subsampled_streams, metas, error_message = broadcast_subsampler(streams, meta) - - for modality in subsampled_streams: - for modality_subsampler in self.subsamplers[modality]: - subsampled_streams, metas, error_message = modality_subsampler(subsampled_streams, metas) - - if error_message is not None: - meta["clips"] = [] - raise ValueError("failed_to_subsample") - - successes += 1 - status = "success" - status_dict.increment(status) - subsampled_streams_list = [ - dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values()) - ] - for subsampled_streams, meta in zip(subsampled_streams_list, metas): - meta["status"] = status - - text_caption = sample_data[caption_indice] if caption_indice is not None else None - if self.config["storage"]["captions_are_subtitles"]: - text_caption = meta.get("clip_subtitles")[0]["lines"] - - sample_writer.write( - subsampled_streams, - meta["key"], - text_caption, - meta, - ) except Exception as err: # pylint: disable=broad-except - status = str(err) - if status.startswith("failed_to_"): - failed[status] += 1 - status_dict.increment(error_message) - meta["status"] = status - meta["error_message"] = error_message - sample_writer.write( - {}, - str_key, - sample_data[caption_indice] if caption_indice is not None else None, - meta, - ) - semaphore.release() - else: - traceback.print_exc() - print(f"Sample {key} failed to download: {err}") + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + return - semaphore.release() - - sample_writer.close() - thread_pool.terminate() - thread_pool.join() - del thread_pool + try: + if shard_status.error_message is not None: + print(shard_status.error_message) + if "[youtube]" in shard_status.error_message: # video-specific error, remove videoID + shard_status.error_message = "ERROR: [youtube]:" + shard_status.error_message.split(":")[-1] + raise ValueError + except Exception: # pylint: disable=broad-except + shard_status.failed["failed_to_download"] += 1 + shard_status.status_dict.increment(shard_status.error_message) + metadata["status"] = "failed_to_download" + metadata["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + str_key, + sample_data[self.caption_indice] if self.caption_indice is not None else None, + metadata, + ) + return + + for stream in streams.values(): + shard_status.bytes_downloaded += len(stream) + for modality in streams: + streams[modality] = [streams[modality]] + + process_sample( + subsamplers=self.subsamplers, + shard_status=shard_status, + streams=cast(Streams, streams), + key=str_key, + caption=cast(str, caption), + metadata=metadata, + captions_are_subtitles=self.config["storage"]["captions_are_subtitles"], + shard_sample_writer=shard_sample_writer, + ) + shard_sample_writer.close() + rm_shard_path() end_time = time.time() + write_stats( self.output_folder, shard_id, - count, - successes, - failed["failed_to_download"], - failed["failed_to_subsample"], - bytes_downloaded, + shard_status.count, + shard_status.successes, + shard_status.failed["failed_to_download"], + shard_status.failed["failed_to_subsample"], + shard_status.bytes_downloaded, start_time, end_time, - status_dict, + shard_status.status_dict, self.config["storage"]["oom_shard_count"], ) - fs.rm(shard_path) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index ad4c9ebc..519aad3e 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -1,77 +1,16 @@ """creates a subset of an existing dataset inside the sample dimension""" -from dataclasses import dataclass, field -import time +import fsspec import json import pyarrow as pa +import time import traceback - -import fsspec -import numpy as np +from typing import Literal, cast import webdataset as wds -from typing import List, Any, Optional, Literal, cast from video2dataset.dataloader import get_video_dataset -from video2dataset.logger import CappedCounter, write_stats -from video2dataset.subsamplers import ( - ClippingSubsampler, - CutDetectionSubsampler, - FrameSubsampler, - FFProbeSubsampler, - NoOpSubsampler, - ResolutionSubsampler, - AudioRateSubsampler, -) +from video2dataset.logger import write_stats from video2dataset.types import EncodeFormats, Streams - - -def get_subsamplers(config: dict, encode_formats: EncodeFormats): - """Initialize all subsamplers using config""" - - clipping_subsampler = ClippingSubsampler( - 5, # oom_clip_count - encode_formats, - **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], - ) - need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" - - cut_detection_subsampler = None - cuts_are_clips = False - if "CutDetectionSubsampler" in config["subsampling"]: - if "args" in config["subsampling"]["CutDetectionSubsampler"]: - cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) - cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) - - broadcast_subsampler = ( - clipping_subsampler if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) else NoOpSubsampler() - ) - - ffprobe_subsampler = None - if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: - ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) - ffprobe_subsampler.extract_keyframes |= need_keyframes - - video_subsamplers: List[Any] = [] - if "ResolutionSubsampler" in config["subsampling"]: - video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) - if "FrameSubsampler" in config["subsampling"]: - video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) - - audio_subsamplers: List[Any] = [] - if "AudioRateSubsampler" in config["subsampling"]: - audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) - - modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} - - return ffprobe_subsampler, modal_subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler - - -@dataclass -class ShardStatus: - successes: int = 0 - failed_to_subsample: int = 0 - status_dict: CappedCounter = field(default_factory=CappedCounter) - error_message: Optional[str] = None - count: int = 0 +from video2dataset.workers.worker import ShardStatus, get_subsamplers, process_sample class SubsetWorker: @@ -87,50 +26,31 @@ def __init__( self.sample_writer_class = sample_writer_class self.output_folder = output_folder self.config = config - ( - self.ffprobe_subsampler, - self.modal_subsamplers, - self.cut_detection_subsampler, - self.cuts_are_clips, - self.broadcast_subsampler, - ) = get_subsamplers(config, encode_formats) - - # set encoding formats self.input_encode_formats = encode_formats - self.output_encode_formats = self.input_encode_formats.copy() - if self.modal_subsamplers["audio"]: - assert ( - len({s.encode_format for s in self.modal_subsamplers["audio"]}) == 1 - ) # assert that all audio subsamplers have the same output format - self.output_encode_formats["audio"] = self.modal_subsamplers["audio"][0].encode_format - if self.modal_subsamplers["video"]: - assert ( - len({s.encode_format for s in self.modal_subsamplers["video"]}) == 1 - ) # assert that all video subsamplers have the same output format - self.output_encode_formats["video"] = self.modal_subsamplers["video"][0].encode_format + self.subsamplers, self.output_encode_formats = get_subsamplers(config, self.input_encode_formats) def __call__( self, row, ): try: - shard, shard_id = row - self.process_shard(shard, shard_id) + shard_file, shard_id = row + self.process_shard(shard_file, shard_id) return (True, row) except Exception as err: # pylint: disable=broad-except traceback.print_exc() - print(f"shard {row[0]} failed with error {err}") + print(f"shard_file {row[0]} failed with error {err}") return (False, row) def get_shard_processors( self, - shard: str, + shard_file: str, shard_id: int, ): """Get objects for loading and writing data""" try: - fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") + fs, shard_path = fsspec.core.url_to_fs(shard_file[: -len(".tar")] + ".parquet") with fs.open(shard_path, "rb") as f: df = pa.parquet.read_table(f) schema = df.schema @@ -150,7 +70,7 @@ def get_shard_processors( self.output_encode_formats, ) shard_dataloader = get_video_dataset( - urls=shard, + urls=shard_file, batch_size=1, decoder_kwargs={}, enforce_additional_keys=[], @@ -160,13 +80,13 @@ def get_shard_processors( def process_shard( self, - shard: str, + shard_file: str, shard_id: int, ): """Function to start an video processing in one process""" start_time = time.time() - shard_sample_writer, shard_dataloader = self.get_shard_processors(shard, shard_id) + shard_sample_writer, shard_dataloader = self.get_shard_processors(shard_file, shard_id) shard_status = ShardStatus() for sample in shard_dataloader: @@ -174,82 +94,27 @@ def process_shard( key = sample["__key__"] try: caption = sample.get("txt", b"").decode("utf-8") - meta = json.loads(sample.get("json", b"{}").decode("utf-8")) + metadata = json.loads(sample.get("json", b"{}").decode("utf-8")) except Exception as err: # pylint: disable=broad-except traceback.print_exc() print(f"Sample {key} failed to download: {err}") return - try: - streams: Streams = {} - for modality, encode_format in self.input_encode_formats.items(): - modality = cast(Literal["audio", "video"], modality) - streams[modality] = [sample[encode_format]] - - if self.ffprobe_subsampler is not None: - streams, meta, shard_status.error_message = self.ffprobe_subsampler(streams, meta) - assert shard_status.error_message is None - - if self.config["storage"]["captions_are_subtitles"]: # create clips - subtitles = meta["yt_meta_dict"]["subtitles"] - meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif self.cut_detection_subsampler is not None: # apply cut detection to get clips - streams, cuts, shard_status.error_message = self.cut_detection_subsampler(streams) - assert shard_status.error_message is None - meta["cuts"] = cuts - assert cuts is not None - if self.cuts_are_clips: - meta["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() - - # 1 video -> many videos (either clipping or noop which does identity broadcasting) - subsampled_streams, metas, shard_status.error_message = self.broadcast_subsampler(streams, meta) - if shard_status.error_message is not None: - meta["clips"] = [] - assert False - - for modality in list(subsampled_streams.keys()): - for modality_subsampler in self.modal_subsamplers[modality]: - subsampled_streams, metas, shard_status.error_message = modality_subsampler( - subsampled_streams, metas - ) - assert shard_status.error_message is None - - shard_status.successes += 1 - status = "success" - shard_status.status_dict.increment(status) - - subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] - if len(subsampled_streams_list) == 0: # no audio or video, just write meta - meta["status"] = status - shard_sample_writer.write( - {}, - key, - caption, - meta, - ) - continue - for subsampled_streams, meta in zip(subsampled_streams_list, metas): - meta["status"] = status - text_caption = caption - if self.config["storage"]["captions_are_subtitles"]: - text_caption = meta.get("clip_subtitles")[0]["lines"][0] - shard_sample_writer.write( - subsampled_streams, - meta["key"], - text_caption, - meta, - ) - except Exception: # pylint: disable=broad-except - shard_status.failed_to_subsample += 1 - shard_status.status_dict.increment(shard_status.error_message) - meta["status"] = "failed_to_subsample" - meta["error_message"] = shard_status.error_message - shard_sample_writer.write( - {}, - key, - caption, - meta, - ) + streams: Streams = {} + for modality, encode_format in self.input_encode_formats.items(): + modality = cast(Literal["audio", "video"], modality) + streams[modality] = [sample[encode_format]] + + process_sample( + subsamplers=self.subsamplers, + shard_status=shard_status, + streams=streams, + key=key, + caption=caption, + metadata=metadata, + captions_are_subtitles=self.config["storage"]["captions_are_subtitles"], + shard_sample_writer=shard_sample_writer, + ) shard_sample_writer.close() end_time = time.time() @@ -260,7 +125,7 @@ def process_shard( shard_status.count, shard_status.successes, 0, # failed to download - shard_status.failed_to_subsample, + shard_status.failed["failed_to_subsample"], 0, # bytes downloaded start_time, end_time, diff --git a/video2dataset/workers/worker.py b/video2dataset/workers/worker.py new file mode 100644 index 00000000..45650829 --- /dev/null +++ b/video2dataset/workers/worker.py @@ -0,0 +1,197 @@ +"""Standard worker for video2dataset.""" +from dataclasses import dataclass, field +import numpy as np +from typing import Any, List, Tuple, Optional + +from video2dataset.logger import CappedCounter +from video2dataset.subsamplers import ( + ClippingSubsampler, + CutDetectionSubsampler, + FrameSubsampler, + FFProbeSubsampler, + NoOpSubsampler, + ResolutionSubsampler, + AudioRateSubsampler, + Subsampler, +) +from video2dataset.types import EncodeFormats, Streams, Metadata + + +@dataclass +class ShardStatus: + """Shard processing status""" + + successes: int = 0 + failed: dict = field( + default_factory=lambda: { + "failed_to_download": 0, + "failed_to_subsample": 0, + } + ) + status_dict: CappedCounter = field(default_factory=CappedCounter) + error_message: Optional[str] = None + count: int = 0 + bytes_downloaded: int = 0 + + +@dataclass +class Subsamplers: + """Subsamplers used in processing""" + + ffprobe_subsampler: Optional[FFProbeSubsampler] = None + modal_subsamplers: dict = field(default_factory=dict) + cut_detection_subsampler: Optional[CutDetectionSubsampler] = None + cuts_are_clips: bool = False + broadcast_subsampler: Subsampler = field(default_factory=NoOpSubsampler) + + +def get_subsamplers( + config: dict, + input_encode_formats: EncodeFormats, + do_clipping: bool = False, +) -> Tuple[Subsamplers, EncodeFormats]: + """Initialize all subsamplers using config""" + + clipping_subsampler = ClippingSubsampler( + oom_clip_count=5, + encode_formats=input_encode_formats, + **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], + ) + need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" + + cut_detection_subsampler = None + cuts_are_clips = False + if "CutDetectionSubsampler" in config["subsampling"]: + if "args" in config["subsampling"]["CutDetectionSubsampler"]: + cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) + cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + broadcast_subsampler = ( + clipping_subsampler + if (do_clipping or config["storage"]["captions_are_subtitles"] or cuts_are_clips) + else NoOpSubsampler() + ) + + ffprobe_subsampler = None + if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: + ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) + ffprobe_subsampler.extract_keyframes |= need_keyframes + + video_subsamplers: List[Any] = [] + if "ResolutionSubsampler" in config["subsampling"]: + video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) + if "FrameSubsampler" in config["subsampling"]: + video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) + + audio_subsamplers: List[Any] = [] + if "AudioRateSubsampler" in config["subsampling"]: + audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) + + modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} + + # output encoding formats + output_encode_formats = input_encode_formats.copy() + if modal_subsamplers["audio"]: + assert ( + len({s.encode_format for s in modal_subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + output_encode_formats["audio"] = modal_subsamplers["audio"][0].encode_format + if modal_subsamplers["video"]: + assert ( + len({s.encode_format for s in modal_subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + output_encode_formats["video"] = modal_subsamplers["video"][0].encode_format + + return ( + Subsamplers( + ffprobe_subsampler=ffprobe_subsampler, + modal_subsamplers=modal_subsamplers, + cut_detection_subsampler=cut_detection_subsampler, + cuts_are_clips=cuts_are_clips, + broadcast_subsampler=broadcast_subsampler, + ), + output_encode_formats, + ) + + +def process_sample( + subsamplers: Subsamplers, + shard_status: ShardStatus, + streams: Streams, + key: str, + caption: str, + metadata: Metadata, + captions_are_subtitles: bool, + shard_sample_writer: Any, # TODO: type correctly +): + """Process a single video""" + + try: + if subsamplers.ffprobe_subsampler is not None: + streams, metadata, shard_status.error_message = subsamplers.ffprobe_subsampler(streams, metadata) + assert shard_status.error_message is None + + if captions_are_subtitles: # create clips + subtitles = metadata["yt_meta_dict"]["subtitles"] + metadata["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] + elif subsamplers.cut_detection_subsampler is not None: # apply cut detection to get clips + streams, cuts, shard_status.error_message = subsamplers.cut_detection_subsampler(streams) + assert shard_status.error_message is None + metadata["cuts"] = cuts + assert cuts is not None + if subsamplers.cuts_are_clips: + metadata["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() + + # 1 video -> many videos (either clipping or noop which does identity broadcasting) + subsampled_streams, metadatas, shard_status.error_message = subsamplers.broadcast_subsampler(streams, metadata) + if shard_status.error_message is not None: + metadata["clips"] = [] + assert False + + for modality in list(subsampled_streams.keys()): + for modality_subsampler in subsamplers.modal_subsamplers[modality]: + subsampled_streams, metadatas, shard_status.error_message = modality_subsampler( + subsampled_streams, metadatas + ) + assert shard_status.error_message is None + + shard_status.successes += 1 + status = "success" + shard_status.status_dict.increment(status) + + subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] + if len(subsampled_streams_list) == 0: # no audio or video, just write metadata + metadata["status"] = status + shard_sample_writer.write( + {}, + key, + caption, + metadata, + ) + return + for subsampled_streams, subsampled_metadata in zip(subsampled_streams_list, metadatas): + subsampled_metadata["status"] = status + text_caption = caption + if captions_are_subtitles: + clip_subtitles = subsampled_metadata.get("clip_subtitles") + first_clip_subtitles = clip_subtitles[0] if clip_subtitles else None + subtitle_lines = first_clip_subtitles["lines"] if first_clip_subtitles else None + text_caption = subtitle_lines[0] if subtitle_lines else text_caption + shard_sample_writer.write( + subsampled_streams, + subsampled_metadata["key"], + text_caption, + subsampled_metadata, + ) + except Exception as err: # pylint: disable=broad-except + print(err) + shard_status.failed["failed_to_subsample"] += 1 + shard_status.status_dict.increment(shard_status.error_message) + metadata["status"] = "failed_to_subsample" + metadata["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + key, + caption, + metadata, + )