diff --git a/examples/templates/omni_text2audio.yaml b/examples/templates/omni_text2audio.yaml index 77d9711..f621def 100644 --- a/examples/templates/omni_text2audio.yaml +++ b/examples/templates/omni_text2audio.yaml @@ -18,7 +18,9 @@ spec: identifier: stabilityai/stable-audio-open-1.0 data: - text: Upbeat electronic music with synth pads and a driving beat + type: list + items: + - Upbeat electronic music with synth pads and a driving beat omni: output_format: wav diff --git a/examples/templates/omni_text2general.yaml b/examples/templates/omni_text2general.yaml index 569455a..32b29c3 100644 --- a/examples/templates/omni_text2general.yaml +++ b/examples/templates/omni_text2general.yaml @@ -18,7 +18,9 @@ spec: identifier: Qwen/Qwen3-Omni-30B-A3B-Instruct data: - text: "Please read the following passage in a calm, professional tone: FlowMesh is a distributed GPU workflow engine." + type: list + items: + - "Please read the following passage in a calm, professional tone: FlowMesh is a distributed GPU workflow engine." omni: output_format: wav diff --git a/examples/templates/omni_text2speech.yaml b/examples/templates/omni_text2speech.yaml index 9fc1003..e0df5c4 100644 --- a/examples/templates/omni_text2speech.yaml +++ b/examples/templates/omni_text2speech.yaml @@ -18,7 +18,9 @@ spec: identifier: Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice data: - text: "Hello, welcome to FlowMesh. This is a text-to-speech demo." + type: list + items: + - "Hello, welcome to FlowMesh. This is a text-to-speech demo." omni: output_format: wav diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 8735390..ff2e04d 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -13,16 +13,20 @@ import struct import tempfile import wave +from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, cast import yaml +from shared.tasks.specs import TaskSpecStrictBase from shared.utils.parsing import to_bool, to_int from ..config import WorkerConfig from ..lifecycle import Lifecycle -from .base_executor import ExecutionError, Executor +from .base_executor import ExecutionError, Executor, ExecutorTask +from .mixins.inference import InferenceMixin +from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces try: import numpy as np @@ -38,13 +42,17 @@ logger = logging.getLogger(__name__) -class OmniExecutorBase(Executor): +class OmniExecutorBase(InferenceMixin, Executor): """Shared base for Omni-family executors. - Manages the ``_omni`` model handle and provides config / audio helpers. - Concrete subclasses implement ``prepare()`` and ``run()`` as usual. + Manages the ``_omni`` model handle, runs the generic ``run()`` shape + (task span + artifact / trace upload), and delegates the task body to + each subclass's ``_run_inner``. Subclasses set ``_TASK_SPEC_TYPE`` so + the base can call ``require_spec`` without knowing the concrete type. """ + _TASK_SPEC_TYPE: ClassVar[type[TaskSpecStrictBase]] + def __init__( self, config: WorkerConfig, lifecycle: Lifecycle | None = None ) -> None: @@ -54,6 +62,30 @@ def __init__( self._omni_spec: tuple[Any, ...] | None = None self._stage_configs_tmp: Path | None = None + def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + spec = self.require_spec(task, self._TASK_SPEC_TYPE) + spec_dict = spec.model_dump(by_alias=True) + out_dir = Path(out_dir).resolve() + with self._task_span( + task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + result = self._run_inner(task, spec, spec_dict, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result + + @abstractmethod + def _run_inner( + self, + task: ExecutorTask, + spec: TaskSpecStrictBase, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + """Run the executor-specific body. ``spec`` is the concrete strict + spec; subclasses ``assert isinstance(spec, ...)`` to narrow.""" + raise NotImplementedError + # ── model lifecycle ────────────────────────────────────────────────── def _close_omni(self) -> None: @@ -160,21 +192,13 @@ def resolve_model_identifier( value = default return str(value).strip() - @staticmethod - def collect_text_inputs(spec_dict: dict[str, Any]) -> list[str]: - data = spec_dict.get("data") or {} - if not isinstance(data, dict): - data = {} - dtype = str(data.get("type") or "").strip().lower() - if dtype == "list": - items = data.get("items") or [] - if not isinstance(items, list): - return [] - return [str(item).strip() for item in items if str(item).strip()] - text_val = data.get("text") or data.get("prompt") or spec_dict.get("task") - if not isinstance(text_val, str) or not text_val.strip(): - return [] - return [text_val.strip()] + def _collect_text_inputs(self, spec: TaskSpecStrictBase, task_id: str) -> list[str]: + prompts = self._collect_prompts_for_spec(spec, task_id).prompts + if not prompts: + raise ExecutionError(f"{self.name} requires text input in spec.data.items.") + if not all(isinstance(p, str) for p in prompts): + raise ExecutionError(f"{self.name} prompts must be strings.") + return cast(list[str], prompts) @staticmethod def resolve_save_path( diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index eb805da..952dfb2 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -50,12 +50,14 @@ current_omni_platform = None _HAS_OMNI_PLATFORM = False +from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2AudioSpecStrict from shared.utils.parsing import to_float, to_int from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase, extract_multimodal_output -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -64,6 +66,7 @@ class OmniText2AudioExecutor(OmniExecutorBase): """Generate background music with Omni diffusion sampling.""" name = "omni_text2audio" + _TASK_SPEC_TYPE = OmniText2AudioSpecStrict def prepare(self) -> None: if torch is None: @@ -75,20 +78,18 @@ def prepare(self) -> None: "vllm_omni is not installed; cannot use omni_text2audio executor." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2AudioSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - prompts = self.collect_text_inputs(spec_dict) - if not prompts: - raise ExecutionError( - "omni_text2audio requires prompt text " - "in spec.data.text or spec.data.items." - ) + def _run_inner( + self, + task: ExecutorTask, + spec: TaskSpecStrictBase, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + assert isinstance(spec, OmniText2AudioSpecStrict) + prompts = self._collect_text_inputs(spec, task.task_id) cfg = _bgm_cfg(spec_dict) - output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav" + output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" if output_format != "wav": raise ExecutionError( "omni_text2audio currently supports output_format='wav' only." @@ -106,75 +107,97 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: base_seed = to_int(cfg.get("seed"), default=42) negative_prompt = str(cfg.get("negative_prompt") or "Low quality.").strip() - self._ensure_omni(spec_dict) + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + self._ensure_omni(spec_dict) omni = self._omni if omni is None: raise ExecutionError("Omni BGM model failed to initialize.") generator_device = _resolve_generator_device() per_prompt_outputs: list[tuple[int, str, Any]] = [] - for prompt_idx, prompt in enumerate(prompts): - seed = base_seed + prompt_idx - torch_generator = torch.Generator(device=generator_device).manual_seed(seed) - - omni_prompt: OmniTextPrompt = {"prompt": prompt} - if negative_prompt: - omni_prompt["negative_prompt"] = negative_prompt - - sampling = OmniDiffusionSamplingParams( - generator=torch_generator, - generator_device=generator_device, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - num_outputs_per_prompt=num_waveforms, - extra_args={ - "audio_start_in_s": audio_start, - "audio_end_in_s": audio_end, - }, - ) - try: - outputs = omni.generate(omni_prompt, sampling) - except Exception as exc: - raise ExecutionError( - f"omni_text2audio generation failed: {exc}" - ) from exc - per_prompt_outputs.append((prompt_idx, prompt, outputs)) + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={ + "task_id": task.task_id, + "prompt_count": len(prompts), + "num_waveforms": num_waveforms, + "num_inference_steps": num_inference_steps, + }, + ): + for prompt_idx, prompt in enumerate(prompts): + seed = base_seed + prompt_idx + torch_generator = torch.Generator(device=generator_device).manual_seed( + seed + ) + + omni_prompt: OmniTextPrompt = {"prompt": prompt} + if negative_prompt: + omni_prompt["negative_prompt"] = negative_prompt + + sampling = OmniDiffusionSamplingParams( + generator=torch_generator, + generator_device=generator_device, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + num_outputs_per_prompt=num_waveforms, + extra_args={ + "audio_start_in_s": audio_start, + "audio_end_in_s": audio_end, + }, + ) + try: + outputs = omni.generate(omni_prompt, sampling) + except Exception as exc: + raise ExecutionError( + f"omni_text2audio generation failed: {exc}" + ) from exc + per_prompt_outputs.append((prompt_idx, prompt, outputs)) artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] global_index = 0 - for prompt_idx, prompt, outputs in per_prompt_outputs: - extracted = _extract_audio_waveforms(outputs) - if not extracted: - raise ExecutionError( - "omni_text2audio completed but returned no audio output." - ) - - for local_idx, audio_entry in enumerate(extracted): - multi = len(prompts) * len(extracted) > 1 - save_path = _resolve_bgm_save_path( - cfg, - out_dir, - index=global_index, - ext=output_format, - multi=multi, - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - _save_waveform( - audio_entry["waveform"], save_path, sample_rate=sample_rate - ) - items.append( - { - "index": global_index, - "prompt_index": prompt_idx, - "waveform_index": local_idx, - "prompt": prompt, - "audio": artifact_ref( - self.relative_to(save_path, artifacts_dir) - ), - } - ) - global_index += 1 + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + for prompt_idx, prompt, outputs in per_prompt_outputs: + extracted = _extract_audio_waveforms(outputs) + if not extracted: + raise ExecutionError( + "omni_text2audio completed but returned no audio output." + ) + + for local_idx, audio_entry in enumerate(extracted): + multi = len(prompts) * len(extracted) > 1 + save_path = _resolve_bgm_save_path( + cfg, + out_dir, + index=global_index, + ext=output_format, + multi=multi, + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + _save_waveform( + audio_entry["waveform"], save_path, sample_rate=sample_rate + ) + items.append( + { + "index": global_index, + "prompt_index": prompt_idx, + "waveform_index": local_idx, + "prompt": prompt, + "audio": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) + global_index += 1 if not items: raise ExecutionError("omni_text2audio produced no savable waveforms.") @@ -194,7 +217,6 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: storyboard = spec_dict.get("storyboard") if isinstance(storyboard, dict): result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index b164c9d..5664af9 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -26,12 +26,14 @@ Omni = None _HAS_OMNI = False +from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2GeneralSpecStrict from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase, extract_audio_from_mm, save_audio -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -46,6 +48,7 @@ class OmniText2GeneralExecutor(OmniExecutorBase): """Generate narration/speech audio using Qwen3-Omni through vllm_omni.Omni.""" name = "omni_text2general" + _TASK_SPEC_TYPE = OmniText2GeneralSpecStrict def prepare(self) -> None: if not _HAS_OMNI: @@ -58,20 +61,18 @@ def prepare(self) -> None: "omni_text2general requires SamplingParams from vllm." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2GeneralSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - texts = self.collect_text_inputs(spec_dict) - if not texts: - raise ExecutionError( - "omni_text2general requires text input " - "in spec.data.text or spec.data.items." - ) + def _run_inner( + self, + task: ExecutorTask, + spec: TaskSpecStrictBase, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + assert isinstance(spec, OmniText2GeneralSpecStrict) + texts = self._collect_text_inputs(spec, task.task_id) cfg = _narration_cfg(spec_dict) - output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav" + output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" if output_format != "wav": raise ExecutionError( "omni_text2general currently supports output_format='wav' only." @@ -80,7 +81,12 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: output_modalities = _parse_modalities(cfg.get("modalities")) py_generator = to_bool(cfg.get("py_generator"), default=False) - self._ensure_omni(spec_dict) + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + self._ensure_omni(spec_dict) if self._omni is None: raise ExecutionError("Omni model failed to initialize.") @@ -93,37 +99,46 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ] sampling_params = _build_sampling_params(cfg) - try: - generator = self._omni.generate( - prompts, sampling_params, py_generator=py_generator - ) - except Exception as exc: - raise ExecutionError( - f"omni_text2general generation failed to start: {exc}" - ) from exc - audio_results: list[dict[str, Any]] = [] text_results: dict[str, str] = {} - for stage_outputs in generator: - final_type = ( - str(getattr(stage_outputs, "final_output_type", "")).strip().lower() - ) - request_outputs = as_list(getattr(stage_outputs, "request_output", None)) - if not request_outputs: - continue - if final_type == "text": - for req in request_outputs: - rid = _request_id(req, default_index=len(text_results) + 1) - text_out = _extract_text_output(req) - if text_out is not None: - text_results[rid] = text_out - continue - if final_type == "audio": - for req in request_outputs: - rid = _request_id(req, default_index=len(audio_results) + 1) - audio_obj = _extract_request_audio(req) - if audio_obj is not None: - audio_results.append({"request_id": rid, "audio": audio_obj}) + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + try: + generator = self._omni.generate( + prompts, sampling_params, py_generator=py_generator + ) + except Exception as exc: + raise ExecutionError( + f"omni_text2general generation failed to start: {exc}" + ) from exc + + for stage_outputs in generator: + final_type = ( + str(getattr(stage_outputs, "final_output_type", "")).strip().lower() + ) + request_outputs = as_list( + getattr(stage_outputs, "request_output", None) + ) + if not request_outputs: + continue + if final_type == "text": + for req in request_outputs: + rid = _request_id(req, default_index=len(text_results) + 1) + text_out = _extract_text_output(req) + if text_out is not None: + text_results[rid] = text_out + continue + if final_type == "audio": + for req in request_outputs: + rid = _request_id(req, default_index=len(audio_results) + 1) + audio_obj = _extract_request_audio(req) + if audio_obj is not None: + audio_results.append( + {"request_id": rid, "audio": audio_obj} + ) if not audio_results: raise ExecutionError( @@ -133,29 +148,34 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] multi = len(audio_results) > 1 - for idx, entry in enumerate(audio_results): - rid = str(entry.get("request_id") or f"req_{idx + 1}") - audio_obj = entry.get("audio") - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=output_format, - multi=multi, - default_prefix="narration", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - save_audio(audio_obj, save_path, sample_rate=sample_rate) - item: dict[str, Any] = { - "index": idx, - "request_id": rid, - "prompt": texts[idx] if idx < len(texts) else None, - "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), - } - text_out = text_results.get(rid) - if text_out: - item["text"] = text_out - items.append(item) + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(audio_results)}, + ): + for idx, entry in enumerate(audio_results): + rid = str(entry.get("request_id") or f"req_{idx + 1}") + audio_obj = entry.get("audio") + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=output_format, + multi=multi, + default_prefix="narration", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + save_audio(audio_obj, save_path, sample_rate=sample_rate) + item: dict[str, Any] = { + "index": idx, + "request_id": rid, + "prompt": texts[idx] if idx < len(texts) else None, + "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), + } + text_out = text_results.get(rid) + if text_out: + item["text"] = text_out + items.append(item) first = items[0]["audio"] if items else {} result: dict[str, Any] = { @@ -170,7 +190,6 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: storyboard = spec_dict.get("storyboard") if isinstance(storyboard, dict): result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 25dbe22..e3b1f23 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -15,12 +15,14 @@ Omni = None _HAS_OMNI = False +from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2ImageSpecStrict from shared.utils.parsing import as_list from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -29,6 +31,7 @@ class OmniText2ImageExecutor(OmniExecutorBase): """Generate images using vllm_omni.Omni.""" name = "omni_text2image" + _TASK_SPEC_TYPE = OmniText2ImageSpecStrict def prepare(self) -> None: if not _HAS_OMNI: @@ -36,43 +39,58 @@ def prepare(self) -> None: "vllm_omni is not installed; cannot use omni_text2image executor." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2ImageSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - prompts = self.collect_text_inputs(spec_dict) - if not prompts: - raise ExecutionError( - "omni_text2image requires prompts " - "in spec.data.prompt or spec.data.items." - ) - - self._ensure_omni(spec_dict) + def _run_inner( + self, + task: ExecutorTask, + spec: TaskSpecStrictBase, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + assert isinstance(spec, OmniText2ImageSpecStrict) + prompts = self._collect_text_inputs(spec, task.task_id) + + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + self._ensure_omni(spec_dict) cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") - fmt = str(cfg.get("output_format") or "png").strip().lower() or "png" + fmt = str(cfg.get("output_format") or "").strip().lower() or "png" artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] - images = self._generate_images(prompts) - for idx, (prompt, image) in enumerate(zip(prompts, images)): - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=fmt, - multi=len(prompts) > 1, - default_prefix="generated_image", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - _save_image(image, save_path) - items.append( - { - "index": idx, - "prompt": prompt, - "image": artifact_ref(self.relative_to(save_path, artifacts_dir)), - } - ) + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + images = self._generate_images(prompts) + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(prompts)}, + ): + for idx, (prompt, image) in enumerate(zip(prompts, images)): + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=fmt, + multi=len(prompts) > 1, + default_prefix="generated_image", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + _save_image(image, save_path) + items.append( + { + "index": idx, + "prompt": prompt, + "image": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) first = items[0]["image"] if items else {} result: dict[str, Any] = { @@ -83,7 +101,6 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: "image": first, "items": items, } - maybe_upload_artifacts(task, out_dir, logger=logger) return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 42fb134..5d7be1b 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -16,6 +16,8 @@ Omni = None _HAS_OMNI = False +from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2SpeechSpecStrict from shared.utils.parsing import as_list, to_int @@ -26,7 +28,7 @@ extract_multimodal_output, save_audio, ) -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -35,6 +37,7 @@ class OmniText2SpeechExecutor(OmniExecutorBase): """Generate speech audio using vllm_omni.Omni.""" name = "omni_text2speech" + _TASK_SPEC_TYPE = OmniText2SpeechSpecStrict def prepare(self) -> None: if not _HAS_OMNI: @@ -42,44 +45,61 @@ def prepare(self) -> None: "vllm_omni is not installed; cannot use omni_text2speech executor." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2SpeechSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - texts = self.collect_text_inputs(spec_dict) - if not texts: - raise ExecutionError( - "omni_text2speech requires text input " - "in spec.data.text or spec.data.items." - ) - - self._ensure_omni(spec_dict) + def _run_inner( + self, + task: ExecutorTask, + spec: TaskSpecStrictBase, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + assert isinstance(spec, OmniText2SpeechSpecStrict) + texts = self._collect_text_inputs(spec, task.task_id) + + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + self._ensure_omni(spec_dict) cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") - output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav" + output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" sample_rate = to_int(cfg.get("sample_rate"), default=24000) - audio_objects = [self._generate_single(t, spec_dict=spec_dict) for t in texts] + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + audio_objects = [ + self._generate_single(t, spec_dict=spec_dict) for t in texts + ] artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] - for idx, (text, audio_obj) in enumerate(zip(texts, audio_objects)): - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=output_format, - multi=len(texts) > 1, - default_prefix="generated_tts", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - save_audio(audio_obj, save_path, sample_rate=sample_rate) - items.append( - { - "index": idx, - "text": text, - "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), - } - ) + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(texts)}, + ): + for idx, (text, audio_obj) in enumerate(zip(texts, audio_objects)): + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=output_format, + multi=len(texts) > 1, + default_prefix="generated_tts", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + save_audio(audio_obj, save_path, sample_rate=sample_rate) + items.append( + { + "index": idx, + "text": text, + "audio": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) first = items[0]["audio"] if items else {} result: dict[str, Any] = { @@ -94,7 +114,6 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: storyboard = spec_dict.get("storyboard") if isinstance(storyboard, dict): result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) return result # ── model ──────────────────────────────────────────────────────────── diff --git a/tests/worker/test_omni_executor_inheritance.py b/tests/worker/test_omni_executor_inheritance.py new file mode 100644 index 0000000..89ac207 --- /dev/null +++ b/tests/worker/test_omni_executor_inheritance.py @@ -0,0 +1,36 @@ +"""Compile-time guard that omni executors sit on the mixin chain. + +The mixin chain is the same one the inference and training executors use; the +omni family was the last to predate it. This test fails loudly if a future +refactor regresses the base class. +""" + +import pytest + +pytest.importorskip("vllm_omni", reason="vllm-omni not installed") + +from worker.executors.mixins.data import DataMixin +from worker.executors.mixins.governance import GovernanceMixin +from worker.executors.mixins.inference import InferenceMixin +from worker.executors.omni_executor_base import OmniExecutorBase +from worker.executors.omni_text2audio_executor import OmniText2AudioExecutor +from worker.executors.omni_text2general_executor import OmniText2GeneralExecutor +from worker.executors.omni_text2image_executor import OmniText2ImageExecutor +from worker.executors.omni_text2speech_executor import OmniText2SpeechExecutor + + +@pytest.mark.parametrize( + "cls", + [ + OmniExecutorBase, + OmniText2AudioExecutor, + OmniText2GeneralExecutor, + OmniText2ImageExecutor, + OmniText2SpeechExecutor, + ], + ids=lambda c: c.__name__, +) +def test_omni_executors_use_mixin_chain(cls: type) -> None: + assert issubclass(cls, InferenceMixin) + assert issubclass(cls, DataMixin) + assert issubclass(cls, GovernanceMixin)