# 研究工作台（Notebook）

- 不启动 HTTP 服务，不调用 `app/api`，直接同进程导入 `app...` 执行。
- 支持两套方式：`ipywidgets` 参数面板 + 纯代码参数。
- 回测输出目录：`research/artifacts/backtests/<run_id>/`。


In [None]:
from __future__ import annotations

import json
import os
import pickle
import sys
import traceback
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any

from IPython.display import Markdown, display

try:
    import pandas as pd
except ModuleNotFoundError as exc:
    raise ModuleNotFoundError("缺少依赖 pandas，请先安装研究工作台依赖。") from exc

try:
    import matplotlib.pyplot as plt
except ModuleNotFoundError as exc:
    raise ModuleNotFoundError("缺少依赖 matplotlib，请先安装研究工作台依赖。") from exc

try:
    import ipywidgets as widgets
except ModuleNotFoundError as exc:
    raise ModuleNotFoundError("缺少依赖 ipywidgets，请先安装研究工作台依赖。") from exc

def _looks_like_backend_root(path: Path) -> bool:
    return (
        (path / "app" / "__init__.py").exists()
        and (path / "app" / "backtest" / "services" / "runner.py").exists()
    )

def _find_backend_root() -> Path:
    cwd = Path.cwd().resolve()
    candidates: list[Path] = []
    for base in [cwd, *cwd.parents]:
        candidates.append(base)
        candidates.append(base / "backtest")
    candidates.extend([Path("/home/app/backtest"), Path("/workspace/backtest")])

    seen: set[Path] = set()
    for candidate in candidates:
        candidate = candidate.resolve()
        if candidate in seen:
            continue
        seen.add(candidate)
        if _looks_like_backend_root(candidate):
            return candidate
    raise FileNotFoundError(
        "未找到后端项目根目录（需包含 app/backtest/services/runner.py）。"
    )

BACKEND_ROOT = _find_backend_root()
if str(BACKEND_ROOT) not in sys.path:
    sys.path.insert(0, str(BACKEND_ROOT))

from app import create_app
from app.backtest.services import runner as runner_service

try:
    from app.backtest.services import extractor as extractor_service
except Exception:
    extractor_service = None

FLASK_APP = create_app(os.environ.get("CONFIG_ENV", "default"))
display(Markdown(f"**Backend Root**: `{BACKEND_ROOT}`"))
display(Markdown("已加载后端模块（未启动 HTTP 服务）。"))


In [None]:
ALLOWED_FREQUENCIES = {"1d", "1m"}

def _default_output_root() -> Path:
    cwd = Path.cwd().resolve()
    for p in [cwd, *cwd.parents]:
        research_dir = p / "research"
        if research_dir.exists():
            return (research_dir / "artifacts" / "backtests").resolve()
    return (cwd / "research" / "artifacts" / "backtests").resolve()

def _coerce_path(raw: Any, *, field_name: str, required: bool) -> Path | None:
    if raw is None or str(raw).strip() == "":
        if required:
            raise ValueError(f"参数 {field_name} 不能为空")
        return None
    return Path(str(raw).strip()).expanduser().resolve()

def _normalize_params(params: dict[str, Any]) -> dict[str, Any]:
    if not isinstance(params, dict):
        raise TypeError("params 必须是 dict")

    strategy_path = _coerce_path(params.get("strategy_path"), field_name="strategy_path", required=True)
    assert strategy_path is not None
    if not strategy_path.exists() or not strategy_path.is_file():
        raise FileNotFoundError(f"策略文件不存在: {strategy_path}")

    start_date = str(params.get("start_date", "")).strip()
    end_date = str(params.get("end_date", "")).strip()
    if not start_date or not end_date:
        raise ValueError("start_date / end_date 不能为空，格式示例：2026-01-01")

    frequency = str(params.get("frequency", "1d")).strip().lower()
    if frequency not in ALLOWED_FREQUENCIES:
        raise ValueError(f"frequency 仅支持 {sorted(ALLOWED_FREQUENCIES)}，当前: {frequency}")

    try:
        init_cash = int(params.get("init_cash", 1000000))
    except Exception as exc:
        raise ValueError("init_cash 必须是整数") from exc
    if init_cash <= 0:
        raise ValueError("init_cash 必须 > 0")

    benchmark = str(params.get("benchmark", "")).strip()
    if not benchmark:
        raise ValueError("benchmark 不能为空")

    symbol_raw = params.get("symbol")
    symbol = str(symbol_raw).strip() if symbol_raw is not None else ""

    bundle_raw = params.get("bundle_path") or os.environ.get("RQALPHA_BUNDLE_PATH") or FLASK_APP.config.get("RQALPHA_BUNDLE_PATH")
    bundle_path = _coerce_path(bundle_raw, field_name="bundle_path", required=True)
    assert bundle_path is not None
    if not bundle_path.exists():
        raise FileNotFoundError(
            f"bundle_path 不存在: {bundle_path}。请在参数里传 bundle_path，或设置环境变量 RQALPHA_BUNDLE_PATH。"
        )

    output_root_raw = params.get("output_root")
    output_root = _coerce_path(output_root_raw, field_name="output_root", required=False) or _default_output_root()
    output_root.mkdir(parents=True, exist_ok=True)

    return {
        "strategy_path": strategy_path,
        "start_date": start_date,
        "end_date": end_date,
        "frequency": frequency,
        "init_cash": init_cash,
        "benchmark": benchmark,
        "symbol": symbol,
        "bundle_path": bundle_path,
        "output_root": output_root,
    }

def _fallback_extract_result(result_pkl: Path) -> dict[str, Any]:
    with result_pkl.open("rb") as f:
        raw = pickle.load(f)
    if not isinstance(raw, dict):
        raise ValueError("result.pkl 内容不是 dict，无法 fallback 解析")

    summary = raw.get("summary", {})
    trades = raw.get("trades", [])
    equity = {"dates": [], "nav": [], "returns": [], "benchmark_nav": []}
    portfolio = raw.get("portfolio")
    if hasattr(portfolio, "index") and hasattr(portfolio, "columns"):
        try:
            equity["dates"] = [str(v.date()) if hasattr(v, "date") else str(v) for v in portfolio.index]
            if "unit_net_value" in portfolio.columns:
                equity["nav"] = portfolio["unit_net_value"].tolist()
            if "returns" in portfolio.columns:
                equity["returns"] = portfolio["returns"].tolist()
        except Exception:
            pass

    if isinstance(trades, list) and trades and isinstance(trades[0], dict):
        trade_columns = list(trades[0].keys())
    else:
        trade_columns = []

    return {
        "summary": summary if isinstance(summary, dict) else {},
        "equity": equity,
        "trades": trades if isinstance(trades, list) else [],
        "trade_columns": trade_columns,
    }

def _run_backtest_shim(params: dict[str, Any]) -> dict[str, Any]:
    p = _normalize_params(params)
    run_id = f"nb_{datetime.now().strftime("%Y%m%d_%H%M%S")}_{uuid.uuid4().hex[:8]}"
    run_dir = (p["output_root"] / run_id).resolve()
    run_dir.mkdir(parents=True, exist_ok=False)

    strategy_dst = run_dir / "strategy.py"
    strategy_dst.write_text(p["strategy_path"].read_text(encoding="utf-8"), encoding="utf-8")

    with FLASK_APP.app_context():
        FLASK_APP.config["RQALPHA_BUNDLE_PATH"] = str(p["bundle_path"])
        cfg_text = runner_service.build_config_yaml(
            start_date=p["start_date"],
            end_date=p["end_date"],
            cash=p["init_cash"],
            benchmark=p["benchmark"],
            frequency=p["frequency"],
            output_file=str((run_dir / "result.pkl").resolve()),
        )
        (run_dir / "config.yml").write_text(cfg_text, encoding="utf-8")

        params_dump = {
            "strategy_path": str(p["strategy_path"]),
            "start_date": p["start_date"],
            "end_date": p["end_date"],
            "frequency": p["frequency"],
            "init_cash": p["init_cash"],
            "benchmark": p["benchmark"],
            "symbol": p["symbol"],
            "bundle_path": str(p["bundle_path"]),
            "output_root": str(p["output_root"]),
            "run_id": run_id,
        }
        (run_dir / "params.json").write_text(
            json.dumps(params_dump, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )

        exit_code = runner_service.run_rqalpha(run_id, run_dir)

    if exit_code != 0:
        raise RuntimeError(f"回测执行失败，exit_code={exit_code}，请查看日志: {run_dir / 'run.log'}")

    result_pkl = run_dir / "result.pkl"
    if not result_pkl.exists():
        raise FileNotFoundError(f"未找到结果文件: {result_pkl}")

    extracted_path = run_dir / "extracted.json"
    if extractor_service is not None and hasattr(extractor_service, "extract_result"):
        result_payload = extractor_service.extract_result(result_pkl, extracted_path)
    else:
        result_payload = _fallback_extract_result(result_pkl)
        extracted_path.write_text(
            json.dumps(result_payload, ensure_ascii=False, indent=2, default=str),
            encoding="utf-8",
        )

    return {
        "run_id": run_id,
        "run_dir": str(run_dir),
        "result_path": str(extracted_path),
        "log_path": str(run_dir / "run.log"),
        "result": result_payload,
        "params": {k: (str(v) if isinstance(v, Path) else v) for k, v in p.items()},
    }

if not hasattr(runner_service, "run_backtest") or not callable(getattr(runner_service, "run_backtest")):
    runner_service.run_backtest = _run_backtest_shim
    display(Markdown("当前后端未提供 `runner.run_backtest`，Notebook 已安装兼容 shim。"))
else:
    display(Markdown("检测到后端原生 `runner.run_backtest`，Notebook 将直接调用。"))

def run_backtest(params: dict[str, Any]) -> dict[str, Any]:
    with FLASK_APP.app_context():
        return runner_service.run_backtest(params)


## 参数区（widgets 版）

In [None]:
w_strategy_path = widgets.Text(description="strategy_path", placeholder="/abs/path/to/strategy.py", layout=widgets.Layout(width="100%"))
w_start_date = widgets.Text(description="start_date", value="2026-01-01")
w_end_date = widgets.Text(description="end_date", value="2026-01-31")
w_frequency = widgets.Dropdown(description="frequency", options=["1d", "1m"], value="1d")
w_init_cash = widgets.IntText(description="init_cash", value=1000000)
w_benchmark = widgets.Text(description="benchmark", value="000300.XSHG")
w_symbol = widgets.Text(description="symbol", placeholder="可选")
w_bundle_path = widgets.Text(description="bundle_path", value=os.environ.get("RQALPHA_BUNDLE_PATH", ""), layout=widgets.Layout(width="100%"))
w_output_root = widgets.Text(description="output_root", value=str(_default_output_root()), layout=widgets.Layout(width="100%"))

btn_run = widgets.Button(description="运行回测", button_style="primary")
out = widgets.Output()

def _collect_widget_params() -> dict[str, Any]:
    return {
        "strategy_path": w_strategy_path.value,
        "start_date": w_start_date.value,
        "end_date": w_end_date.value,
        "frequency": w_frequency.value,
        "init_cash": w_init_cash.value,
        "benchmark": w_benchmark.value,
        "symbol": w_symbol.value,
        "bundle_path": w_bundle_path.value,
        "output_root": w_output_root.value,
    }

def _on_run_clicked(_):
    with out:
        out.clear_output()
        try:
            params = _collect_widget_params()
            ctx = run_backtest(params)
            globals()["LAST_RUN_CONTEXT"] = ctx
            print(f"回测完成，run_id={ctx['run_id']}")
            print(f"输出目录: {ctx['run_dir']}")
            print("可继续运行展示单元格查看图表和表格。")
        except Exception as exc:
            print(f"运行失败: {exc}")
            traceback.print_exc()

btn_run.on_click(_on_run_clicked)

display(widgets.VBox([
    w_strategy_path,
    widgets.HBox([w_start_date, w_end_date, w_frequency]),
    widgets.HBox([w_init_cash, w_benchmark]),
    widgets.HBox([w_symbol]),
    w_bundle_path,
    w_output_root,
    btn_run,
    out,
]))


## 参数区（纯代码版）

In [None]:
params = {
    "strategy_path": "/abs/path/to/strategy.py",
    "start_date": "2026-01-01",
    "end_date": "2026-01-31",
    "frequency": "1d",  # 仅支持 1d/1m
    "init_cash": 1000000,
    "benchmark": "000300.XSHG",
    "symbol": "",        # optional
    "bundle_path": "",   # optional: 为空时尝试读取环境变量 RQALPHA_BUNDLE_PATH
    "output_root": "",   # optional: 默认 research/artifacts/backtests
}
params


## 运行回测（调用 `run_backtest`）

In [None]:
RUN_CONTEXT = None
try:
    RUN_CONTEXT = run_backtest(params)
    LAST_RUN_CONTEXT = RUN_CONTEXT
    print(f"回测完成: {RUN_CONTEXT['run_id']}")
    print(f"输出目录: {RUN_CONTEXT['run_dir']}")
except Exception as exc:
    print(f"回测失败: {exc}")
    raise


## 读取结果（优先 extractor，fallback 读文件）

- 优先调用 `app.backtest.services.extractor.extract_result`。
- 若 extractor 不可用，则 fallback 读取 `extracted.json/result.json/result.pkl`。

In [None]:
def load_result_payload(run_dir: Path) -> dict[str, Any]:
    run_dir = run_dir.resolve()
    result_pkl = run_dir / "result.pkl"
    extracted_json = run_dir / "extracted.json"
    result_json = run_dir / "result.json"

    if extractor_service is not None and hasattr(extractor_service, "extract_result") and result_pkl.exists():
        return extractor_service.extract_result(result_pkl, extracted_json)

    if extracted_json.exists():
        return json.loads(extracted_json.read_text(encoding="utf-8"))
    if result_json.exists():
        return json.loads(result_json.read_text(encoding="utf-8"))
    if result_pkl.exists():
        return _fallback_extract_result(result_pkl)

    raise FileNotFoundError(
        f"结果文件不存在（extracted.json/result.json/result.pkl 均未找到）: {run_dir}"
    )

def tail_text(path: Path, n: int = 60) -> str:
    if not path.exists():
        return ""
    lines = path.read_text(encoding="utf-8", errors="replace").splitlines()
    return "\n".join(lines[-n:])


## 可视化与展示

In [None]:
def show_backtest_report(run_context: dict[str, Any] | None = None) -> None:
    ctx = run_context or globals().get("LAST_RUN_CONTEXT")
    if not ctx:
        raise ValueError("没有可用的 run_context，请先运行回测。")

    run_dir = Path(ctx["run_dir"]).resolve()
    payload = load_result_payload(run_dir)

    summary = payload.get("summary") if isinstance(payload, dict) else {}
    summary = summary if isinstance(summary, dict) else {}

    equity = payload.get("equity") if isinstance(payload, dict) else {}
    equity = equity if isinstance(equity, dict) else {}
    dates = equity.get("dates") if isinstance(equity.get("dates"), list) else []
    nav = equity.get("nav") if isinstance(equity.get("nav"), list) else []

    display(Markdown("### 净值曲线"))
    if nav:
        plt.figure(figsize=(11, 4))
        if dates and len(dates) == len(nav):
            plt.plot(dates, nav)
            plt.xticks(rotation=45)
        else:
            plt.plot(nav)
        plt.title(f"Net Value Curve - {run_dir.name}")
        plt.tight_layout()
        plt.show()
    else:
        print("暂无净值数据")

    display(Markdown("### 关键指标"))
    if summary:
        metrics_df = pd.DataFrame(
            [{"metric": k, "value": v} for k, v in summary.items()]
        )
    else:
        metrics_df = pd.DataFrame(columns=["metric", "value"])
    display(metrics_df)

    trades = payload.get("trades") if isinstance(payload, dict) else []
    trades_df = pd.DataFrame(trades if isinstance(trades, list) else [])
    display(Markdown("### 交易明细（前 20 行）"))
    display(trades_df.head(20))

    display(Markdown("### 回测日志（最后 60 行）"))
    log_path = run_dir / "run.log"
    if log_path.exists():
        print(tail_text(log_path, 60))
    else:
        print(f"日志文件不存在: {log_path}")

show_backtest_report(globals().get("LAST_RUN_CONTEXT"))
