Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/templates/omni_text2audio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ spec:
identifier: stabilityai/stable-audio-open-1.0

data:
text: Upbeat electronic music with synth pads and a driving beat
type: list
items:
- Upbeat electronic music with synth pads and a driving beat

omni:
output_format: wav
Expand Down
4 changes: 3 additions & 1 deletion examples/templates/omni_text2general.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ spec:
identifier: Qwen/Qwen3-Omni-30B-A3B-Instruct

data:
text: "Please read the following passage in a calm, professional tone: FlowMesh is a distributed GPU workflow engine."
type: list
items:
- "Please read the following passage in a calm, professional tone: FlowMesh is a distributed GPU workflow engine."

omni:
output_format: wav
Expand Down
4 changes: 3 additions & 1 deletion examples/templates/omni_text2speech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ spec:
identifier: Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice

data:
text: "Hello, welcome to FlowMesh. This is a text-to-speech demo."
type: list
items:
- "Hello, welcome to FlowMesh. This is a text-to-speech demo."

omni:
output_format: wav
Expand Down
64 changes: 44 additions & 20 deletions src/worker/executors/omni_executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
import struct
import tempfile
import wave
from abc import abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, cast

import yaml

from shared.tasks.specs import TaskSpecStrictBase
from shared.utils.parsing import to_bool, to_int

from ..config import WorkerConfig
from ..lifecycle import Lifecycle
from .base_executor import ExecutionError, Executor
from .base_executor import ExecutionError, Executor, ExecutorTask
from .mixins.inference import InferenceMixin
from .utils.checkpoints import maybe_upload_artifacts, maybe_upload_traces

try:
import numpy as np
Expand All @@ -38,13 +42,17 @@
logger = logging.getLogger(__name__)


class OmniExecutorBase(Executor):
class OmniExecutorBase(InferenceMixin, Executor):
"""Shared base for Omni-family executors.

Manages the ``_omni`` model handle and provides config / audio helpers.
Concrete subclasses implement ``prepare()`` and ``run()`` as usual.
Manages the ``_omni`` model handle, runs the generic ``run()`` shape
(task span + artifact / trace upload), and delegates the task body to
each subclass's ``_run_inner``. Subclasses set ``_TASK_SPEC_TYPE`` so
the base can call ``require_spec`` without knowing the concrete type.
"""

_TASK_SPEC_TYPE: ClassVar[type[TaskSpecStrictBase]]

def __init__(
self, config: WorkerConfig, lifecycle: Lifecycle | None = None
) -> None:
Expand All @@ -54,6 +62,30 @@ def __init__(
self._omni_spec: tuple[Any, ...] | None = None
self._stage_configs_tmp: Path | None = None

def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
spec = self.require_spec(task, self._TASK_SPEC_TYPE)
spec_dict = spec.model_dump(by_alias=True)
out_dir = Path(out_dir).resolve()
with self._task_span(
task.task_id, task.workflow_id, out_dir, owner_id=task.owner_id
):
result = self._run_inner(task, spec, spec_dict, out_dir)
maybe_upload_artifacts(task, out_dir, logger=logger)
maybe_upload_traces(task, out_dir, logger=logger)
return result

@abstractmethod
def _run_inner(
self,
task: ExecutorTask,
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
) -> dict[str, Any]:
"""Run the executor-specific body. ``spec`` is the concrete strict
spec; subclasses ``assert isinstance(spec, ...)`` to narrow."""
raise NotImplementedError

# ── model lifecycle ──────────────────────────────────────────────────

def _close_omni(self) -> None:
Expand Down Expand Up @@ -160,21 +192,13 @@ def resolve_model_identifier(
value = default
return str(value).strip()

@staticmethod
def collect_text_inputs(spec_dict: dict[str, Any]) -> list[str]:
data = spec_dict.get("data") or {}
if not isinstance(data, dict):
data = {}
dtype = str(data.get("type") or "").strip().lower()
if dtype == "list":
items = data.get("items") or []
if not isinstance(items, list):
return []
return [str(item).strip() for item in items if str(item).strip()]
text_val = data.get("text") or data.get("prompt") or spec_dict.get("task")
if not isinstance(text_val, str) or not text_val.strip():
return []
return [text_val.strip()]
def _collect_text_inputs(self, spec: TaskSpecStrictBase, task_id: str) -> list[str]:
prompts = self._collect_prompts_for_spec(spec, task_id).prompts
if not prompts:
raise ExecutionError(f"{self.name} requires text input in spec.data.items.")
if not all(isinstance(p, str) for p in prompts):
raise ExecutionError(f"{self.name} prompts must be strings.")
return cast(list[str], prompts)

@staticmethod
def resolve_save_path(
Expand Down
168 changes: 95 additions & 73 deletions src/worker/executors/omni_text2audio_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@
current_omni_platform = None
_HAS_OMNI_PLATFORM = False

from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.specs.omni import OmniText2AudioSpecStrict
from shared.utils.parsing import to_float, to_int

from .base_executor import ExecutionError, ExecutorTask
from .omni_executor_base import OmniExecutorBase, extract_multimodal_output
from .utils.checkpoints import artifact_ref, maybe_upload_artifacts
from .utils.checkpoints import artifact_ref

logger = logging.getLogger(__name__)

Expand All @@ -64,6 +66,7 @@ class OmniText2AudioExecutor(OmniExecutorBase):
"""Generate background music with Omni diffusion sampling."""

name = "omni_text2audio"
_TASK_SPEC_TYPE = OmniText2AudioSpecStrict

def prepare(self) -> None:
if torch is None:
Expand All @@ -75,20 +78,18 @@ def prepare(self) -> None:
"vllm_omni is not installed; cannot use omni_text2audio executor."
)

def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
spec = self.require_spec(task, OmniText2AudioSpecStrict)
spec_dict = spec.model_dump(by_alias=True)
out_dir = Path(out_dir).resolve()

prompts = self.collect_text_inputs(spec_dict)
if not prompts:
raise ExecutionError(
"omni_text2audio requires prompt text "
"in spec.data.text or spec.data.items."
)
def _run_inner(
self,
task: ExecutorTask,
spec: TaskSpecStrictBase,
spec_dict: dict[str, Any],
out_dir: Path,
) -> dict[str, Any]:
assert isinstance(spec, OmniText2AudioSpecStrict)
prompts = self._collect_text_inputs(spec, task.task_id)

cfg = _bgm_cfg(spec_dict)
output_format = str(cfg.get("output_format") or "wav").strip().lower() or "wav"
output_format = str(cfg.get("output_format") or "").strip().lower() or "wav"
Comment thread
kaiitunnz marked this conversation as resolved.
if output_format != "wav":
raise ExecutionError(
"omni_text2audio currently supports output_format='wav' only."
Expand All @@ -106,75 +107,97 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
base_seed = to_int(cfg.get("seed"), default=42)
negative_prompt = str(cfg.get("negative_prompt") or "Low quality.").strip()

self._ensure_omni(spec_dict)
with self._span(
"model load",
span_type=SpanType.COMPUTE,
attributes={"task_id": task.task_id, "prompt_count": len(prompts)},
):
self._ensure_omni(spec_dict)
omni = self._omni
if omni is None:
raise ExecutionError("Omni BGM model failed to initialize.")

generator_device = _resolve_generator_device()
per_prompt_outputs: list[tuple[int, str, Any]] = []
for prompt_idx, prompt in enumerate(prompts):
seed = base_seed + prompt_idx
torch_generator = torch.Generator(device=generator_device).manual_seed(seed)

omni_prompt: OmniTextPrompt = {"prompt": prompt}
if negative_prompt:
omni_prompt["negative_prompt"] = negative_prompt

sampling = OmniDiffusionSamplingParams(
generator=torch_generator,
generator_device=generator_device,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_outputs_per_prompt=num_waveforms,
extra_args={
"audio_start_in_s": audio_start,
"audio_end_in_s": audio_end,
},
)
try:
outputs = omni.generate(omni_prompt, sampling)
except Exception as exc:
raise ExecutionError(
f"omni_text2audio generation failed: {exc}"
) from exc
per_prompt_outputs.append((prompt_idx, prompt, outputs))
with self._span(
"generation",
span_type=SpanType.COMPUTE,
attributes={
"task_id": task.task_id,
"prompt_count": len(prompts),
"num_waveforms": num_waveforms,
"num_inference_steps": num_inference_steps,
},
):
for prompt_idx, prompt in enumerate(prompts):
seed = base_seed + prompt_idx
torch_generator = torch.Generator(device=generator_device).manual_seed(
seed
)

omni_prompt: OmniTextPrompt = {"prompt": prompt}
if negative_prompt:
omni_prompt["negative_prompt"] = negative_prompt

sampling = OmniDiffusionSamplingParams(
generator=torch_generator,
generator_device=generator_device,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_outputs_per_prompt=num_waveforms,
extra_args={
"audio_start_in_s": audio_start,
"audio_end_in_s": audio_end,
},
)
try:
outputs = omni.generate(omni_prompt, sampling)
except Exception as exc:
raise ExecutionError(
f"omni_text2audio generation failed: {exc}"
) from exc
per_prompt_outputs.append((prompt_idx, prompt, outputs))

artifacts_dir = out_dir / "artifacts"
items: list[dict[str, Any]] = []
global_index = 0
for prompt_idx, prompt, outputs in per_prompt_outputs:
extracted = _extract_audio_waveforms(outputs)
if not extracted:
raise ExecutionError(
"omni_text2audio completed but returned no audio output."
)

for local_idx, audio_entry in enumerate(extracted):
multi = len(prompts) * len(extracted) > 1
save_path = _resolve_bgm_save_path(
cfg,
out_dir,
index=global_index,
ext=output_format,
multi=multi,
)
save_path.parent.mkdir(parents=True, exist_ok=True)
_save_waveform(
audio_entry["waveform"], save_path, sample_rate=sample_rate
)
items.append(
{
"index": global_index,
"prompt_index": prompt_idx,
"waveform_index": local_idx,
"prompt": prompt,
"audio": artifact_ref(
self.relative_to(save_path, artifacts_dir)
),
}
)
global_index += 1
with self._span(
"output postprocessing",
span_type=SpanType.COMPUTE,
attributes={"task_id": task.task_id, "prompt_count": len(prompts)},
):
for prompt_idx, prompt, outputs in per_prompt_outputs:
extracted = _extract_audio_waveforms(outputs)
if not extracted:
raise ExecutionError(
"omni_text2audio completed but returned no audio output."
)

for local_idx, audio_entry in enumerate(extracted):
multi = len(prompts) * len(extracted) > 1
save_path = _resolve_bgm_save_path(
cfg,
out_dir,
index=global_index,
ext=output_format,
multi=multi,
)
save_path.parent.mkdir(parents=True, exist_ok=True)
_save_waveform(
audio_entry["waveform"], save_path, sample_rate=sample_rate
)
items.append(
{
"index": global_index,
"prompt_index": prompt_idx,
"waveform_index": local_idx,
"prompt": prompt,
"audio": artifact_ref(
self.relative_to(save_path, artifacts_dir)
),
}
)
global_index += 1

if not items:
raise ExecutionError("omni_text2audio produced no savable waveforms.")
Expand All @@ -194,7 +217,6 @@ def run(self, task: ExecutorTask, out_dir: Path) -> dict[str, Any]:
storyboard = spec_dict.get("storyboard")
if isinstance(storyboard, dict):
result["storyboard"] = dict(storyboard)
maybe_upload_artifacts(task, out_dir, logger=logger)
return result

# ── model ────────────────────────────────────────────────────────────
Expand Down
Loading