Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 135 additions & 25 deletions skillopt_webui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
import json
import os
import signal
import socket
import subprocess
import sys
import threading
import time
from pathlib import Path
from urllib.parse import urlparse

import gradio as gr
import yaml

from skillopt.config import flatten_config
from skillopt.config import load_config as load_merged_config

PROJECT_ROOT = Path(__file__).resolve().parent.parent


Expand All @@ -42,6 +46,131 @@ def config_to_display(cfg: dict) -> str:
return yaml.dump(cfg, default_flow_style=False, sort_keys=False)


def _can_connect_to_url(url: str, timeout: float = 0.5) -> bool:
parsed = urlparse(url)
host = parsed.hostname
if not host:
return False
port = parsed.port or (443 if parsed.scheme == "https" else 80)
try:
with socket.create_connection((host, port), timeout=timeout):
return True
except OSError:
return False


def _load_env_file(path: Path, env: dict[str, str]) -> None:
for line in path.read_text().splitlines():
line = line.strip()
if line.startswith("export "):
line = line[len("export "):].strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
env[key.strip()] = value.strip().strip("\"'")


def build_training_env() -> dict[str, str]:
"""Build the environment shared by preflight and the training subprocess."""
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"

dot_env = PROJECT_ROOT / ".env"
if dot_env.is_file():
_load_env_file(dot_env, env)

secrets_dir = PROJECT_ROOT / ".secrets"
if secrets_dir.is_dir():
for env_file in sorted(secrets_dir.glob("*.env")):
_load_env_file(env_file, env)

# Propagate OPTIMIZER_* to base AZURE_OPENAI_* when base is missing,
# so target/default endpoints inherit from optimizer config.
for suffix in (
"ENDPOINT", "API_VERSION", "AUTH_MODE", "MANAGED_IDENTITY_CLIENT_ID",
"AD_SCOPE", "API_KEY",
):
base_key = f"AZURE_OPENAI_{suffix}"
optimizer_key = f"OPTIMIZER_AZURE_OPENAI_{suffix}"
if not env.get(base_key) and env.get(optimizer_key):
env[base_key] = env[optimizer_key]
return env


def validate_training_config(
config_path: str,
overrides: dict,
env: dict[str, str] | None = None,
) -> str | None:
"""Return an actionable preflight error, or None when training can start."""
env = env or os.environ
cfg_options = [
f"{key}={value}" for key, value in overrides.items()
if value is not None and value != ""
]
try:
cfg = flatten_config(load_merged_config(str(PROJECT_ROOT / config_path), cfg_options))
except Exception as exc:
return f"❌ Invalid config: {exc}"

shared_endpoint = (
cfg.get("azure_openai_endpoint")
or cfg.get("azure_endpoint")
or env.get("AZURE_OPENAI_ENDPOINT")
)
missing_openai_roles = []
for role in ("optimizer", "target"):
if cfg.get(f"{role}_backend") != "openai_chat":
continue
role_endpoint = (
cfg.get(f"{role}_azure_openai_endpoint")
or env.get(f"{role.upper()}_AZURE_OPENAI_ENDPOINT")
or shared_endpoint
)
if not role_endpoint:
missing_openai_roles.append(role)
if missing_openai_roles:
configured_backend = cfg.get("model_backend")
detail = ""
if configured_backend in {"qwen", "qwen_chat"}:
detail = (
"\nNote: model.backend is qwen, but explicit optimizer_backend/"
"target_backend values are still openai_chat."
)
return (
"❌ Model backend is not ready: missing Azure/OpenAI-compatible endpoint "
f"for {', '.join(missing_openai_roles)}.\n"
"Set model.azure_openai_endpoint (or AZURE_OPENAI_ENDPOINT), or change "
"the role backends to the backend you intend to use."
f"{detail}"
)

qwen_failures = []
qwen_shared = (
cfg.get("qwen_chat_base_url")
or env.get("QWEN_CHAT_BASE_URL")
or "http://localhost:8000/v1"
)
for role in ("optimizer", "target"):
if cfg.get(f"{role}_backend") != "qwen_chat":
continue
base_url = (
cfg.get(f"{role}_qwen_chat_base_url")
or env.get(f"{role.upper()}_QWEN_CHAT_BASE_URL")
or qwen_shared
)
if not _can_connect_to_url(str(base_url)):
qwen_failures.append(f"{role}={base_url}")
if qwen_failures:
return (
"❌ Model backend is not ready: cannot connect to qwen_chat endpoint "
f"for {', '.join(qwen_failures)}.\n"
"Start your OpenAI-compatible Qwen/vLLM server, or set "
"model.qwen_chat_base_url / OPTIMIZER_QWEN_CHAT_BASE_URL / "
"TARGET_QWEN_CHAT_BASE_URL to the correct URL."
)
return None


# ─── Training process management ────────────────────────────────────────────

class TrainingManager:
Expand All @@ -63,6 +192,11 @@ def start(self, config_path: str, overrides: dict) -> str:
if self.running:
return "⚠️ Training already running. Stop it first."

env = build_training_env()
preflight_error = validate_training_config(config_path, overrides, env)
if preflight_error:
return preflight_error

cmd = [
sys.executable, "scripts/train.py",
"--config", config_path,
Expand All @@ -75,30 +209,6 @@ def start(self, config_path: str, overrides: dict) -> str:
cmd.append("--cfg-options")
cmd.extend(cfg_options)

env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
# Auto-load API credentials from .secrets/*.env
secrets_dir = PROJECT_ROOT / ".secrets"
if secrets_dir.is_dir():
for env_file in sorted(secrets_dir.glob("*.env")):
for line in env_file.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, v = line.split("=", 1)
env[k] = v
# Propagate OPTIMIZER_* to base AZURE_OPENAI_* when base is missing,
# so target/default endpoints inherit from optimizer config.
_propagate = [
("ENDPOINT", ""), ("API_VERSION", ""), ("AUTH_MODE", ""),
("MANAGED_IDENTITY_CLIENT_ID", ""), ("AD_SCOPE", ""),
("API_KEY", ""),
]
for suffix, _ in _propagate:
base_key = f"AZURE_OPENAI_{suffix}"
optimizer_key = f"OPTIMIZER_AZURE_OPENAI_{suffix}"
if not env.get(base_key) and env.get(optimizer_key):
env[base_key] = env[optimizer_key]

try:
proc = subprocess.Popen(
cmd,
Expand Down
89 changes: 89 additions & 0 deletions tests/test_webui_env_preflight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import yaml

pytest.importorskip("gradio")

from skillopt_webui import app as webui_app


def _write_config(tmp_path, model):
config_path = tmp_path / "config.yaml"
config_path.write_text(
yaml.safe_dump({
"model": model,
"env": {"name": "searchqa"},
}),
encoding="utf-8",
)
return str(config_path)


def test_build_training_env_loads_project_dotenv(tmp_path, monkeypatch):
monkeypatch.setattr(webui_app, "PROJECT_ROOT", tmp_path)
(tmp_path / ".env").write_text(
"\n".join([
"export QWEN_CHAT_BASE_URL=http://qwen.example/v1",
"QWEN_CHAT_MODEL=test-model",
"QWEN_CHAT_API_KEY='secret-value'",
]),
encoding="utf-8",
)

env = webui_app.build_training_env()

assert env["QWEN_CHAT_BASE_URL"] == "http://qwen.example/v1"
assert env["QWEN_CHAT_MODEL"] == "test-model"
assert env["QWEN_CHAT_API_KEY"] == "secret-value"


def test_preflight_reports_missing_openai_chat_endpoint(tmp_path, monkeypatch):
monkeypatch.delenv("AZURE_OPENAI_ENDPOINT", raising=False)
monkeypatch.delenv("OPTIMIZER_AZURE_OPENAI_ENDPOINT", raising=False)
monkeypatch.delenv("TARGET_AZURE_OPENAI_ENDPOINT", raising=False)
config_path = _write_config(
tmp_path,
{
"backend": "qwen",
"optimizer_backend": "openai_chat",
"target_backend": "openai_chat",
},
)

error = webui_app.validate_training_config(config_path, {})

assert "missing Azure/OpenAI-compatible endpoint for optimizer, target" in error
assert "model.backend is qwen" in error


def test_preflight_reports_unreachable_qwen_endpoint(tmp_path, monkeypatch):
monkeypatch.setattr(webui_app, "_can_connect_to_url", lambda _url: False)
config_path = _write_config(
tmp_path,
{
"backend": "qwen",
"optimizer_backend": "qwen_chat",
"target_backend": "qwen_chat",
"qwen_chat_base_url": "http://127.0.0.1:9/v1",
},
)

error = webui_app.validate_training_config(config_path, {})

assert "cannot connect to qwen_chat endpoint" in error
assert "127.0.0.1:9" in error


def test_preflight_accepts_reachable_qwen_endpoint(tmp_path, monkeypatch):
seen_urls = []
monkeypatch.setattr(webui_app, "_can_connect_to_url", lambda url: seen_urls.append(url) or True)
config_path = _write_config(
tmp_path,
{
"optimizer_backend": "qwen_chat",
"target_backend": "qwen_chat",
"qwen_chat_base_url": "http://qwen.example/v1",
},
)

assert webui_app.validate_training_config(config_path, {}) is None
assert seen_urls == ["http://qwen.example/v1", "http://qwen.example/v1"]