diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 33e90770..a972bb66 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - run: pip install pandas && pip install -e ".[dev]" + - run: pip install pandas && pip install -e ".[dev,api,viz]" - name: Run tests with coverage run: | pytest tests/ -m "not slow and not benchmark" \ @@ -58,7 +58,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.11" - - run: pip install pandas && pip install -e ".[dev]" + - run: pip install pandas && pip install -e ".[dev,api,viz]" - name: CLI smoke run: omega --help - name: Import smoke diff --git a/pyproject.toml b/pyproject.toml index 1d270168..2b90211f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ viz = ["matplotlib~=3.7"] dev = [ "pytest~=8.0", "pytest-cov~=5.0", + "pytest-timeout>=2.2", "ruff~=0.6", "mypy~=1.8", "types-PyYAML", diff --git a/src/omega_pbpk/adapters/__init__.py b/src/omega_pbpk/adapters/__init__.py index 6672166e..53a34816 100644 --- a/src/omega_pbpk/adapters/__init__.py +++ b/src/omega_pbpk/adapters/__init__.py @@ -1,3 +1,3 @@ -from omega_pbpk.adapters.yaml_loader import load_drug_spec, drug_to_spec, spec_to_drug +from omega_pbpk.adapters.yaml_loader import drug_to_spec, load_drug_spec, spec_to_drug __all__ = ["load_drug_spec", "drug_to_spec", "spec_to_drug"] diff --git a/src/omega_pbpk/adapters/population_adapter.py b/src/omega_pbpk/adapters/population_adapter.py index c3d07d60..e67020d8 100644 --- a/src/omega_pbpk/adapters/population_adapter.py +++ b/src/omega_pbpk/adapters/population_adapter.py @@ -1,4 +1,5 @@ """VirtualPopulation → PatientSpec 변환 어댑터.""" + from __future__ import annotations from omega_pbpk.contracts.patient_spec import PatientSpec diff --git a/src/omega_pbpk/adapters/yaml_loader.py b/src/omega_pbpk/adapters/yaml_loader.py index aea93a15..262487d3 100644 --- a/src/omega_pbpk/adapters/yaml_loader.py +++ b/src/omega_pbpk/adapters/yaml_loader.py @@ -8,9 +8,9 @@ spec = load_drug_spec("compounds/caffeine.yaml") spec.validate() """ + from __future__ import annotations -import warnings from pathlib import Path from typing import Literal @@ -34,9 +34,7 @@ def drug_to_spec(drug: Drug, param_source: str = "yaml") -> DrugSpec: """기존 Drug 객체 → DrugSpec 변환.""" - compound_type = _DRUG_TYPE_MAP.get( - getattr(drug, "drug_type", "neutral"), "neutral" - ) + compound_type = _DRUG_TYPE_MAP.get(getattr(drug, "drug_type", "neutral"), "neutral") return DrugSpec( name=drug.name, smiles=drug.smiles, @@ -61,8 +59,12 @@ def drug_to_spec(drug: Drug, param_source: str = "yaml") -> DrugSpec: def spec_to_drug(spec: DrugSpec) -> Drug: """DrugSpec → 기존 Drug 객체 역변환 (ODE 엔진 호환용).""" # compound_type → drug_type 역매핑 - _reverse_map = {"neutral": "neutral", "acid": "monoprotic_acid", - "base": "monoprotic_base", "zwitterion": "diprotic"} + _reverse_map = { + "neutral": "neutral", + "acid": "monoprotic_acid", + "base": "monoprotic_base", + "zwitterion": "diprotic", + } return Drug( name=spec.name, smiles=spec.smiles, @@ -96,6 +98,7 @@ def load_drug_spec(path: str | Path) -> DrugSpec: FileNotFoundError: YAML 파일 없음 """ from omega_pbpk.config import load_compound + drug = load_compound(path) return drug_to_spec(drug, param_source="yaml") diff --git a/src/omega_pbpk/api/app.py b/src/omega_pbpk/api/app.py index f6327e2e..324d6af9 100644 --- a/src/omega_pbpk/api/app.py +++ b/src/omega_pbpk/api/app.py @@ -24,6 +24,7 @@ try: from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse + HAS_FASTAPI = True except ImportError: HAS_FASTAPI = False @@ -37,8 +38,7 @@ if not HAS_FASTAPI: raise ImportError( - "FastAPI is required to use the API server. " - "Install with: pip install omega-pbpk[api]" + "FastAPI is required to use the API server. Install with: pip install omega-pbpk[api]" ) app = FastAPI( @@ -653,13 +653,14 @@ def train_surrogate(req: TrainSurrogateRequest) -> TrainSurrogateResponse: if req.n_samples < 10 or req.n_samples > 10000: raise HTTPException(status_code=422, detail="n_samples must be 10–10000") try: - from omega_pbpk.surrogate.data_generator import generate_training_data from omega_pbpk.surrogate import PKSurrogate + from omega_pbpk.surrogate.data_generator import generate_training_data data = generate_training_data(n_samples=min(req.n_samples, 100)) # MVP: 100으로 제한 model = PKSurrogate(n_input=data.n_params) model.train(data.X, data.y, epochs=min(req.epochs, 20)) import os + os.makedirs(req.output_dir, exist_ok=True) model.save(req.output_dir) return TrainSurrogateResponse( @@ -695,6 +696,8 @@ def validate(req: ValidateRequest) -> ValidateResponse: except Exception as exc: return ValidateResponse(mode="benchmark", passed=False, results={"error": str(exc)}) elif req.mode == "sanity": - return ValidateResponse(mode="sanity", passed=True, results={"message": "Sanity checks passed"}) + return ValidateResponse( + mode="sanity", passed=True, results={"message": "Sanity checks passed"} + ) else: raise HTTPException(status_code=422, detail=f"Unknown mode: {req.mode}") diff --git a/src/omega_pbpk/api/server.py b/src/omega_pbpk/api/server.py index b7bdba06..69e703fd 100644 --- a/src/omega_pbpk/api/server.py +++ b/src/omega_pbpk/api/server.py @@ -2,6 +2,7 @@ This module is kept for backward compatibility. """ + import warnings warnings.warn( diff --git a/src/omega_pbpk/cli.py b/src/omega_pbpk/cli.py index 6f0ce945..f4de3432 100644 --- a/src/omega_pbpk/cli.py +++ b/src/omega_pbpk/cli.py @@ -42,14 +42,14 @@ # 4-verb sub-apps (M3 CLI restructure) # --------------------------------------------------------------------------- simulate_app = typer.Typer(help="Run PBPK simulations.") -train_app = typer.Typer(help="Train ML surrogate or calibrate models.") +train_app = typer.Typer(help="Train ML surrogate or calibrate models.") validate_app = typer.Typer(help="Validate, benchmark, and QA.") -serve_app = typer.Typer(help="Start API server.") +serve_app = typer.Typer(help="Start API server.") app.add_typer(simulate_app, name="simulate") -app.add_typer(train_app, name="train") +app.add_typer(train_app, name="train") app.add_typer(validate_app, name="validate") -app.add_typer(serve_app, name="serve") +app.add_typer(serve_app, name="serve") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logger = logging.getLogger("omega_pbpk") @@ -1408,10 +1408,8 @@ def serve( try: import uvicorn except ImportError: - typer.echo( - "API server requires [api] extras. Install with: pip install omega-pbpk[api]" - ) - raise SystemExit(1) + typer.echo("API server requires [api] extras. Install with: pip install omega-pbpk[api]") + raise SystemExit(1) from None typer.echo(f"Starting Omega PBPK API server on http://{host}:{port}") uvicorn.run("omega_pbpk.api.app:app", host=host, port=port, reload=reload) @@ -1435,6 +1433,7 @@ def run_tests() -> None: # --- simulate group ---------------------------------------------------------- + @simulate_app.command("single") def simulate_single( compound: str = typer.Argument(..., help="Compound name or YAML path"), @@ -1529,6 +1528,7 @@ def simulate_pgx( # --- train group ------------------------------------------------------------- + @train_app.command("surrogate") def train_surrogate_cmd( n_samples: int = typer.Option(500, help="Training samples"), @@ -1572,6 +1572,7 @@ def train_calibrate_cmd( # --- validate group ---------------------------------------------------------- + @validate_app.command("benchmark") def validate_benchmark_cmd( suite_dir: str = typer.Option("benchmarks", help="Path to benchmark suite directory."), @@ -1629,6 +1630,7 @@ def validate_sensitivity_cmd( # --- serve group ------------------------------------------------------------- + @serve_app.command("start") def serve_start( host: str = typer.Option("0.0.0.0", help="Host to bind the server to."), @@ -1639,10 +1641,8 @@ def serve_start( try: import uvicorn except ImportError: - typer.echo( - "API server requires [api] extras: pip install omega-pbpk[api]", err=True - ) - raise typer.Exit(1) + typer.echo("API server requires [api] extras: pip install omega-pbpk[api]", err=True) + raise typer.Exit(1) from None typer.echo(f"Starting Omega PBPK API server on http://{host}:{port}") uvicorn.run("omega_pbpk.api.app:app", host=host, port=port, reload=reload) diff --git a/src/omega_pbpk/contracts/drug_spec.py b/src/omega_pbpk/contracts/drug_spec.py index ef87e1b6..41500293 100644 --- a/src/omega_pbpk/contracts/drug_spec.py +++ b/src/omega_pbpk/contracts/drug_spec.py @@ -1,4 +1,5 @@ from __future__ import annotations + from dataclasses import dataclass, field from typing import Literal @@ -11,12 +12,12 @@ class DrugSpec: logP: float = 2.0 pka: list[float] = field(default_factory=lambda: [7.0]) compound_type: Literal["neutral", "acid", "base", "zwitterion"] = "neutral" - fup: float = 0.5 # fraction unbound in plasma - rbp: float = 1.0 # blood:plasma ratio + fup: float = 0.5 # fraction unbound in plasma + rbp: float = 1.0 # blood:plasma ratio clint_hepatic_L_per_h: float = 0.0 clint_gut_L_per_h: float = 0.0 clr_L_per_h: float = 0.0 - peff: float = 1.0 # effective permeability (cm/s × 10^-4) + peff: float = 1.0 # effective permeability (cm/s × 10^-4) solubility_mg_mL: float = 1.0 kp: dict[str, float] = field(default_factory=dict) permeability_limited: dict[str, dict[str, float]] = field(default_factory=dict) @@ -32,7 +33,9 @@ def validate(self) -> None: if self.mw <= 0: raise ValueError(f"mw must be > 0, got {self.mw}") if self.clint_hepatic_L_per_h < 0: - raise ValueError(f"clint_hepatic_L_per_h must be >= 0, got {self.clint_hepatic_L_per_h}") + raise ValueError( + f"clint_hepatic_L_per_h must be >= 0, got {self.clint_hepatic_L_per_h}" + ) if self.clr_L_per_h < 0: raise ValueError(f"clr_L_per_h must be >= 0, got {self.clr_L_per_h}") if self.peff < 0: diff --git a/src/omega_pbpk/contracts/patient_spec.py b/src/omega_pbpk/contracts/patient_spec.py index 65a84eb4..508b7974 100644 --- a/src/omega_pbpk/contracts/patient_spec.py +++ b/src/omega_pbpk/contracts/patient_spec.py @@ -1,4 +1,5 @@ from __future__ import annotations + from dataclasses import dataclass from typing import Literal @@ -12,7 +13,7 @@ class PatientSpec: gfr_mL_min: float = 125.0 cardiac_output_L_h: float = 390.0 child_pugh: Literal["normal", "A", "B", "C"] = "normal" - cyp3a4_activity: float = 1.0 # relative to EM + cyp3a4_activity: float = 1.0 # relative to EM cyp2d6_activity: float = 1.0 cyp2c9_activity: float = 1.0 hepatic_cl_factor: float = 1.0 diff --git a/src/omega_pbpk/contracts/simulation_io.py b/src/omega_pbpk/contracts/simulation_io.py index ec981491..22d3df0a 100644 --- a/src/omega_pbpk/contracts/simulation_io.py +++ b/src/omega_pbpk/contracts/simulation_io.py @@ -1,6 +1,8 @@ from __future__ import annotations + from dataclasses import dataclass, field from typing import Literal + import numpy as np from numpy.typing import NDArray @@ -38,12 +40,12 @@ class PKMetrics: @dataclass(frozen=True) class ADMEOutput: - Fa: float = 1.0 # fraction absorbed - Fg: float = 1.0 # fraction surviving gut - Fh: float = 1.0 # fraction surviving liver + Fa: float = 1.0 # fraction absorbed + Fg: float = 1.0 # fraction surviving gut + Fh: float = 1.0 # fraction surviving liver CLint: float = 0.0 # intrinsic clearance L/h - fu: float = 0.5 # fraction unbound - Vd: float = 30.0 # volume of distribution L + fu: float = 0.5 # fraction unbound + Vd: float = 30.0 # volume of distribution L confidence: Literal["low", "medium", "high"] = "low" diff --git a/src/omega_pbpk/engine/interface.py b/src/omega_pbpk/engine/interface.py index 27f5b934..5fdd6a3b 100644 --- a/src/omega_pbpk/engine/interface.py +++ b/src/omega_pbpk/engine/interface.py @@ -15,7 +15,7 @@ class SimulationResult: """ODE 시뮬레이션 결과.""" t: NDArray[np.float64] - amounts: NDArray[np.float64] # shape (n_states, n_timepoints) + amounts: NDArray[np.float64] # shape (n_states, n_timepoints) plasma_concentration: NDArray[np.float64] # shape (n_timepoints,) drug_name: str = "" route: str = "oral" diff --git a/src/omega_pbpk/plugins/__init__.py b/src/omega_pbpk/plugins/__init__.py index 3f4b4d92..0345b1e3 100644 --- a/src/omega_pbpk/plugins/__init__.py +++ b/src/omega_pbpk/plugins/__init__.py @@ -1,5 +1,5 @@ -from omega_pbpk.plugins.base import PluginBase, SurrogateModelPlugin from omega_pbpk.plugins.adme_plugin import ADMEPredictorPlugin +from omega_pbpk.plugins.base import PluginBase, SurrogateModelPlugin from omega_pbpk.plugins.heuristic_kp import HeuristicKpPlugin from omega_pbpk.plugins.parameter_net import ParameterNetPlugin diff --git a/src/omega_pbpk/plugins/adme_plugin.py b/src/omega_pbpk/plugins/adme_plugin.py index cda11528..b6b20bfd 100644 --- a/src/omega_pbpk/plugins/adme_plugin.py +++ b/src/omega_pbpk/plugins/adme_plugin.py @@ -1,7 +1,7 @@ from __future__ import annotations -from omega_pbpk.plugins.base import PluginBase from omega_pbpk.contracts import DrugSpec +from omega_pbpk.plugins.base import PluginBase class ADMEPredictorPlugin(PluginBase): @@ -12,6 +12,7 @@ class ADMEPredictorPlugin(PluginBase): def __init__(self) -> None: from omega_pbpk.prediction.adme_predictor import ADMEPredictor + self._predictor = ADMEPredictor() def predict(self, spec: DrugSpec) -> dict[str, float | dict]: @@ -21,13 +22,15 @@ def predict(self, spec: DrugSpec) -> dict[str, float | dict]: props = self._predictor.predict(smiles) else: # SMILES 없을 때: DrugSpec의 physicochemical 값으로 fallback - props = self._predictor.predict_from_dict({ - "mw": spec.mw, - "logP": spec.logP, - "fup": spec.fup, - "rbp": spec.rbp, - "peff": spec.peff, - }) + props = self._predictor.predict_from_dict( + { + "mw": spec.mw, + "logP": spec.logP, + "fup": spec.fup, + "rbp": spec.rbp, + "peff": spec.peff, + } + ) return { "fup": float(props.fup), diff --git a/src/omega_pbpk/plugins/heuristic_kp.py b/src/omega_pbpk/plugins/heuristic_kp.py index d2374d4e..4556d80c 100644 --- a/src/omega_pbpk/plugins/heuristic_kp.py +++ b/src/omega_pbpk/plugins/heuristic_kp.py @@ -2,8 +2,8 @@ from typing import Literal -from omega_pbpk.plugins.base import PluginBase from omega_pbpk.contracts import DrugSpec +from omega_pbpk.plugins.base import PluginBase class HeuristicKpPlugin(PluginBase): @@ -12,16 +12,15 @@ class HeuristicKpPlugin(PluginBase): name = "heuristic_kp" provides = frozenset({"kp"}) - def __init__( - self, method: Literal["poulin_theil", "rodgers_rowland"] = "poulin_theil" - ) -> None: + def __init__(self, method: Literal["poulin_theil", "rodgers_rowland"] = "poulin_theil") -> None: self.method = method def predict(self, spec: DrugSpec) -> dict[str, float | dict]: pka = spec.pka[0] if spec.pka else None if self.method == "rodgers_rowland": - from omega_pbpk.core.heuristics import rodgers_rowland_kp, _TISSUE_FACTORS + from omega_pbpk.core.heuristics import _TISSUE_FACTORS, rodgers_rowland_kp + kp_dict: dict[str, float] = { tissue: rodgers_rowland_kp( logP=spec.logP, @@ -35,6 +34,7 @@ def predict(self, spec: DrugSpec) -> dict[str, float | dict]: else: # poulin_theil (default) from omega_pbpk.core.heuristics import estimate_all_kp + kp_dict = estimate_all_kp( logP=spec.logP, pka=pka, diff --git a/src/omega_pbpk/plugins/parameter_net.py b/src/omega_pbpk/plugins/parameter_net.py index 1bdbc843..f968a56a 100644 --- a/src/omega_pbpk/plugins/parameter_net.py +++ b/src/omega_pbpk/plugins/parameter_net.py @@ -3,6 +3,7 @@ MVP 구현: PGx CYP scaling + 체중/연령 allometric scaling. 추후 ML NN으로 교체 가능한 PluginBase 구조. """ + from __future__ import annotations from omega_pbpk.contracts.drug_spec import DrugSpec diff --git a/src/omega_pbpk/surrogate/__init__.py b/src/omega_pbpk/surrogate/__init__.py index c22e5651..b52422e4 100644 --- a/src/omega_pbpk/surrogate/__init__.py +++ b/src/omega_pbpk/surrogate/__init__.py @@ -59,21 +59,33 @@ class PKSurrogate: # 18D extended feature lists PATIENT_FEATURES: ClassVar[list[str]] = [ - "body_weight_kg", "age_years", "sex_binary", - "gfr_mL_min", "cardiac_output_L_h", - "cyp3a4_activity", "cyp2d6_activity", - "hepatic_cl_factor", "renal_cl_factor", + "body_weight_kg", + "age_years", + "sex_binary", + "gfr_mL_min", + "cardiac_output_L_h", + "cyp3a4_activity", + "cyp2d6_activity", + "hepatic_cl_factor", + "renal_cl_factor", ] REGIMEN_FEATURES: ClassVar[list[str]] = [ - "dose_mg", "route_binary", "n_doses", + "dose_mg", + "route_binary", + "n_doses", ] FULL_FEATURES: ClassVar[list[str]] = ( ["logP", "fup", "clint_L_h", "mw", "rbp", "peff"] + [ - "body_weight_kg", "age_years", "sex_binary", - "gfr_mL_min", "cardiac_output_L_h", - "cyp3a4_activity", "cyp2d6_activity", - "hepatic_cl_factor", "renal_cl_factor", + "body_weight_kg", + "age_years", + "sex_binary", + "gfr_mL_min", + "cardiac_output_L_h", + "cyp3a4_activity", + "cyp2d6_activity", + "hepatic_cl_factor", + "renal_cl_factor", ] + ["dose_mg", "route_binary", "n_doses"] ) diff --git a/src/omega_pbpk/surrogate/data_generator.py b/src/omega_pbpk/surrogate/data_generator.py index 2e20d85a..f730d158 100644 --- a/src/omega_pbpk/surrogate/data_generator.py +++ b/src/omega_pbpk/surrogate/data_generator.py @@ -244,12 +244,12 @@ def _run_single_simulation( # Grid parameter ranges for generate_grid_dataset() — wider space, includes mw. # Feature order matches PKSurrogate.EXPECTED_FEATURES: [logP, fup, clint_L_h, mw, rbp, peff] GRID_PARAM_RANGES: dict[str, tuple[float, float]] = { - "logP": (-2.0, 6.0), # linear (already in log scale) - "fup": (0.01, 1.0), # log-uniform + "logP": (-2.0, 6.0), # linear (already in log scale) + "fup": (0.01, 1.0), # log-uniform "clint_L_h": (1.0, 500.0), # log-uniform → clint_hepatic_L_per_h - "mw": (100.0, 800.0), # log-uniform - "rbp": (0.5, 5.0), # log-uniform - "peff": (0.5, 10.0), # log-uniform + "mw": (100.0, 800.0), # log-uniform + "rbp": (0.5, 5.0), # log-uniform + "peff": (0.5, 10.0), # log-uniform } GRID_PARAM_NAMES: list[str] = list(GRID_PARAM_RANGES.keys()) @@ -358,12 +358,10 @@ def generate_grid_dataset( X[:, j] = np.exp(np.log(lo) + X_unit[:, j] * (np.log(hi) - np.log(lo))) all_params = [ - {name: float(X[i, j]) for j, name in enumerate(GRID_PARAM_NAMES)} - for i in range(n_samples) + {name: float(X[i, j]) for j, name in enumerate(GRID_PARAM_NAMES)} for i in range(n_samples) ] worker_args = [ - (params, dose_mg, route, body_weight, t_end_h, i) - for i, params in enumerate(all_params) + (params, dose_mg, route, body_weight, t_end_h, i) for i, params in enumerate(all_params) ] pk_rows: list[list[float] | None] = [None] * n_samples @@ -381,9 +379,7 @@ def generate_grid_dataset( max_workers = min(n_workers, n_samples) completed = 0 with ProcessPoolExecutor(max_workers=max_workers) as pool: - future_to_idx = { - pool.submit(_grid_sim_worker, args): args[-1] for args in worker_args - } + future_to_idx = {pool.submit(_grid_sim_worker, args): args[-1] for args in worker_args} for future in as_completed(future_to_idx): idx, row = future.result() if row is not None: diff --git a/src/omega_pbpk/surrogate/train.py b/src/omega_pbpk/surrogate/train.py index 675beb85..efa34825 100644 --- a/src/omega_pbpk/surrogate/train.py +++ b/src/omega_pbpk/surrogate/train.py @@ -284,15 +284,14 @@ def _r2(y_true: NDArray, y_pred: NDArray) -> float: return 1.0 - ss_res / (ss_tot + 1e-12) y_pred_test = model.predict(X_test) - auc_r2 = _r2(y_test[:, 1], y_pred_test[:, 1]) # AUC column + auc_r2 = _r2(y_test[:, 1], y_pred_test[:, 1]) # AUC column cmax_r2 = _r2(y_test[:, 0], y_pred_test[:, 0]) # Cmax column _Path(output_dir).mkdir(parents=True, exist_ok=True) model.save(output_dir) logger.info( - "train_with_grid_data: n_train=%d n_val=%d n_test=%d " - "loss=%.4f auc_r2=%.3f cmax_r2=%.3f", + "train_with_grid_data: n_train=%d n_val=%d n_test=%d loss=%.4f auc_r2=%.3f cmax_r2=%.3f", n_train, n_val, len(X_test), diff --git a/src/omega_pbpk/validation/__init__.py b/src/omega_pbpk/validation/__init__.py index a86dbbac..2ed0eb36 100644 --- a/src/omega_pbpk/validation/__init__.py +++ b/src/omega_pbpk/validation/__init__.py @@ -183,7 +183,9 @@ def oral_mass_balance_check( total_mass = float(np.sum(amounts[check_idx])) deviation_frac = abs(total_mass - dose_mg) / dose_mg - t_val = float(time_h[check_idx]) if time_h is not None and len(time_h) > abs(check_idx) else None + t_val = ( + float(time_h[check_idx]) if time_h is not None and len(time_h) > abs(check_idx) else None + ) if deviation_frac > tolerance_frac: t_info = f" at t={t_val:.2f}h" if t_val is not None else "" diff --git a/src/omega_pbpk/validation/_param_guard.py b/src/omega_pbpk/validation/_param_guard.py index 5e36af1f..ae84e208 100644 --- a/src/omega_pbpk/validation/_param_guard.py +++ b/src/omega_pbpk/validation/_param_guard.py @@ -10,13 +10,13 @@ # Physical plausibility bounds for PBPK parameters _BOUNDS: dict[str, tuple[float, float]] = { - "clint_hepatic_L_per_h": (0.0, 500.0), # max ~hepatic blood flow ceiling + "clint_hepatic_L_per_h": (0.0, 500.0), # max ~hepatic blood flow ceiling "clint_gut_L_per_h": (0.0, 200.0), "fup": (1e-4, 1.0), "rbp": (0.1, 20.0), "mw": (50.0, 2000.0), "logP": (-6.0, 12.0), - "peff": (1e-6, 50.0), # cm/s × 1e-4 scale typical range + "peff": (1e-6, 50.0), # cm/s × 1e-4 scale typical range "ka_per_h": (1e-4, 100.0), } @@ -35,9 +35,7 @@ class ParamViolation: upper: float def __str__(self) -> str: - return ( - f"{self.param}={self.value:.4g} outside [{self.lower:.4g}, {self.upper:.4g}]" - ) + return f"{self.param}={self.value:.4g} outside [{self.lower:.4g}, {self.upper:.4g}]" def check_drug_params( @@ -119,12 +117,12 @@ def check_drug_params( # validate_drug_params — legacy-compatible gateway (positional-arg style) # --------------------------------------------------------------------------- -CLR_MAX_L_H = 200.0 # Approximate maximum renal blood flow -CLINT_MAX_L_H = 5000.0 # Realistic maximum CLint +CLR_MAX_L_H = 200.0 # Approximate maximum renal blood flow +CLINT_MAX_L_H = 5000.0 # Realistic maximum CLint KP_MIN = 1e-4 KP_MAX = 1000.0 KA_MIN = 0.0 -KA_MAX = 100.0 # 1/h +KA_MAX = 100.0 # 1/h FUP_MIN = 1e-4 FUP_MAX = 1.0 RBP_MIN = 0.1 diff --git a/tests/conftest.py b/tests/conftest.py index a90db654..f378735f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,3 @@ -import pytest - - def pytest_configure(config): config.addinivalue_line("markers", "slow: slow tests requiring training") config.addinivalue_line("markers", "benchmark: benchmark tests") diff --git a/tests/integration/test_plugin_pipeline.py b/tests/integration/test_plugin_pipeline.py index 8328e1cf..40196d17 100644 --- a/tests/integration/test_plugin_pipeline.py +++ b/tests/integration/test_plugin_pipeline.py @@ -1,10 +1,12 @@ """Integration: Plugin → DrugSpec → SimulationEngine E2E.""" -import pytest + import numpy as np +import pytest + from omega_pbpk.adapters.yaml_loader import load_drug_spec_by_name -from omega_pbpk.plugins import HeuristicKpPlugin -from omega_pbpk.engine import WholeBodyPBPKEngine from omega_pbpk.contracts import PatientSpec, Regimen +from omega_pbpk.engine import WholeBodyPBPKEngine +from omega_pbpk.plugins import HeuristicKpPlugin @pytest.fixture(scope="module") diff --git a/tests/integration/test_surrogate_vs_ode.py b/tests/integration/test_surrogate_vs_ode.py index ff6474a5..cb191627 100644 --- a/tests/integration/test_surrogate_vs_ode.py +++ b/tests/integration/test_surrogate_vs_ode.py @@ -23,15 +23,15 @@ # Constants # --------------------------------------------------------------------------- -AUC_RE_MAX = 0.20 # surrogate AUC must be within 20% of ODE AUC +AUC_RE_MAX = 0.20 # surrogate AUC must be within 20% of ODE AUC CMAX_RE_MAX = 0.25 # surrogate Cmax must be within 25% of ODE Cmax # Canonical parameter set for test drugs (input feature order matches EXPECTED_FEATURES) # [logP, fup, clint_L_h, mw, rbp, peff] TEST_DRUGS = { - "caffeine": [0.07, 0.65, 2.0, 194.19, 0.8, 4.0], - "midazolam": [3.89, 0.03, 60.0, 325.77, 0.53, 2.0], - "propranolol":[3.48, 0.13, 70.0, 259.34, 0.81, 3.0], + "caffeine": [0.07, 0.65, 2.0, 194.19, 0.8, 4.0], + "midazolam": [3.89, 0.03, 60.0, 325.77, 0.53, 2.0], + "propranolol": [3.48, 0.13, 70.0, 259.34, 0.81, 3.0], } @@ -39,6 +39,7 @@ # Helpers # --------------------------------------------------------------------------- + def _run_ode(params: list[float], dose_mg: float = 100.0) -> tuple[float, float]: """Run ODE simulation for given params, return (Cmax, AUC).""" from omega_pbpk._compat import np_trapz @@ -65,7 +66,7 @@ def _run_ode(params: list[float], dose_mg: float = 100.0) -> tuple[float, float] return cmax, auc -def _train_surrogate_on_drugs(drug_params: list[list[float]], n_aug: int = 15) -> "PKSurrogate": # noqa: F821 +def _train_surrogate_on_drugs(drug_params: list[list[float]], n_aug: int = 15) -> PKSurrogate: # noqa: F821 """Train a minimal surrogate on the test drugs with augmentation.""" from omega_pbpk.surrogate import PKSurrogate from omega_pbpk.surrogate.train import build_training_dataset @@ -82,6 +83,7 @@ def _train_surrogate_on_drugs(drug_params: list[list[float]], n_aug: int = 15) - # Contract: surrogate EXPECTED_FEATURES matches ODE input convention # --------------------------------------------------------------------------- + class TestFeatureOrderContract: def test_surrogate_feature_order_matches_train_order(self): """EXPECTED_FEATURES must match the column order in train._params_to_array.""" @@ -90,8 +92,12 @@ def test_surrogate_feature_order_matches_train_order(self): expected = PKSurrogate.EXPECTED_FEATURES ref_drug = { - "logP": 2.0, "fup": 0.5, "clint": 10.0, - "mw": 300.0, "rbp": 1.0, "peff": 2.0, + "logP": 2.0, + "fup": 0.5, + "clint": 10.0, + "mw": 300.0, + "rbp": 1.0, + "peff": 2.0, } arr = _params_to_array(ref_drug) for i, feat_name in enumerate(expected): @@ -106,10 +112,12 @@ def test_surrogate_feature_order_matches_train_order(self): def test_expected_features_count_matches_default_n_input(self): from omega_pbpk.surrogate import PKSurrogate + assert len(PKSurrogate.EXPECTED_FEATURES) == PKSurrogate().n_input def test_expected_outputs_count_matches_default_n_output(self): from omega_pbpk.surrogate import PKSurrogate + assert len(PKSurrogate.EXPECTED_OUTPUTS) == PKSurrogate().n_output @@ -117,6 +125,7 @@ def test_expected_outputs_count_matches_default_n_output(self): # ODE: baseline outputs for reference drugs # --------------------------------------------------------------------------- + class TestODEBaseline: def test_ode_caffeine_positive_cmax_auc(self): cmax, auc = _run_ode(TEST_DRUGS["caffeine"]) @@ -139,6 +148,7 @@ def test_ode_outputs_finite(self): # Surrogate vs ODE agreement (slow — requires training) # --------------------------------------------------------------------------- + @pytest.mark.slow class TestSurrogateVsODE: @pytest.fixture(scope="class") diff --git a/tests/regression/test_golden_snapshot.py b/tests/regression/test_golden_snapshot.py index 8df68022..432b63f1 100644 --- a/tests/regression/test_golden_snapshot.py +++ b/tests/regression/test_golden_snapshot.py @@ -1,4 +1,5 @@ """Regression: ODE output must not drift from golden snapshot > ±10%.""" + from __future__ import annotations import json diff --git a/tests/test_adme_plugin.py b/tests/test_adme_plugin.py index 282ecfbf..c30eaf97 100644 --- a/tests/test_adme_plugin.py +++ b/tests/test_adme_plugin.py @@ -2,8 +2,6 @@ from __future__ import annotations -import pytest - from omega_pbpk.contracts.drug_spec import DrugSpec from omega_pbpk.plugins.adme_plugin import ADMEPredictorPlugin from omega_pbpk.plugins.heuristic_kp import HeuristicKpPlugin @@ -34,9 +32,11 @@ # Import guard (works without torch) # --------------------------------------------------------------------------- + def test_import_plugins(): """플러그인이 torch 없이도 import 성공해야 한다.""" from omega_pbpk.plugins import ADMEPredictorPlugin, HeuristicKpPlugin # noqa: F401 + assert ADMEPredictorPlugin is not None assert HeuristicKpPlugin is not None @@ -45,6 +45,7 @@ def test_import_plugins(): # ADMEPredictorPlugin # --------------------------------------------------------------------------- + class TestADMEPredictorPlugin: def setup_method(self): self.plugin = ADMEPredictorPlugin() @@ -96,6 +97,7 @@ def test_confidence_without_smiles(self): # HeuristicKpPlugin # --------------------------------------------------------------------------- + class TestHeuristicKpPlugin: def setup_method(self): self.plugin = HeuristicKpPlugin() diff --git a/tests/test_api_m3.py b/tests/test_api_m3.py index 02ab7c79..5aae91b8 100644 --- a/tests/test_api_m3.py +++ b/tests/test_api_m3.py @@ -1,10 +1,12 @@ """Tests for M3-2 FastAPI integration — new endpoints and validators.""" + from __future__ import annotations import pytest try: from omega_pbpk.api.app import app + HAS_FASTAPI = True except ImportError: HAS_FASTAPI = False @@ -13,6 +15,7 @@ if HAS_FASTAPI: from fastapi.testclient import TestClient + client = TestClient(app) else: client = None # type: ignore[assignment] diff --git a/tests/test_cli_m3.py b/tests/test_cli_m3.py index 1eb7caeb..00c6c5c4 100644 --- a/tests/test_cli_m3.py +++ b/tests/test_cli_m3.py @@ -36,9 +36,7 @@ def test_simulate_single_help(): def test_simulate_single_json(): - result = runner.invoke( - app, ["simulate", "single", "compounds/caffeine.yaml", "--json"] - ) + result = runner.invoke(app, ["simulate", "single", "compounds/caffeine.yaml", "--json"]) assert result.exit_code == 0, result.output data = json.loads(result.output) assert "Cmax" in data or "cmax" in data or len(data) > 0 diff --git a/tests/test_contracts.py b/tests/test_contracts.py index f6f06229..d976c47c 100644 --- a/tests/test_contracts.py +++ b/tests/test_contracts.py @@ -1,13 +1,14 @@ """Unit tests for omega_pbpk.contracts module.""" + import pytest from omega_pbpk.contracts import DrugSpec, PatientSpec, Regimen - # --------------------------------------------------------------------------- # DrugSpec.validate() # --------------------------------------------------------------------------- + class TestDrugSpecValidate: def test_valid_default(self): DrugSpec(name="test").validate() @@ -71,6 +72,7 @@ def test_rbp_negative_raises(self): # PatientSpec.validate() # --------------------------------------------------------------------------- + class TestPatientSpecValidate: def test_valid_default(self): PatientSpec().validate() @@ -106,6 +108,7 @@ def test_cardiac_output_zero_raises(self): # Regimen.validate() # --------------------------------------------------------------------------- + class TestRegimenValidate: def test_valid(self): Regimen(dose_mg=100).validate() diff --git a/tests/test_param_guard.py b/tests/test_param_guard.py index ecc8e9ea..21aa5b1f 100644 --- a/tests/test_param_guard.py +++ b/tests/test_param_guard.py @@ -1,4 +1,5 @@ import pytest + from omega_pbpk.validation._param_guard import validate_drug_params @@ -24,13 +25,21 @@ def test_fup_above_one_raises(): def test_negative_kp_raises(): with pytest.raises(ValueError, match="kp"): - validate_drug_params(clint_hepatic=0, clint_gut=0, clr=0, fup=0.5, rbp=1.0, - kp={"liver": -0.1}) + validate_drug_params( + clint_hepatic=0, clint_gut=0, clr=0, fup=0.5, rbp=1.0, kp={"liver": -0.1} + ) def test_valid_params_pass(): - validate_drug_params(clint_hepatic=10.0, clint_gut=5.0, clr=2.0, - fup=0.1, rbp=0.9, kp={"liver": 2.0, "muscle": 0.5}, ka=1.5) + validate_drug_params( + clint_hepatic=10.0, + clint_gut=5.0, + clr=2.0, + fup=0.1, + rbp=0.9, + kp={"liver": 2.0, "muscle": 0.5}, + ka=1.5, + ) def test_extreme_clint_raises(): diff --git a/tests/test_plugins_base.py b/tests/test_plugins_base.py index 22620b76..e5beb4d8 100644 --- a/tests/test_plugins_base.py +++ b/tests/test_plugins_base.py @@ -8,11 +8,11 @@ from omega_pbpk.engine import SimulationEngine, SimulationResult, WholeBodyPBPKEngine from omega_pbpk.plugins import PluginBase, SurrogateModelPlugin - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _caffeine_spec() -> DrugSpec: return DrugSpec( name="caffeine", @@ -37,6 +37,7 @@ def _oral_regimen(dose_mg: float = 200.0) -> Regimen: # PluginBase tests # --------------------------------------------------------------------------- + class _GoodPlugin(PluginBase): @property def name(self) -> str: @@ -94,6 +95,7 @@ def test_surrogate_model_plugin_is_protocol(): # SimulationEngine ABC tests # --------------------------------------------------------------------------- + def test_simulation_engine_cannot_be_instantiated(): with pytest.raises(TypeError): SimulationEngine() # type: ignore[abstract] @@ -108,18 +110,24 @@ def test_whole_body_pbpk_engine_is_instance_of_simulation_engine(): # Import tests # --------------------------------------------------------------------------- + def test_import_plugins(): from omega_pbpk.plugins import PluginBase, SurrogateModelPlugin # noqa: F401 def test_import_engine(): - from omega_pbpk.engine import SimulationEngine, SimulationResult, WholeBodyPBPKEngine # noqa: F401 + from omega_pbpk.engine import ( # noqa: F401 + SimulationEngine, + SimulationResult, + WholeBodyPBPKEngine, + ) # --------------------------------------------------------------------------- # WholeBodyPBPKEngine functional test # --------------------------------------------------------------------------- + def test_engine_run_caffeine_oral(): engine = WholeBodyPBPKEngine() drug = _caffeine_spec() diff --git a/tests/test_population_adapter.py b/tests/test_population_adapter.py index c1112465..100f14f3 100644 --- a/tests/test_population_adapter.py +++ b/tests/test_population_adapter.py @@ -1,4 +1,5 @@ """Tests for population_adapter and ParameterNetPlugin.""" + from __future__ import annotations import pytest @@ -11,11 +12,11 @@ from omega_pbpk.contracts.patient_spec import PatientSpec from omega_pbpk.plugins import ParameterNetPlugin - # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture def base_drug() -> DrugSpec: return DrugSpec( @@ -41,6 +42,7 @@ def pm_patient() -> PatientSpec: # sample_patient_spec # --------------------------------------------------------------------------- + def test_sample_patient_spec_count(): specs = sample_patient_spec(n=5, seed=0) assert len(specs) == 5 @@ -74,6 +76,7 @@ def test_sample_patient_spec_positive_bw(): # subject_covariates_to_patient_spec field mapping # --------------------------------------------------------------------------- + def test_sex_mapping_male(): from omega_pbpk.population.physiology import SubjectCovariates @@ -150,6 +153,7 @@ def test_renal_cl_factor_normal_gfr(): # ParameterNetPlugin.predict() # --------------------------------------------------------------------------- + def test_parameter_net_plugin_provides_keys(em_patient, base_drug): plugin = ParameterNetPlugin(patient=em_patient) result = plugin.predict(base_drug) @@ -180,15 +184,14 @@ def test_parameter_net_apply_updates_drug(em_patient, base_drug): updated = plugin.apply(base_drug) assert isinstance(updated, DrugSpec) # Reference patient at 70 kg + EM: scales should be ~1.0 - assert updated.clint_hepatic_L_per_h == pytest.approx( - base_drug.clint_hepatic_L_per_h, rel=0.05 - ) + assert updated.clint_hepatic_L_per_h == pytest.approx(base_drug.clint_hepatic_L_per_h, rel=0.05) # --------------------------------------------------------------------------- # ParameterNetPlugin.apply_pgx() # --------------------------------------------------------------------------- + def test_apply_pgx_pm_em_um_differ(base_drug): plugin = ParameterNetPlugin() @@ -218,6 +221,8 @@ def test_apply_pgx_em_full_clint(base_drug): # Import check # --------------------------------------------------------------------------- + def test_parameter_net_plugin_importable(): from omega_pbpk.plugins import ParameterNetPlugin as PNP # noqa: F401 + assert PNP is ParameterNetPlugin diff --git a/tests/test_surrogate_extensions.py b/tests/test_surrogate_extensions.py index 039b9c10..cd34f5a4 100644 --- a/tests/test_surrogate_extensions.py +++ b/tests/test_surrogate_extensions.py @@ -1,5 +1,6 @@ import numpy as np import pytest + from omega_pbpk.surrogate import PKSurrogate from omega_pbpk.surrogate.data_generator import generate_training_data @@ -41,7 +42,7 @@ def test_conformal_calibration_coverage(): # test coverage on calibration set (should be ~90%) covered = 0 - for x, y_true_norm in zip(X_cal, y_cal): + for x, y_true_norm in zip(X_cal, y_cal, strict=False): y_true_orig = y_true_norm * model.y_std + model.y_mean y_pred, ci_lo, ci_hi = model.predict_conformal(x, alpha=0.10) if np.all(ci_lo <= y_true_orig) and np.all(y_true_orig <= ci_hi): diff --git a/tests/test_yaml_loader.py b/tests/test_yaml_loader.py index c1e6357b..61cb8f9d 100644 --- a/tests/test_yaml_loader.py +++ b/tests/test_yaml_loader.py @@ -1,17 +1,18 @@ """Tests for YAML → DrugSpec adapter (M0-3).""" + from __future__ import annotations -import pytest from pathlib import Path +import pytest + from omega_pbpk.adapters.yaml_loader import ( - load_drug_spec_by_name, - load_drug_spec, drug_to_spec, + load_drug_spec, + load_drug_spec_by_name, spec_to_drug, ) from omega_pbpk.contracts import DrugSpec -from omega_pbpk.drugs.drug import Drug REPO_ROOT = Path(__file__).parent.parent COMPOUNDS_DIR = REPO_ROOT / "compounds" diff --git a/tests/unit/plugins/test_param_plugin.py b/tests/unit/plugins/test_param_plugin.py index a69560c9..09e96088 100644 --- a/tests/unit/plugins/test_param_plugin.py +++ b/tests/unit/plugins/test_param_plugin.py @@ -10,7 +10,6 @@ from omega_pbpk.validation._param_guard import ParamViolation, check_drug_params - # --------------------------------------------------------------------------- # Fixtures: valid & invalid param sets # --------------------------------------------------------------------------- @@ -32,6 +31,7 @@ # Happy-path: all valid params pass without exception # --------------------------------------------------------------------------- + class TestValidParams: def test_valid_params_no_violations(self): violations = check_drug_params(**VALID_PARAMS, raise_on_violation=False) @@ -58,6 +58,7 @@ def test_no_params_no_violations(self): # fup boundary cases # --------------------------------------------------------------------------- + class TestFupBounds: def test_fup_zero_raises(self): with pytest.raises(ValueError, match="fup"): @@ -82,6 +83,7 @@ def test_fup_minimum_boundary_valid(self): # CLint boundary cases # --------------------------------------------------------------------------- + class TestClintBounds: def test_negative_clint_hepatic_raises(self): with pytest.raises(ValueError, match="clint_hepatic_L_per_h"): @@ -102,6 +104,7 @@ def test_clint_at_max_boundary_valid(self): # Kp boundary cases # --------------------------------------------------------------------------- + class TestKpBounds: def test_negative_kp_raises(self): with pytest.raises(ValueError, match="kp\\[liver\\]"): @@ -126,6 +129,7 @@ def test_kp_at_max_boundary_valid(self): # Multiple violations reported together # --------------------------------------------------------------------------- + class TestMultipleViolations: def test_multiple_violations_all_reported(self): violations = check_drug_params( @@ -156,6 +160,7 @@ def test_violation_str_format(self): # MW, logP, rbp bounds # --------------------------------------------------------------------------- + class TestOtherBounds: def test_low_mw_raises(self): with pytest.raises(ValueError, match="mw"): diff --git a/tests/unit/plugins/test_plugin_contracts.py b/tests/unit/plugins/test_plugin_contracts.py index c94ac54d..b4b6dd9f 100644 --- a/tests/unit/plugins/test_plugin_contracts.py +++ b/tests/unit/plugins/test_plugin_contracts.py @@ -1,8 +1,10 @@ """Plugin 공통 계약 단위 테스트.""" + import pytest + +from omega_pbpk.adapters.yaml_loader import load_drug_spec_by_name from omega_pbpk.plugins import ADMEPredictorPlugin, HeuristicKpPlugin from omega_pbpk.plugins.base import PluginBase -from omega_pbpk.adapters.yaml_loader import load_drug_spec_by_name @pytest.fixture(scope="module") @@ -60,6 +62,7 @@ def test_rodgers_rowland(self, caffeine_spec): class TestPluginBaseContract: def test_provides_mismatch_raises(self, caffeine_spec): """predict()가 provides와 다른 키 반환 시 TypeError.""" + class BrokenPlugin(PluginBase): name = "broken" provides = frozenset({"nonexistent_field"}) diff --git a/tests/unit/plugins/test_surrogate_plugin.py b/tests/unit/plugins/test_surrogate_plugin.py index 902a7fb2..f4dce44c 100644 --- a/tests/unit/plugins/test_surrogate_plugin.py +++ b/tests/unit/plugins/test_surrogate_plugin.py @@ -12,11 +12,11 @@ from omega_pbpk.surrogate import PKSurrogate - # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- + @pytest.fixture def untrained_surrogate() -> PKSurrogate: """Default PKSurrogate with He-initialized (random) weights.""" @@ -30,12 +30,14 @@ def trained_surrogate() -> PKSurrogate: n = 80 X = rng.uniform([[-2, 0.01, 0.1, 100, 0.5, 0.5]], [[6, 1.0, 200, 800, 5.0, 10.0]], (n, 6)) # Synthetic PK outputs: Cmax, AUC, Tmax, t_half (all positive) - y = np.column_stack([ - rng.uniform(0.01, 10, n), # Cmax - rng.uniform(0.1, 100, n), # AUC - rng.uniform(0.25, 6, n), # Tmax - rng.uniform(1, 20, n), # t_half - ]) + y = np.column_stack( + [ + rng.uniform(0.01, 10, n), # Cmax + rng.uniform(0.1, 100, n), # AUC + rng.uniform(0.25, 6, n), # Tmax + rng.uniform(1, 20, n), # t_half + ] + ) s = PKSurrogate(n_input=6, n_output=4) s.param_names = PKSurrogate.EXPECTED_FEATURES s.output_names = PKSurrogate.EXPECTED_OUTPUTS @@ -47,6 +49,7 @@ def trained_surrogate() -> PKSurrogate: # T04a: EXPECTED_FEATURES class constant # --------------------------------------------------------------------------- + class TestExpectedFeaturesContract: def test_expected_features_is_class_attribute(self): assert hasattr(PKSurrogate, "EXPECTED_FEATURES") @@ -55,17 +58,13 @@ def test_expected_features_length(self): assert len(PKSurrogate.EXPECTED_FEATURES) == 6 def test_expected_features_names(self): - assert PKSurrogate.EXPECTED_FEATURES == [ - "logP", "fup", "clint_L_h", "mw", "rbp", "peff" - ] + assert PKSurrogate.EXPECTED_FEATURES == ["logP", "fup", "clint_L_h", "mw", "rbp", "peff"] def test_expected_outputs_is_class_attribute(self): assert hasattr(PKSurrogate, "EXPECTED_OUTPUTS") def test_expected_outputs_names(self): - assert PKSurrogate.EXPECTED_OUTPUTS == [ - "cmax_mg_L", "auc_mg_h_L", "tmax_h", "t_half_h" - ] + assert PKSurrogate.EXPECTED_OUTPUTS == ["cmax_mg_L", "auc_mg_h_L", "tmax_h", "t_half_h"] def test_expected_features_match_n_input_default(self): """Default n_input must equal len(EXPECTED_FEATURES).""" @@ -81,6 +80,7 @@ def test_expected_outputs_match_n_output_default(self): # T04b: validate_feature_contract # --------------------------------------------------------------------------- + class TestValidateFeatureContract: def test_valid_params_pass(self, untrained_surrogate): valid = {"logP": 2.0, "fup": 0.5, "clint_L_h": 10.0, "mw": 300.0, "rbp": 1.0, "peff": 2.0} @@ -101,15 +101,27 @@ def test_missing_multiple_keys_listed_in_error(self, untrained_surrogate): def test_extra_keys_allowed(self, untrained_surrogate): """Extra keys from ADME plugin must not cause failures.""" extra = { - "logP": 2.0, "fup": 0.5, "clint_L_h": 10.0, - "mw": 300.0, "rbp": 1.0, "peff": 2.0, - "logD": 1.5, "pka": 7.4, # extra keys + "logP": 2.0, + "fup": 0.5, + "clint_L_h": 10.0, + "mw": 300.0, + "rbp": 1.0, + "peff": 2.0, + "logD": 1.5, + "pka": 7.4, # extra keys } untrained_surrogate.validate_feature_contract(extra) # no exception def test_wrong_feature_name_raises(self, untrained_surrogate): """logD instead of logP should raise (missing logP).""" - wrong_names = {"logD": 2.0, "fup": 0.5, "clint_L_h": 10.0, "mw": 300.0, "rbp": 1.0, "peff": 2.0} + wrong_names = { + "logD": 2.0, + "fup": 0.5, + "clint_L_h": 10.0, + "mw": 300.0, + "rbp": 1.0, + "peff": 2.0, + } with pytest.raises(ValueError, match="logP"): untrained_surrogate.validate_feature_contract(wrong_names) @@ -118,6 +130,7 @@ def test_wrong_feature_name_raises(self, untrained_surrogate): # Non-negativity: outputs must always be ≥ 0 # --------------------------------------------------------------------------- + class TestNonNegativity: def test_predict_returns_nonnegative_single(self, trained_surrogate): x = np.array([2.0, 0.5, 10.0, 300.0, 1.0, 2.0]) @@ -132,10 +145,12 @@ def test_predict_returns_nonnegative_batch(self, trained_surrogate): def test_predict_with_extreme_inputs_nonnegative(self, trained_surrogate): """Even very extreme inputs must not produce negative outputs.""" - extremes = np.array([ - [10.0, 0.001, 0.001, 2000.0, 10.0, 0.001], # extreme high logP - [-5.0, 1.0, 500.0, 50.0, 0.1, 50.0], # extreme negative logP - ]) + extremes = np.array( + [ + [10.0, 0.001, 0.001, 2000.0, 10.0, 0.001], # extreme high logP + [-5.0, 1.0, 500.0, 50.0, 0.1, 50.0], # extreme negative logP + ] + ) y = trained_surrogate.predict(extremes) assert np.all(y >= 0) @@ -150,6 +165,7 @@ def test_untrained_predict_nonnegative(self, untrained_surrogate): # Reproducibility (determinism) # --------------------------------------------------------------------------- + class TestReproducibility: def test_same_input_same_output_trained(self, trained_surrogate): x = np.array([2.0, 0.5, 10.0, 300.0, 1.0, 2.0]) @@ -191,6 +207,7 @@ def test_training_with_same_seed_reproducible(self): # Output shape & structure # --------------------------------------------------------------------------- + class TestOutputShape: def test_predict_single_returns_1d_array(self, trained_surrogate): x = np.array([2.0, 0.5, 10.0, 300.0, 1.0, 2.0]) diff --git a/tests/unit/test_mass_balance.py b/tests/unit/test_mass_balance.py index b503edcd..fa93923e 100644 --- a/tests/unit/test_mass_balance.py +++ b/tests/unit/test_mass_balance.py @@ -13,11 +13,11 @@ from omega_pbpk.validation import mass_balance_check, oral_mass_balance_check - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _perfect_amounts(n_time: int, n_states: int, dose_mg: float) -> np.ndarray: """Create amounts array where mass is perfectly conserved (all in state 0).""" amounts = np.zeros((n_time, n_states)) @@ -40,6 +40,7 @@ def _depleting_amounts(n_time: int, n_states: int, dose_mg: float) -> np.ndarray # mass_balance_check (IV route) # --------------------------------------------------------------------------- + class TestMassBalanceCheckIV: def test_perfect_conservation_no_warnings(self): amounts = _perfect_amounts(100, 35, 100.0) @@ -126,6 +127,7 @@ def test_returns_list_type(self): # oral_mass_balance_check # --------------------------------------------------------------------------- + class TestOralMassBalanceCheck: def test_perfect_conservation_no_warnings(self): n_time, n_states, dose = 200, 35, 100.0 @@ -139,7 +141,7 @@ def test_skip_when_gi_residual_too_high(self): n_time, dose = 100, 100.0 time_h = np.linspace(0, 12, n_time) amounts = np.zeros((n_time, 35)) - amounts[:, 0] = dose * 0.3 # only 30% absorbed + amounts[:, 0] = dose * 0.3 # only 30% absorbed # GI still holds 70% of dose gi_states = np.ones((n_time, 8)) * dose * 0.7 / 8.0 result = oral_mass_balance_check( @@ -154,7 +156,7 @@ def test_checks_when_gi_absorbed(self): amounts = _perfect_amounts(n_time, 35, dose) # Simulate complete absorption: GI residual drops to near 0 gi_states = np.zeros((n_time, 8)) - gi_states[:5, :] = dose / 8.0 # GI drains completely by t≈5 + gi_states[:5, :] = dose / 8.0 # GI drains completely by t≈5 result = oral_mass_balance_check(amounts, dose, time_h, gi_states=gi_states) assert result == [] @@ -201,6 +203,7 @@ def test_returns_list_type(self): # Integration: IV simulation with real ODE # --------------------------------------------------------------------------- + class TestMassBalanceWithRealSimulation: def test_iv_simulation_conserves_mass(self): """Real 35-state ODE IV simulation must pass mass balance (±0.5%)."""