From 9fb6d08410ec795d79fdb05aa717604968687246 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Thu, 14 May 2026 09:01:55 +0000 Subject: [PATCH 1/6] refactor(omni): adopt mixin chain + emit per-phase spans (RFC #48) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - OmniExecutorBase inherits (InferenceMixin, Executor) so the four omni executors pick up GovernanceMixin / DataMixin / InferenceMixin from the same chain the inference and training executors use. - Each concrete omni executor wraps run() with self._task_span(...) so a 'task' root span is emitted with executor.name + workflow_id. - Inside run(), three per-phase compute spans are added — 'model load' (_ensure_omni), 'generation' (the omni.generate call(s)), and 'output postprocessing' (artifact save + items build) — mirroring the vllm executor's tracing shape. - New tests/worker/test_omni_executor_inheritance.py asserts the full mixin chain on each omni executor class as a compile-time guard against regression. Live e2e on a single GPU worker against all four omni templates (omni_text2{speech,image,audio,general}.yaml): each task reports ok=True with the expected artifact, and spans.jsonl contains the 'task' root plus 'model load' / 'generation' / 'output postprocessing' sub-spans. Signed-off-by: Zhengyuan Su --- src/worker/executors/omni_executor_base.py | 3 +- .../executors/omni_text2audio_executor.py | 149 +++++++++++------- .../executors/omni_text2general_executor.py | 134 +++++++++------- .../executors/omni_text2image_executor.py | 109 +++++++------ .../executors/omni_text2speech_executor.py | 119 ++++++++------ .../worker/test_omni_executor_inheritance.py | 36 +++++ 6 files changed, 346 insertions(+), 204 deletions(-) create mode 100644 tests/worker/test_omni_executor_inheritance.py diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 87353907..8dfb15c4 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -23,6 +23,7 @@ from ..config import WorkerConfig from ..lifecycle import Lifecycle from .base_executor import ExecutionError, Executor +from .mixins.inference import InferenceMixin try: import numpy as np @@ -38,7 +39,7 @@ logger = logging.getLogger(__name__) -class OmniExecutorBase(Executor): +class OmniExecutorBase(InferenceMixin, Executor): """Shared base for Omni-family executors. Manages the ``_omni`` model handle and provides config / audio helpers. diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index eb805da4..be7fd20e 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -50,6 +50,7 @@ current_omni_platform = None _HAS_OMNI_PLATFORM = False +from shared.schemas.governance import SpanType from shared.tasks.specs.omni import OmniText2AudioSpecStrict from shared.utils.parsing import to_float, to_int @@ -80,6 +81,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec_dict = spec.model_dump(by_alias=True) out_dir = Path(out_dir).resolve() + with self._task_span( + task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + return self._run_inner(task, spec_dict, out_dir) + + def _run_inner( + self, task: ExecutorTask, spec_dict: dict[str, Any], out_dir: Path + ) -> dict[str, Any]: prompts = self.collect_text_inputs(spec_dict) if not prompts: raise ExecutionError( @@ -106,75 +115,97 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: base_seed = to_int(cfg.get("seed"), default=42) negative_prompt = str(cfg.get("negative_prompt") or "Low quality.").strip() - self._ensure_omni(spec_dict) + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + self._ensure_omni(spec_dict) omni = self._omni if omni is None: raise ExecutionError("Omni BGM model failed to initialize.") generator_device = _resolve_generator_device() per_prompt_outputs: list[tuple[int, str, Any]] = [] - for prompt_idx, prompt in enumerate(prompts): - seed = base_seed + prompt_idx - torch_generator = torch.Generator(device=generator_device).manual_seed(seed) - - omni_prompt: OmniTextPrompt = {"prompt": prompt} - if negative_prompt: - omni_prompt["negative_prompt"] = negative_prompt - - sampling = OmniDiffusionSamplingParams( - generator=torch_generator, - generator_device=generator_device, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - num_outputs_per_prompt=num_waveforms, - extra_args={ - "audio_start_in_s": audio_start, - "audio_end_in_s": audio_end, - }, - ) - try: - outputs = omni.generate(omni_prompt, sampling) - except Exception as exc: - raise ExecutionError( - f"omni_text2audio generation failed: {exc}" - ) from exc - per_prompt_outputs.append((prompt_idx, prompt, outputs)) + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={ + "task_id": task.task_id, + "prompt_count": len(prompts), + "num_waveforms": num_waveforms, + "num_inference_steps": num_inference_steps, + }, + ): + for prompt_idx, prompt in enumerate(prompts): + seed = base_seed + prompt_idx + torch_generator = torch.Generator(device=generator_device).manual_seed( + seed + ) + + omni_prompt: OmniTextPrompt = {"prompt": prompt} + if negative_prompt: + omni_prompt["negative_prompt"] = negative_prompt + + sampling = OmniDiffusionSamplingParams( + generator=torch_generator, + generator_device=generator_device, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + num_outputs_per_prompt=num_waveforms, + extra_args={ + "audio_start_in_s": audio_start, + "audio_end_in_s": audio_end, + }, + ) + try: + outputs = omni.generate(omni_prompt, sampling) + except Exception as exc: + raise ExecutionError( + f"omni_text2audio generation failed: {exc}" + ) from exc + per_prompt_outputs.append((prompt_idx, prompt, outputs)) artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] global_index = 0 - for prompt_idx, prompt, outputs in per_prompt_outputs: - extracted = _extract_audio_waveforms(outputs) - if not extracted: - raise ExecutionError( - "omni_text2audio completed but returned no audio output." - ) - - for local_idx, audio_entry in enumerate(extracted): - multi = len(prompts) * len(extracted) > 1 - save_path = _resolve_bgm_save_path( - cfg, - out_dir, - index=global_index, - ext=output_format, - multi=multi, - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - _save_waveform( - audio_entry["waveform"], save_path, sample_rate=sample_rate - ) - items.append( - { - "index": global_index, - "prompt_index": prompt_idx, - "waveform_index": local_idx, - "prompt": prompt, - "audio": artifact_ref( - self.relative_to(save_path, artifacts_dir) - ), - } - ) - global_index += 1 + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + for prompt_idx, prompt, outputs in per_prompt_outputs: + extracted = _extract_audio_waveforms(outputs) + if not extracted: + raise ExecutionError( + "omni_text2audio completed but returned no audio output." + ) + + for local_idx, audio_entry in enumerate(extracted): + multi = len(prompts) * len(extracted) > 1 + save_path = _resolve_bgm_save_path( + cfg, + out_dir, + index=global_index, + ext=output_format, + multi=multi, + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + _save_waveform( + audio_entry["waveform"], save_path, sample_rate=sample_rate + ) + items.append( + { + "index": global_index, + "prompt_index": prompt_idx, + "waveform_index": local_idx, + "prompt": prompt, + "audio": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) + global_index += 1 if not items: raise ExecutionError("omni_text2audio produced no savable waveforms.") diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index b164c9d0..e9328a67 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -26,6 +26,7 @@ Omni = None _HAS_OMNI = False +from shared.schemas.governance import SpanType from shared.tasks.specs.omni import OmniText2GeneralSpecStrict from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list @@ -63,6 +64,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec_dict = spec.model_dump(by_alias=True) out_dir = Path(out_dir).resolve() + with self._task_span( + task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + return self._run_inner(task, spec_dict, out_dir) + + def _run_inner( + self, task: ExecutorTask, spec_dict: dict[str, Any], out_dir: Path + ) -> dict[str, Any]: texts = self.collect_text_inputs(spec_dict) if not texts: raise ExecutionError( @@ -80,7 +89,12 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: output_modalities = _parse_modalities(cfg.get("modalities")) py_generator = to_bool(cfg.get("py_generator"), default=False) - self._ensure_omni(spec_dict) + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + self._ensure_omni(spec_dict) if self._omni is None: raise ExecutionError("Omni model failed to initialize.") @@ -93,37 +107,46 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: ] sampling_params = _build_sampling_params(cfg) - try: - generator = self._omni.generate( - prompts, sampling_params, py_generator=py_generator - ) - except Exception as exc: - raise ExecutionError( - f"omni_text2general generation failed to start: {exc}" - ) from exc - audio_results: list[dict[str, Any]] = [] text_results: dict[str, str] = {} - for stage_outputs in generator: - final_type = ( - str(getattr(stage_outputs, "final_output_type", "")).strip().lower() - ) - request_outputs = as_list(getattr(stage_outputs, "request_output", None)) - if not request_outputs: - continue - if final_type == "text": - for req in request_outputs: - rid = _request_id(req, default_index=len(text_results) + 1) - text_out = _extract_text_output(req) - if text_out is not None: - text_results[rid] = text_out - continue - if final_type == "audio": - for req in request_outputs: - rid = _request_id(req, default_index=len(audio_results) + 1) - audio_obj = _extract_request_audio(req) - if audio_obj is not None: - audio_results.append({"request_id": rid, "audio": audio_obj}) + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + try: + generator = self._omni.generate( + prompts, sampling_params, py_generator=py_generator + ) + except Exception as exc: + raise ExecutionError( + f"omni_text2general generation failed to start: {exc}" + ) from exc + + for stage_outputs in generator: + final_type = ( + str(getattr(stage_outputs, "final_output_type", "")).strip().lower() + ) + request_outputs = as_list( + getattr(stage_outputs, "request_output", None) + ) + if not request_outputs: + continue + if final_type == "text": + for req in request_outputs: + rid = _request_id(req, default_index=len(text_results) + 1) + text_out = _extract_text_output(req) + if text_out is not None: + text_results[rid] = text_out + continue + if final_type == "audio": + for req in request_outputs: + rid = _request_id(req, default_index=len(audio_results) + 1) + audio_obj = _extract_request_audio(req) + if audio_obj is not None: + audio_results.append( + {"request_id": rid, "audio": audio_obj} + ) if not audio_results: raise ExecutionError( @@ -133,29 +156,34 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] multi = len(audio_results) > 1 - for idx, entry in enumerate(audio_results): - rid = str(entry.get("request_id") or f"req_{idx + 1}") - audio_obj = entry.get("audio") - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=output_format, - multi=multi, - default_prefix="narration", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - save_audio(audio_obj, save_path, sample_rate=sample_rate) - item: dict[str, Any] = { - "index": idx, - "request_id": rid, - "prompt": texts[idx] if idx < len(texts) else None, - "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), - } - text_out = text_results.get(rid) - if text_out: - item["text"] = text_out - items.append(item) + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(audio_results)}, + ): + for idx, entry in enumerate(audio_results): + rid = str(entry.get("request_id") or f"req_{idx + 1}") + audio_obj = entry.get("audio") + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=output_format, + multi=multi, + default_prefix="narration", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + save_audio(audio_obj, save_path, sample_rate=sample_rate) + item: dict[str, Any] = { + "index": idx, + "request_id": rid, + "prompt": texts[idx] if idx < len(texts) else None, + "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), + } + text_out = text_results.get(rid) + if text_out: + item["text"] = text_out + items.append(item) first = items[0]["audio"] if items else {} result: dict[str, Any] = { diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 25dbe223..015f78ac 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -15,6 +15,7 @@ Omni = None _HAS_OMNI = False +from shared.schemas.governance import SpanType from shared.tasks.specs.omni import OmniText2ImageSpecStrict from shared.utils.parsing import as_list @@ -41,50 +42,70 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec_dict = spec.model_dump(by_alias=True) out_dir = Path(out_dir).resolve() - prompts = self.collect_text_inputs(spec_dict) - if not prompts: - raise ExecutionError( - "omni_text2image requires prompts " - "in spec.data.prompt or spec.data.items." - ) - - self._ensure_omni(spec_dict) - cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") - fmt = str(cfg.get("output_format") or "png").strip().lower() or "png" - - artifacts_dir = out_dir / "artifacts" - items: list[dict[str, Any]] = [] - images = self._generate_images(prompts) - for idx, (prompt, image) in enumerate(zip(prompts, images)): - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=fmt, - multi=len(prompts) > 1, - default_prefix="generated_image", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - _save_image(image, save_path) - items.append( - { - "index": idx, - "prompt": prompt, - "image": artifact_ref(self.relative_to(save_path, artifacts_dir)), - } - ) - - first = items[0]["image"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "image", - "model": self._model_name, - "image": first, - "items": items, - } - maybe_upload_artifacts(task, out_dir, logger=logger) - return result + with self._task_span( + task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + prompts = self.collect_text_inputs(spec_dict) + if not prompts: + raise ExecutionError( + "omni_text2image requires prompts " + "in spec.data.prompt or spec.data.items." + ) + + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + self._ensure_omni(spec_dict) + cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") + fmt = str(cfg.get("output_format") or "png").strip().lower() or "png" + + artifacts_dir = out_dir / "artifacts" + items: list[dict[str, Any]] = [] + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + images = self._generate_images(prompts) + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(prompts)}, + ): + for idx, (prompt, image) in enumerate(zip(prompts, images)): + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=fmt, + multi=len(prompts) > 1, + default_prefix="generated_image", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + _save_image(image, save_path) + items.append( + { + "index": idx, + "prompt": prompt, + "image": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) + + first = items[0]["image"] if items else {} + result: dict[str, Any] = { + "ok": True, + "executor": self.name, + "mode": "image", + "model": self._model_name, + "image": first, + "items": items, + } + maybe_upload_artifacts(task, out_dir, logger=logger) + return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 42fb1346..51903504 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -16,6 +16,7 @@ Omni = None _HAS_OMNI = False +from shared.schemas.governance import SpanType from shared.tasks.specs.omni import OmniText2SpeechSpecStrict from shared.utils.parsing import as_list, to_int @@ -47,55 +48,79 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: spec_dict = spec.model_dump(by_alias=True) out_dir = Path(out_dir).resolve() - texts = self.collect_text_inputs(spec_dict) - if not texts: - raise ExecutionError( - "omni_text2speech requires text input " - "in spec.data.text or spec.data.items." - ) + with self._task_span( + task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + texts = self.collect_text_inputs(spec_dict) + if not texts: + raise ExecutionError( + "omni_text2speech requires text input " + "in spec.data.text or spec.data.items." + ) - self._ensure_omni(spec_dict) - cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") - output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav" - sample_rate = to_int(cfg.get("sample_rate"), default=24000) - - audio_objects = [self._generate_single(t, spec_dict=spec_dict) for t in texts] - artifacts_dir = out_dir / "artifacts" - items: list[dict[str, Any]] = [] - for idx, (text, audio_obj) in enumerate(zip(texts, audio_objects)): - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=output_format, - multi=len(texts) > 1, - default_prefix="generated_tts", + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + self._ensure_omni(spec_dict) + cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") + output_format = ( + str(cfg.get("output_format") or "wav").strip().lower() or "wav" ) - save_path.parent.mkdir(parents=True, exist_ok=True) - save_audio(audio_obj, save_path, sample_rate=sample_rate) - items.append( - { - "index": idx, - "text": text, - "audio": artifact_ref(self.relative_to(save_path, artifacts_dir)), - } - ) - - first = items[0]["audio"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "tts", - "model": self._model_name, - "audio": first, - "items": items, - "sample_rate": sample_rate, - } - storyboard = spec_dict.get("storyboard") - if isinstance(storyboard, dict): - result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) - return result + sample_rate = to_int(cfg.get("sample_rate"), default=24000) + + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + audio_objects = [ + self._generate_single(t, spec_dict=spec_dict) for t in texts + ] + artifacts_dir = out_dir / "artifacts" + items: list[dict[str, Any]] = [] + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(texts)}, + ): + for idx, (text, audio_obj) in enumerate(zip(texts, audio_objects)): + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=output_format, + multi=len(texts) > 1, + default_prefix="generated_tts", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + save_audio(audio_obj, save_path, sample_rate=sample_rate) + items.append( + { + "index": idx, + "text": text, + "audio": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) + + first = items[0]["audio"] if items else {} + result: dict[str, Any] = { + "ok": True, + "executor": self.name, + "mode": "tts", + "model": self._model_name, + "audio": first, + "items": items, + "sample_rate": sample_rate, + } + storyboard = spec_dict.get("storyboard") + if isinstance(storyboard, dict): + result["storyboard"] = dict(storyboard) + maybe_upload_artifacts(task, out_dir, logger=logger) + return result # ── model ──────────────────────────────────────────────────────────── diff --git a/tests/worker/test_omni_executor_inheritance.py b/tests/worker/test_omni_executor_inheritance.py new file mode 100644 index 00000000..89ac207d --- /dev/null +++ b/tests/worker/test_omni_executor_inheritance.py @@ -0,0 +1,36 @@ +"""Compile-time guard that omni executors sit on the mixin chain. + +The mixin chain is the same one the inference and training executors use; the +omni family was the last to predate it. This test fails loudly if a future +refactor regresses the base class. +""" + +import pytest + +pytest.importorskip("vllm_omni", reason="vllm-omni not installed") + +from worker.executors.mixins.data import DataMixin +from worker.executors.mixins.governance import GovernanceMixin +from worker.executors.mixins.inference import InferenceMixin +from worker.executors.omni_executor_base import OmniExecutorBase +from worker.executors.omni_text2audio_executor import OmniText2AudioExecutor +from worker.executors.omni_text2general_executor import OmniText2GeneralExecutor +from worker.executors.omni_text2image_executor import OmniText2ImageExecutor +from worker.executors.omni_text2speech_executor import OmniText2SpeechExecutor + + +@pytest.mark.parametrize( + "cls", + [ + OmniExecutorBase, + OmniText2AudioExecutor, + OmniText2GeneralExecutor, + OmniText2ImageExecutor, + OmniText2SpeechExecutor, + ], + ids=lambda c: c.__name__, +) +def test_omni_executors_use_mixin_chain(cls: type) -> None: + assert issubclass(cls, InferenceMixin) + assert issubclass(cls, DataMixin) + assert issubclass(cls, GovernanceMixin) From f8a45ed36058731b307418d701cca6645f23a5d8 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Thu, 14 May 2026 10:08:32 +0000 Subject: [PATCH 2/6] refactor(omni): route prompts through DataMixin._collect_prompts_for_spec Replace the executor-local collect_text_inputs helper with the mixin's _collect_prompts_for_spec. Each omni executor now narrows PromptInput to str inline and raises ExecutionError if any item is not a string. Templates adopt the canonical data.type: list / items shape. Signed-off-by: Zhengyuan Su --- examples/templates/omni_text2audio.yaml | 4 +++- examples/templates/omni_text2general.yaml | 4 +++- examples/templates/omni_text2speech.yaml | 4 +++- src/worker/executors/omni_executor_base.py | 16 ---------------- .../executors/omni_text2audio_executor.py | 17 ++++++++++++----- .../executors/omni_text2general_executor.py | 17 ++++++++++++----- .../executors/omni_text2image_executor.py | 6 +++++- .../executors/omni_text2speech_executor.py | 9 ++++++--- 8 files changed, 44 insertions(+), 33 deletions(-) diff --git a/examples/templates/omni_text2audio.yaml b/examples/templates/omni_text2audio.yaml index 77d9711e..f621def4 100644 --- a/examples/templates/omni_text2audio.yaml +++ b/examples/templates/omni_text2audio.yaml @@ -18,7 +18,9 @@ spec: identifier: stabilityai/stable-audio-open-1.0 data: - text: Upbeat electronic music with synth pads and a driving beat + type: list + items: + - Upbeat electronic music with synth pads and a driving beat omni: output_format: wav diff --git a/examples/templates/omni_text2general.yaml b/examples/templates/omni_text2general.yaml index 569455a8..32b29c32 100644 --- a/examples/templates/omni_text2general.yaml +++ b/examples/templates/omni_text2general.yaml @@ -18,7 +18,9 @@ spec: identifier: Qwen/Qwen3-Omni-30B-A3B-Instruct data: - text: "Please read the following passage in a calm, professional tone: FlowMesh is a distributed GPU workflow engine." + type: list + items: + - "Please read the following passage in a calm, professional tone: FlowMesh is a distributed GPU workflow engine." omni: output_format: wav diff --git a/examples/templates/omni_text2speech.yaml b/examples/templates/omni_text2speech.yaml index 9fc10037..e0df5c4b 100644 --- a/examples/templates/omni_text2speech.yaml +++ b/examples/templates/omni_text2speech.yaml @@ -18,7 +18,9 @@ spec: identifier: Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice data: - text: "Hello, welcome to FlowMesh. This is a text-to-speech demo." + type: list + items: + - "Hello, welcome to FlowMesh. This is a text-to-speech demo." omni: output_format: wav diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 8dfb15c4..a44d497e 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -161,22 +161,6 @@ def resolve_model_identifier( value = default return str(value).strip() - @staticmethod - def collect_text_inputs(spec_dict: dict[str, Any]) -> list[str]: - data = spec_dict.get("data") or {} - if not isinstance(data, dict): - data = {} - dtype = str(data.get("type") or "").strip().lower() - if dtype == "list": - items = data.get("items") or [] - if not isinstance(items, list): - return [] - return [str(item).strip() for item in items if str(item).strip()] - text_val = data.get("text") or data.get("prompt") or spec_dict.get("task") - if not isinstance(text_val, str) or not text_val.strip(): - return [] - return [text_val.strip()] - @staticmethod def resolve_save_path( cfg: dict[str, Any], diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index be7fd20e..6403d33d 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -84,16 +84,23 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - return self._run_inner(task, spec_dict, out_dir) + return self._run_inner(task, spec, spec_dict, out_dir) def _run_inner( - self, task: ExecutorTask, spec_dict: dict[str, Any], out_dir: Path + self, + task: ExecutorTask, + spec: OmniText2AudioSpecStrict, + spec_dict: dict[str, Any], + out_dir: Path, ) -> dict[str, Any]: - prompts = self.collect_text_inputs(spec_dict) + prompts: list[str] = [] + for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: + if not isinstance(p, str): + raise ExecutionError("omni_text2audio prompts must be strings.") + prompts.append(p) if not prompts: raise ExecutionError( - "omni_text2audio requires prompt text " - "in spec.data.text or spec.data.items." + "omni_text2audio requires prompt text in spec.data.items." ) cfg = _bgm_cfg(spec_dict) diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index e9328a67..a783859b 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -67,16 +67,23 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - return self._run_inner(task, spec_dict, out_dir) + return self._run_inner(task, spec, spec_dict, out_dir) def _run_inner( - self, task: ExecutorTask, spec_dict: dict[str, Any], out_dir: Path + self, + task: ExecutorTask, + spec: OmniText2GeneralSpecStrict, + spec_dict: dict[str, Any], + out_dir: Path, ) -> dict[str, Any]: - texts = self.collect_text_inputs(spec_dict) + texts: list[str] = [] + for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: + if not isinstance(p, str): + raise ExecutionError("omni_text2general prompts must be strings.") + texts.append(p) if not texts: raise ExecutionError( - "omni_text2general requires text input " - "in spec.data.text or spec.data.items." + "omni_text2general requires text input in spec.data.items." ) cfg = _narration_cfg(spec_dict) diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 015f78ac..22927c27 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -45,7 +45,11 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - prompts = self.collect_text_inputs(spec_dict) + prompts: list[str] = [] + for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: + if not isinstance(p, str): + raise ExecutionError("omni_text2image prompts must be strings.") + prompts.append(p) if not prompts: raise ExecutionError( "omni_text2image requires prompts " diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 51903504..fa098a77 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -51,11 +51,14 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - texts = self.collect_text_inputs(spec_dict) + texts: list[str] = [] + for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: + if not isinstance(p, str): + raise ExecutionError("omni_text2speech prompts must be strings.") + texts.append(p) if not texts: raise ExecutionError( - "omni_text2speech requires text input " - "in spec.data.text or spec.data.items." + "omni_text2speech requires text input in spec.data.items." ) with self._span( From bb9310b36501c25018f786a36a46cd7cd5a4487a Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Thu, 14 May 2026 10:42:32 +0000 Subject: [PATCH 3/6] fix(omni): upload artifacts + traces after task span closes Move maybe_upload_artifacts out of the task span and add a missing maybe_upload_traces call right after, matching the vllm and diffusers pattern. Without the trace upload, omni span JSONL stayed on remote workers and never reached the server's /traces endpoint in HTTP mode. Also extract _run_inner in the image and speech executors so the post-span fall-through is the same shape across all four. Signed-off-by: Zhengyuan Su --- .../executors/omni_text2audio_executor.py | 10 +- .../executors/omni_text2general_executor.py | 10 +- .../executors/omni_text2image_executor.py | 139 ++++++++-------- .../executors/omni_text2speech_executor.py | 149 ++++++++++-------- 4 files changed, 166 insertions(+), 142 deletions(-) diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index 6403d33d..74da9c7f 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -56,7 +56,7 @@ from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase, extract_multimodal_output -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces logger = logging.getLogger(__name__) @@ -84,7 +84,10 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - return self._run_inner(task, spec, spec_dict, out_dir) + result = self._run_inner(task, spec, spec_dict, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result def _run_inner( self, @@ -104,7 +107,7 @@ def _run_inner( ) cfg = _bgm_cfg(spec_dict) - output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav" + output_format = str(cfg.get("output_format", "wav")).strip().lower() if output_format != "wav": raise ExecutionError( "omni_text2audio currently supports output_format='wav' only." @@ -232,7 +235,6 @@ def _run_inner( storyboard = spec_dict.get("storyboard") if isinstance(storyboard, dict): result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index a783859b..9718383a 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -32,7 +32,7 @@ from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase, extract_audio_from_mm, save_audio -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces logger = logging.getLogger(__name__) @@ -67,7 +67,10 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - return self._run_inner(task, spec, spec_dict, out_dir) + result = self._run_inner(task, spec, spec_dict, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result def _run_inner( self, @@ -87,7 +90,7 @@ def _run_inner( ) cfg = _narration_cfg(spec_dict) - output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav" + output_format = str(cfg.get("output_format", "wav")).strip().lower() if output_format != "wav": raise ExecutionError( "omni_text2general currently supports output_format='wav' only." @@ -205,7 +208,6 @@ def _run_inner( storyboard = spec_dict.get("storyboard") if isinstance(storyboard, dict): result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 22927c27..6f98b03d 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -21,7 +21,7 @@ from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces logger = logging.getLogger(__name__) @@ -45,71 +45,82 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - prompts: list[str] = [] - for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: - if not isinstance(p, str): - raise ExecutionError("omni_text2image prompts must be strings.") - prompts.append(p) - if not prompts: - raise ExecutionError( - "omni_text2image requires prompts " - "in spec.data.prompt or spec.data.items." + result = self._run_inner(task, spec, spec_dict, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result + + def _run_inner( + self, + task: ExecutorTask, + spec: OmniText2ImageSpecStrict, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + prompts: list[str] = [] + for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: + if not isinstance(p, str): + raise ExecutionError("omni_text2image prompts must be strings.") + prompts.append(p) + if not prompts: + raise ExecutionError( + "omni_text2image requires prompts " + "in spec.data.prompt or spec.data.items." + ) + + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + self._ensure_omni(spec_dict) + cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") + fmt = str(cfg.get("output_format", "png")).strip().lower() + + artifacts_dir = out_dir / "artifacts" + items: list[dict[str, Any]] = [] + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, + ): + images = self._generate_images(prompts) + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(prompts)}, + ): + for idx, (prompt, image) in enumerate(zip(prompts, images)): + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=fmt, + multi=len(prompts) > 1, + default_prefix="generated_image", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + _save_image(image, save_path) + items.append( + { + "index": idx, + "prompt": prompt, + "image": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } ) - with self._span( - "model load", - span_type=SpanType.COMPUTE, - attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, - ): - self._ensure_omni(spec_dict) - cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") - fmt = str(cfg.get("output_format") or "png").strip().lower() or "png" - - artifacts_dir = out_dir / "artifacts" - items: list[dict[str, Any]] = [] - with self._span( - "generation", - span_type=SpanType.COMPUTE, - attributes={"task_id": task.task_id, "prompt_count": len(prompts)}, - ): - images = self._generate_images(prompts) - with self._span( - "output postprocessing", - span_type=SpanType.COMPUTE, - attributes={"task_id": task.task_id, "item_count": len(prompts)}, - ): - for idx, (prompt, image) in enumerate(zip(prompts, images)): - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=fmt, - multi=len(prompts) > 1, - default_prefix="generated_image", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - _save_image(image, save_path) - items.append( - { - "index": idx, - "prompt": prompt, - "image": artifact_ref( - self.relative_to(save_path, artifacts_dir) - ), - } - ) - - first = items[0]["image"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "image", - "model": self._model_name, - "image": first, - "items": items, - } - maybe_upload_artifacts(task, out_dir, logger=logger) - return result + first = items[0]["image"] if items else {} + result: dict[str, Any] = { + "ok": True, + "executor": self.name, + "mode": "image", + "model": self._model_name, + "image": first, + "items": items, + } + return result # ── model ──────────────────────────────────────────────────────────── diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index fa098a77..692c2755 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -27,7 +27,7 @@ extract_multimodal_output, save_audio, ) -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts +from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces logger = logging.getLogger(__name__) @@ -51,79 +51,88 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: with self._task_span( task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id ): - texts: list[str] = [] - for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: - if not isinstance(p, str): - raise ExecutionError("omni_text2speech prompts must be strings.") - texts.append(p) - if not texts: - raise ExecutionError( - "omni_text2speech requires text input in spec.data.items." - ) + result = self._run_inner(task, spec, spec_dict, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result - with self._span( - "model load", - span_type=SpanType.COMPUTE, - attributes={"task_id": task.task_id, "prompt_count": len(texts)}, - ): - self._ensure_omni(spec_dict) - cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") - output_format = ( - str(cfg.get("output_format") or "wav").strip().lower() or "wav" + def _run_inner( + self, + task: ExecutorTask, + spec: OmniText2SpeechSpecStrict, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + texts: list[str] = [] + for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: + if not isinstance(p, str): + raise ExecutionError("omni_text2speech prompts must be strings.") + texts.append(p) + if not texts: + raise ExecutionError( + "omni_text2speech requires text input in spec.data.items." ) - sample_rate = to_int(cfg.get("sample_rate"), default=24000) - with self._span( - "generation", - span_type=SpanType.COMPUTE, - attributes={"task_id": task.task_id, "prompt_count": len(texts)}, - ): - audio_objects = [ - self._generate_single(t, spec_dict=spec_dict) for t in texts - ] - artifacts_dir = out_dir / "artifacts" - items: list[dict[str, Any]] = [] - with self._span( - "output postprocessing", - span_type=SpanType.COMPUTE, - attributes={"task_id": task.task_id, "item_count": len(texts)}, - ): - for idx, (text, audio_obj) in enumerate(zip(texts, audio_objects)): - save_path = self.resolve_save_path( - cfg, - out_dir, - index=idx, - ext=output_format, - multi=len(texts) > 1, - default_prefix="generated_tts", - ) - save_path.parent.mkdir(parents=True, exist_ok=True) - save_audio(audio_obj, save_path, sample_rate=sample_rate) - items.append( - { - "index": idx, - "text": text, - "audio": artifact_ref( - self.relative_to(save_path, artifacts_dir) - ), - } - ) + with self._span( + "model load", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + self._ensure_omni(spec_dict) + cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") + output_format = str(cfg.get("output_format", "wav")).strip().lower() + sample_rate = to_int(cfg.get("sample_rate"), default=24000) + + with self._span( + "generation", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "prompt_count": len(texts)}, + ): + audio_objects = [ + self._generate_single(t, spec_dict=spec_dict) for t in texts + ] + artifacts_dir = out_dir / "artifacts" + items: list[dict[str, Any]] = [] + with self._span( + "output postprocessing", + span_type=SpanType.COMPUTE, + attributes={"task_id": task.task_id, "item_count": len(texts)}, + ): + for idx, (text, audio_obj) in enumerate(zip(texts, audio_objects)): + save_path = self.resolve_save_path( + cfg, + out_dir, + index=idx, + ext=output_format, + multi=len(texts) > 1, + default_prefix="generated_tts", + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + save_audio(audio_obj, save_path, sample_rate=sample_rate) + items.append( + { + "index": idx, + "text": text, + "audio": artifact_ref( + self.relative_to(save_path, artifacts_dir) + ), + } + ) - first = items[0]["audio"] if items else {} - result: dict[str, Any] = { - "ok": True, - "executor": self.name, - "mode": "tts", - "model": self._model_name, - "audio": first, - "items": items, - "sample_rate": sample_rate, - } - storyboard = spec_dict.get("storyboard") - if isinstance(storyboard, dict): - result["storyboard"] = dict(storyboard) - maybe_upload_artifacts(task, out_dir, logger=logger) - return result + first = items[0]["audio"] if items else {} + result: dict[str, Any] = { + "ok": True, + "executor": self.name, + "mode": "tts", + "model": self._model_name, + "audio": first, + "items": items, + "sample_rate": sample_rate, + } + storyboard = spec_dict.get("storyboard") + if isinstance(storyboard, dict): + result["storyboard"] = dict(storyboard) + return result # ── model ──────────────────────────────────────────────────────────── From 188e3057c79236c18b560166f7c12ac032784ec6 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Thu, 14 May 2026 10:51:31 +0000 Subject: [PATCH 4/6] small patch Signed-off-by: Zhengyuan Su --- src/worker/executors/omni_text2audio_executor.py | 2 +- src/worker/executors/omni_text2general_executor.py | 2 +- src/worker/executors/omni_text2image_executor.py | 2 +- src/worker/executors/omni_text2speech_executor.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index 74da9c7f..3552d475 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -107,7 +107,7 @@ def _run_inner( ) cfg = _bgm_cfg(spec_dict) - output_format = str(cfg.get("output_format", "wav")).strip().lower() + output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" if output_format != "wav": raise ExecutionError( "omni_text2audio currently supports output_format='wav' only." diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index 9718383a..e6caf2e1 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -90,7 +90,7 @@ def _run_inner( ) cfg = _narration_cfg(spec_dict) - output_format = str(cfg.get("output_format", "wav")).strip().lower() + output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" if output_format != "wav": raise ExecutionError( "omni_text2general currently supports output_format='wav' only." diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 6f98b03d..351192dc 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -75,7 +75,7 @@ def _run_inner( ): self._ensure_omni(spec_dict) cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") - fmt = str(cfg.get("output_format", "png")).strip().lower() + fmt = str(cfg.get("output_format") or "").strip().lower() or "png" artifacts_dir = out_dir / "artifacts" items: list[dict[str, Any]] = [] diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 692c2755..f117eea0 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -80,7 +80,7 @@ def _run_inner( ): self._ensure_omni(spec_dict) cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") - output_format = str(cfg.get("output_format", "wav")).strip().lower() + output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" sample_rate = to_int(cfg.get("sample_rate"), default=24000) with self._span( From 849e89c4938e19e567b064692d761e3f15049580 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Fri, 15 May 2026 08:47:44 +0000 Subject: [PATCH 5/6] refactor(omni): hoist run() to OmniExecutorBase via _TASK_SPEC_TYPE Each concrete omni executor's run() did the same five things: resolve spec, dump dict, normalize out_dir, run the task span, upload artifacts and traces. Move that boilerplate to OmniExecutorBase.run() and let subclasses contribute via a _TASK_SPEC_TYPE class attribute plus an abstract _run_inner whose first line is `assert isinstance(spec, ...)` to recover the concrete type. Also adopt the cast(list[str], raw_prompts) form for the prompt-string narrowing in all four executors so the pattern reads identically. Signed-off-by: Zhengyuan Su --- src/worker/executors/omni_executor_base.py | 39 +++++++++++++++++-- .../executors/omni_text2audio_executor.py | 31 +++++---------- .../executors/omni_text2general_executor.py | 31 +++++---------- .../executors/omni_text2image_executor.py | 31 +++++---------- .../executors/omni_text2speech_executor.py | 31 +++++---------- 5 files changed, 75 insertions(+), 88 deletions(-) diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index a44d497e..87e40d35 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -13,17 +13,20 @@ import struct import tempfile import wave +from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar import yaml +from shared.tasks.specs import TaskSpecStrictBase from shared.utils.parsing import to_bool, to_int from ..config import WorkerConfig from ..lifecycle import Lifecycle -from .base_executor import ExecutionError, Executor +from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.inference import InferenceMixin +from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces try: import numpy as np @@ -42,10 +45,14 @@ class OmniExecutorBase(InferenceMixin, Executor): """Shared base for Omni-family executors. - Manages the ``_omni`` model handle and provides config / audio helpers. - Concrete subclasses implement ``prepare()`` and ``run()`` as usual. + Manages the ``_omni`` model handle, runs the generic ``run()`` shape + (task span + artifact / trace upload), and delegates the task body to + each subclass's ``_run_inner``. Subclasses set ``_TASK_SPEC_TYPE`` so + the base can call ``require_spec`` without knowing the concrete type. """ + _TASK_SPEC_TYPE: ClassVar[type[TaskSpecStrictBase]] + def __init__( self, config: WorkerConfig, lifecycle: Lifecycle | None = None ) -> None: @@ -55,6 +62,30 @@ def __init__( self._omni_spec: tuple[Any, ...] | None = None self._stage_configs_tmp: Path | None = None + def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: + spec = self.require_spec(task, self._TASK_SPEC_TYPE) + spec_dict = spec.model_dump(by_alias=True) + out_dir = Path(out_dir).resolve() + with self._task_span( + task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id + ): + result = self._run_inner(task, spec, spec_dict, out_dir) + maybe_upload_artifacts(task, out_dir, logger=logger) + maybe_upload_traces(task, out_dir, logger=logger) + return result + + @abstractmethod + def _run_inner( + self, + task: ExecutorTask, + spec: TaskSpecStrictBase, + spec_dict: dict[str, Any], + out_dir: Path, + ) -> dict[str, Any]: + """Run the executor-specific body. ``spec`` is the concrete strict + spec; subclasses ``assert isinstance(spec, ...)`` to narrow.""" + raise NotImplementedError + # ── model lifecycle ────────────────────────────────────────────────── def _close_omni(self) -> None: diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index 3552d475..75dc711a 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -3,7 +3,7 @@ import logging import wave from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast try: import numpy as np @@ -51,12 +51,13 @@ _HAS_OMNI_PLATFORM = False from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2AudioSpecStrict from shared.utils.parsing import to_float, to_int from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase, extract_multimodal_output -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -65,6 +66,7 @@ class OmniText2AudioExecutor(OmniExecutorBase): """Generate background music with Omni diffusion sampling.""" name = "omni_text2audio" + _TASK_SPEC_TYPE = OmniText2AudioSpecStrict def prepare(self) -> None: if torch is None: @@ -76,31 +78,18 @@ def prepare(self) -> None: "vllm_omni is not installed; cannot use omni_text2audio executor." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2AudioSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - with self._task_span( - task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id - ): - result = self._run_inner(task, spec, spec_dict, out_dir) - maybe_upload_artifacts(task, out_dir, logger=logger) - maybe_upload_traces(task, out_dir, logger=logger) - return result - def _run_inner( self, task: ExecutorTask, - spec: OmniText2AudioSpecStrict, + spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, ) -> dict[str, Any]: - prompts: list[str] = [] - for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: - if not isinstance(p, str): - raise ExecutionError("omni_text2audio prompts must be strings.") - prompts.append(p) + assert isinstance(spec, OmniText2AudioSpecStrict) + raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts + if not all(isinstance(p, str) for p in raw_prompts): + raise ExecutionError("omni_text2audio prompts must be strings.") + prompts = cast(list[str], raw_prompts) if not prompts: raise ExecutionError( "omni_text2audio requires prompt text in spec.data.items." diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index e6caf2e1..9e42fe87 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast try: from vllm import SamplingParams @@ -27,12 +27,13 @@ _HAS_OMNI = False from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2GeneralSpecStrict from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase, extract_audio_from_mm, save_audio -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ class OmniText2GeneralExecutor(OmniExecutorBase): """Generate narration/speech audio using Qwen3-Omni through vllm_omni.Omni.""" name = "omni_text2general" + _TASK_SPEC_TYPE = OmniText2GeneralSpecStrict def prepare(self) -> None: if not _HAS_OMNI: @@ -59,31 +61,18 @@ def prepare(self) -> None: "omni_text2general requires SamplingParams from vllm." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2GeneralSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - with self._task_span( - task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id - ): - result = self._run_inner(task, spec, spec_dict, out_dir) - maybe_upload_artifacts(task, out_dir, logger=logger) - maybe_upload_traces(task, out_dir, logger=logger) - return result - def _run_inner( self, task: ExecutorTask, - spec: OmniText2GeneralSpecStrict, + spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, ) -> dict[str, Any]: - texts: list[str] = [] - for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: - if not isinstance(p, str): - raise ExecutionError("omni_text2general prompts must be strings.") - texts.append(p) + assert isinstance(spec, OmniText2GeneralSpecStrict) + raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts + if not all(isinstance(t, str) for t in raw_prompts): + raise ExecutionError("omni_text2general prompts must be strings.") + texts = cast(list[str], raw_prompts) if not texts: raise ExecutionError( "omni_text2general requires text input in spec.data.items." diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 351192dc..0b382715 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast try: from vllm_omni.entrypoints.omni import Omni @@ -16,12 +16,13 @@ _HAS_OMNI = False from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2ImageSpecStrict from shared.utils.parsing import as_list from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import OmniExecutorBase -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -30,6 +31,7 @@ class OmniText2ImageExecutor(OmniExecutorBase): """Generate images using vllm_omni.Omni.""" name = "omni_text2image" + _TASK_SPEC_TYPE = OmniText2ImageSpecStrict def prepare(self) -> None: if not _HAS_OMNI: @@ -37,31 +39,18 @@ def prepare(self) -> None: "vllm_omni is not installed; cannot use omni_text2image executor." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2ImageSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - with self._task_span( - task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id - ): - result = self._run_inner(task, spec, spec_dict, out_dir) - maybe_upload_artifacts(task, out_dir, logger=logger) - maybe_upload_traces(task, out_dir, logger=logger) - return result - def _run_inner( self, task: ExecutorTask, - spec: OmniText2ImageSpecStrict, + spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, ) -> dict[str, Any]: - prompts: list[str] = [] - for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: - if not isinstance(p, str): - raise ExecutionError("omni_text2image prompts must be strings.") - prompts.append(p) + assert isinstance(spec, OmniText2ImageSpecStrict) + raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts + if not all(isinstance(p, str) for p in raw_prompts): + raise ExecutionError("omni_text2image prompts must be strings.") + prompts = cast(list[str], raw_prompts) if not prompts: raise ExecutionError( "omni_text2image requires prompts " diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index f117eea0..bafedf14 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast try: from vllm_omni.entrypoints.omni import Omni @@ -17,6 +17,7 @@ _HAS_OMNI = False from shared.schemas.governance import SpanType +from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2SpeechSpecStrict from shared.utils.parsing import as_list, to_int @@ -27,7 +28,7 @@ extract_multimodal_output, save_audio, ) -from .utils.checkpoints import artifact_ref, maybe_upload_artifacts, maybe_upload_traces +from .utils.checkpoints import artifact_ref logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ class OmniText2SpeechExecutor(OmniExecutorBase): """Generate speech audio using vllm_omni.Omni.""" name = "omni_text2speech" + _TASK_SPEC_TYPE = OmniText2SpeechSpecStrict def prepare(self) -> None: if not _HAS_OMNI: @@ -43,31 +45,18 @@ def prepare(self) -> None: "vllm_omni is not installed; cannot use omni_text2speech executor." ) - def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]: - spec = self.require_spec(task, OmniText2SpeechSpecStrict) - spec_dict = spec.model_dump(by_alias=True) - out_dir = Path(out_dir).resolve() - - with self._task_span( - task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id - ): - result = self._run_inner(task, spec, spec_dict, out_dir) - maybe_upload_artifacts(task, out_dir, logger=logger) - maybe_upload_traces(task, out_dir, logger=logger) - return result - def _run_inner( self, task: ExecutorTask, - spec: OmniText2SpeechSpecStrict, + spec: TaskSpecStrictBase, spec_dict: dict[str, Any], out_dir: Path, ) -> dict[str, Any]: - texts: list[str] = [] - for p in self._collect_prompts_for_spec(spec, task.task_id).prompts: - if not isinstance(p, str): - raise ExecutionError("omni_text2speech prompts must be strings.") - texts.append(p) + assert isinstance(spec, OmniText2SpeechSpecStrict) + raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts + if not all(isinstance(t, str) for t in raw_prompts): + raise ExecutionError("omni_text2speech prompts must be strings.") + texts = cast(list[str], raw_prompts) if not texts: raise ExecutionError( "omni_text2speech requires text input in spec.data.items." From 21b9149105ba443813782da01eafc9a974a14198 Mon Sep 17 00:00:00 2001 From: Noppanat Wadlom Date: Fri, 15 May 2026 09:20:21 +0000 Subject: [PATCH 6/6] refactor: introduce _collect_text_inputs as a shared helper Signed-off-by: Noppanat Wadlom --- src/worker/executors/omni_executor_base.py | 10 +++++++++- src/worker/executors/omni_text2audio_executor.py | 11 ++--------- src/worker/executors/omni_text2general_executor.py | 11 ++--------- src/worker/executors/omni_text2image_executor.py | 12 ++---------- src/worker/executors/omni_text2speech_executor.py | 11 ++--------- 5 files changed, 17 insertions(+), 38 deletions(-) diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 87e40d35..ff2e04dc 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -15,7 +15,7 @@ import wave from abc import abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast import yaml @@ -192,6 +192,14 @@ def resolve_model_identifier( value = default return str(value).strip() + def _collect_text_inputs(self, spec: TaskSpecStrictBase, task_id: str) -> list[str]: + prompts = self._collect_prompts_for_spec(spec, task_id).prompts + if not prompts: + raise ExecutionError(f"{self.name} requires text input in spec.data.items.") + if not all(isinstance(p, str) for p in prompts): + raise ExecutionError(f"{self.name} prompts must be strings.") + return cast(list[str], prompts) + @staticmethod def resolve_save_path( cfg: dict[str, Any], diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index 75dc711a..952dfb28 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -3,7 +3,7 @@ import logging import wave from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any try: import numpy as np @@ -86,14 +86,7 @@ def _run_inner( out_dir: Path, ) -> dict[str, Any]: assert isinstance(spec, OmniText2AudioSpecStrict) - raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts - if not all(isinstance(p, str) for p in raw_prompts): - raise ExecutionError("omni_text2audio prompts must be strings.") - prompts = cast(list[str], raw_prompts) - if not prompts: - raise ExecutionError( - "omni_text2audio requires prompt text in spec.data.items." - ) + prompts = self._collect_text_inputs(spec, task.task_id) cfg = _bgm_cfg(spec_dict) output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index 9e42fe87..5664af99 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any try: from vllm import SamplingParams @@ -69,14 +69,7 @@ def _run_inner( out_dir: Path, ) -> dict[str, Any]: assert isinstance(spec, OmniText2GeneralSpecStrict) - raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts - if not all(isinstance(t, str) for t in raw_prompts): - raise ExecutionError("omni_text2general prompts must be strings.") - texts = cast(list[str], raw_prompts) - if not texts: - raise ExecutionError( - "omni_text2general requires text input in spec.data.items." - ) + texts = self._collect_text_inputs(spec, task.task_id) cfg = _narration_cfg(spec_dict) output_format = str(cfg.get("output_format") or "").strip().lower() or "wav" diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 0b382715..e3b1f234 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any try: from vllm_omni.entrypoints.omni import Omni @@ -47,15 +47,7 @@ def _run_inner( out_dir: Path, ) -> dict[str, Any]: assert isinstance(spec, OmniText2ImageSpecStrict) - raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts - if not all(isinstance(p, str) for p in raw_prompts): - raise ExecutionError("omni_text2image prompts must be strings.") - prompts = cast(list[str], raw_prompts) - if not prompts: - raise ExecutionError( - "omni_text2image requires prompts " - "in spec.data.prompt or spec.data.items." - ) + prompts = self._collect_text_inputs(spec, task.task_id) with self._span( "model load", diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index bafedf14..5d7be1b7 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any try: from vllm_omni.entrypoints.omni import Omni @@ -53,14 +53,7 @@ def _run_inner( out_dir: Path, ) -> dict[str, Any]: assert isinstance(spec, OmniText2SpeechSpecStrict) - raw_prompts = self._collect_prompts_for_spec(spec, task.task_id).prompts - if not all(isinstance(t, str) for t in raw_prompts): - raise ExecutionError("omni_text2speech prompts must be strings.") - texts = cast(list[str], raw_prompts) - if not texts: - raise ExecutionError( - "omni_text2speech requires text input in spec.data.items." - ) + texts = self._collect_text_inputs(spec, task.task_id) with self._span( "model load",