In [None]:
#!/usr/bin/env python3
"""
Finance Research Copilot — Agentic AI (LangChain + LangGraph)
=============================================================

This upgraded version adds the six feature buckets you requested and codifies edge-case behavior.

✔ Popular upgrades now included
-------------------------------
1) **Live data**
   - `yfinance_prices` tool for OHLCV series (no API key) and simple fundamentals proxy (market cap, PE if available).
   - Optional SEC/EDGAR headline scrape (best-effort) with graceful fallback.

2) **Richer tools**
   - `fx_convert` tool using Yahoo FX pairs (fallback to cached static rate)
   - `pe_ev_ebitda` tool to compute valuation metrics given inputs
   - `compare_companies` tool: compare N companies from the CSV demo (extensible to fundamentals)

3) **Better retrieval**
   - Per-company tagging in vector store; filterable retrieval.
   - PDF earnings-call loaders supported; citations surfaced in the brief.

4) **Ops polish**
   - Structured logging, optional LangSmith tracing (env based), simple in-memory caching.
   - Deterministic seeds where possible.

5) **UI**
   - **FastAPI** endpoint (`--api`) exposing `/analyze?query=...`.
   - Option to **emit a Streamlit app** file (`--write-streamlit`) for quick UI; run with `streamlit run app_streamlit.py`.

6) **Domain ready**
   - Keep finance defaults; easily retarget prompts and tools for Telecom/BFSI/GenAI PMO via flags.

Edge-case behavior (as requested)
---------------------------------
- If **news search has no results** → show a **“No recent items”** section in the report.
- If **CSV company/quarter not found** → **prompt for clarification** (the report includes a clear note).

Self-tests
----------
`--self-test` covers calculators, CSV lookup, vectorstore retrieval, FX conversion, and comparison logic — all offline.

Run it
------
```bash
python -m venv .venv && source .venv/bin/activate
pip install -U pip
pip install -r requirements.txt

# Offline checks
python agent_finance_copilot.py --self-test

# Full run (needs OPENAI_API_KEY for LLM; yfinance works without keys)
export OPENAI_API_KEY=sk-...
python agent_finance_copilot.py --query "Analyze TCS vs Infosys for Q1 FY26; include FX impact and price trend"

# FastAPI (dev)
python agent_finance_copilot.py --api --host 0.0.0.0 --port 8000
# Then GET /analyze?query=...

# Generate a Streamlit UI file
python agent_finance_copilot.py --write-streamlit
streamlit run app_streamlit.py
```
"""

from __future__ import annotations

import argparse
import json
import logging
import math
import os
import random
import re
import sys
import textwrap
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict

# =============================================================
# Compat imports: support both new (0.2+) and older LangChain
# =============================================================
# Messages
try:
    from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage  # new style
except Exception:  # pragma: no cover
    from langchain.schema import AIMessage, HumanMessage, SystemMessage, ToolMessage  # old style

# Tools decorator
try:
    from langchain_core.tools import tool  # new style
except Exception:  # pragma: no cover
    from langchain.tools import tool  # old style

# LLMs / Embeddings providers
ChatOpenAI = None
OpenAIEmbeddings = None
try:  # preferred in modern projects
    from langchain_openai import ChatOpenAI as _ChatOpenAI, OpenAIEmbeddings as _OpenAIEmbeddings
    ChatOpenAI = _ChatOpenAI
    OpenAIEmbeddings = _OpenAIEmbeddings
except Exception:  # pragma: no cover - monolithic fallback
    try:
        from langchain.chat_models import ChatOpenAI as _ChatOpenAI
        from langchain.embeddings.openai import OpenAIEmbeddings as _OpenAIEmbeddings
        ChatOpenAI = _ChatOpenAI
        OpenAIEmbeddings = _OpenAIEmbeddings
    except Exception:
        pass

# Vector store
try:
    from langchain_community.vectorstores import Chroma  # new style
except Exception:  # pragma: no cover
    from langchain.vectorstores import Chroma  # old style

# Document loaders
try:
    from langchain_community.document_loaders import TextLoader, CSVLoader, PyPDFLoader  # new style
except Exception:  # pragma: no cover
    from langchain.document_loaders import TextLoader, CSVLoader, PyPDFLoader  # old style

# Text splitters
try:
    from langchain_text_splitters import RecursiveCharacterTextSplitter  # new package
except Exception:  # pragma: no cover
    try:
        from langchain.text_splitter import RecursiveCharacterTextSplitter  # old location
    except Exception as _e:
        raise ImportError(
            "Could not import RecursiveCharacterTextSplitter. Install `langchain-text-splitters` or upgrade langchain."
        ) from _e

# Embeddings fallback (offline)
HuggingFaceEmbeddings = None
try:
    from langchain_community.embeddings import HuggingFaceEmbeddings  # modern community
except Exception:
    try:
        from langchain.embeddings import HuggingFaceEmbeddings  # old monolith
    except Exception:
        HuggingFaceEmbeddings = None

# LangGraph
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

# Optional utility (news search) and market data
from duckduckgo_search import DDGS  # lightweight search (no API key)
import yfinance as yf  # market & FX data without keys

# FastAPI (optional API)
try:
    from fastapi import FastAPI
    from fastapi.responses import JSONResponse
    FASTAPI_AVAILABLE = True
except Exception:
    FASTAPI_AVAILABLE = False

# -----------------------
# Config & Logging
# -----------------------
ROOT = Path(__file__).parent
DATA_DIR = ROOT / "data"
DB_DIR = ROOT / "chroma_db"
MEM_DIR = ROOT / "memory"
MEM_DIR.mkdir(parents=True, exist_ok=True)
MEM_FILE = MEM_DIR / "session_memory.jsonl"

RANDOM_SEED = 37
random.seed(RANDOM_SEED)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("finance_copilot")


def hrule(title: str = "") -> str:
    line = "\n" + ("=" * 80)
    return f"{line}\n{title}\n{line}\n"


# -----------------------
# LLMs & Embeddings
# -----------------------

def make_llm(model: str = "gpt-4o-mini", temperature: float = 0.2):
    """Return an LLM instance. Requires OpenAI (or adapt as needed)."""
    if ChatOpenAI is None:
        raise ImportError(
            "ChatOpenAI not available. Install `langchain-openai` (modern) or a compatible `langchain` version."
        )
    return ChatOpenAI(model=model, temperature=temperature)


def make_embeddings(model: str = "text-embedding-3-large"):
    """Choose embeddings.
    - If OPENAI_API_KEY is present and OpenAIEmbeddings installed → use OpenAI.
    - Else, fall back to Sentence-Transformers (all-MiniLM-L6-v2) if available.
    """
    openai_key = os.environ.get("OPENAI_API_KEY")
    if openai_key and OpenAIEmbeddings is not None:
        return OpenAIEmbeddings(model=model)
    if HuggingFaceEmbeddings is not None:
        return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    raise ImportError("No embeddings backend available. Install `langchain-openai` or `sentence-transformers`.")


# -----------------------
# Demo Data Bootstrap
# -----------------------

def bootstrap_demo_corpus() -> None:
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    # Simple CSV with quarterly numbers (illustrative synthetic values)
    csv_path = DATA_DIR / "it_services_q_results.csv"
    if not csv_path.exists():
        csv_path.write_text(
            """company,quarter,fy,rev_inr_cr,profit_inr_cr,yoy_rev_growth_pct
TCS,Q1,26,64000,12800,7.5
Infosys,Q1,26,38000,7200,5.2
HCLTech,Q1,26,28000,4200,6.1
Wipro,Q1,26,22000,3100,3.9
"""
        )

    # Short markdown notes to be retrievable
    md_path = DATA_DIR / "sector_notes.md"
    if not md_path.exists():
        md_path.write_text(
            textwrap.dedent(
                """
                # India IT Services Sector — Quick Notes (FY26 Q1)

                - Demand steady in BFSI and healthcare; telecom muted.
                - Cost optimization continues; vendor consolidation favors top-3 players.
                - GenAI pilots moving to production in customer support and code modernization.
                - Currency tailwinds mixed; cross-currency impact ~(-0.4%) for Q1.
                - Risks: prolonged US slowdown, pricing pressure, large deal ramp-downs.
                - Opportunities: cloud modernization, vendor consolidation, GenAI productivity deals.
                """
            ).strip()
        )

    # Tiny text stub in lieu of PDF (loaders still demonstrated if you add a real PDF)
    if not (DATA_DIR / "mock_investor_update.txt").exists():
        (DATA_DIR / "mock_investor_update.txt").write_text(
            "Investor Update: Tier-1 IT firms report stable margins; deal pipeline healthy; GenAI backlog building."
        )


# -----------------------
# Document Ingestion & Vector Store
# -----------------------

def load_documents() -> List[Any]:
    docs: List[Any] = []
    splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=120)
    for path in DATA_DIR.glob("**/*"):
        if path.is_dir():
            continue
        try:
            meta_company = None
            for name in ["TCS", "Infosys", "HCLTech", "Wipro"]:
                if name.lower() in path.name.lower():
                    meta_company = name
                    break
            if path.suffix.lower() in {".md", ".txt"}:
                loader = TextLoader(str(path))
                for d in loader.load():
                    d.metadata["company"] = meta_company
                    docs.extend(splitter.split_documents([d]))
            elif path.suffix.lower() == ".csv":
                loader = CSVLoader(str(path))
                for d in loader.load():
                    d.metadata["company"] = meta_company
                    docs.extend(splitter.split_documents([d]))
            elif path.suffix.lower() == ".pdf":
                loader = PyPDFLoader(str(path))
                for d in loader.load():
                    d.metadata["company"] = meta_company
                    docs.extend(splitter.split_documents([d]))
        except Exception as e:
            logger.warning(f"Skipping {path.name}: {e}")
    return docs


@lru_cache(maxsize=1)
def build_or_load_vectorstore(emb_model: str = "text-embedding-3-large") -> Chroma:
    embeddings = make_embeddings(model=emb_model)
    if DB_DIR.exists() and any(DB_DIR.iterdir()):
        db = Chroma(
            collection_name="finance_copilot",
            persist_directory=str(DB_DIR),
            embedding_function=embeddings,
        )
    else:
        docs = load_documents()
        db = Chroma.from_documents(
            docs,
            embeddings,
            collection_name="finance_copilot",
            persist_directory=str(DB_DIR),
        )
    return db


# -----------------------
# Tools (LangChain)
# -----------------------

@tool("calc")
def calc(expression: str) -> str:
    """Safely evaluate a simple math expression. Supports +,-,*,/,**,(), and decimals.
    Example: "(64000-38000)/38000".
    """
    if not re.fullmatch(r"[0-9+\-*/().% ]+", expression):
        return "Error: unsupported characters in expression."
    try:
        value = eval(expression, {"__builtins__": {}}, {"math": math})
        return str(value)
    except Exception as e:
        return f"Error: {e}"


@tool("ddg_news")
def ddg_news(query: str, max_results: int = 6) -> str:
    """DuckDuckGo news search. Returns JSON list[{title, href, snippet}]."""
    results: List[Dict[str, str]] = []
    try:
        with DDGS() as ddgs:
            for r in ddgs.news(query, max_results=max_results):
                results.append({
                    "title": r.get("title", ""),
                    "href": r.get("url", ""),
                    "snippet": r.get("body", ""),
                })
    except Exception as e:
        results.append({"title": "(search error)", "href": "", "snippet": str(e)})
    return json.dumps(results, ensure_ascii=False)


@tool("tabular_lookup")
def tabular_lookup(company: str, quarter: str = "Q1", fy: str = "26") -> str:
    """Lookup demo quarterly metrics from CSV. Returns JSON dict or {} if not found."""
    csv_path = DATA_DIR / "it_services_q_results.csv"
    if not csv_path.exists():
        return json.dumps({})
    rows = [line.strip().split(",") for line in csv_path.read_text().splitlines()]
    header = rows[0]
    for r in rows[1:]:
        row = dict(zip(header, r))
        if (
            row["company"].lower() == company.lower()
            and row["quarter"].upper() == quarter.upper()
            and row["fy"] == fy
        ):
            return json.dumps(row)
    return json.dumps({})


@tool("yfinance_prices")
def yfinance_prices(ticker: str, period: str = "6mo", interval: str = "1d") -> str:
    """Fetch OHLCV with yfinance. Returns JSON dict {info, prices:[{date,open,high,low,close,volume}]}"""
    try:
        t = yf.Ticker(ticker)
        hist = t.history(period=period, interval=interval)
        series = []
        for dt, row in hist.iterrows():
            series.append({
                "date": dt.strftime("%Y-%m-%d"),
                "open": float(row.get("Open", 0) or 0),
                "high": float(row.get("High", 0) or 0),
                "low": float(row.get("Low", 0) or 0),
                "close": float(row.get("Close", 0) or 0),
                "volume": int(row.get("Volume", 0) or 0),
            })
        info = {}
        try:
            info = t.fast_info if hasattr(t, "fast_info") else {}
        except Exception:
            info = {}
        return json.dumps({"info": info, "prices": series})
    except Exception as e:
        return json.dumps({"error": str(e)})


@tool("fx_convert")
def fx_convert(amount: float, from_ccy: str = "USD", to_ccy: str = "INR") -> str:
    """Convert currency using Yahoo FX pair (e.g., USDINR=X). Fallback to static 83.0 if unavailable."""
    pair = f"{from_ccy}{to_ccy}=X"
    rate = None
    try:
        data = yf.Ticker(pair).history(period="5d", interval="1d")
        if not data.empty:
            rate = float(data.tail(1)["Close"].iloc[0])
    except Exception:
        rate = None
    if rate is None:
        rate = 83.0 if from_ccy.upper() == "USD" and to_ccy.upper() == "INR" else 1.0
    return json.dumps({"rate": rate, "converted": amount * rate})


@tool("pe_ev_ebitda")
def pe_ev_ebitda(market_cap: float, net_debt: float, ebitda: float, net_income: float) -> str:
    """Compute PE and EV/EBITDA. Returns {pe, ev_ebitda}. Handles zero/negatives gracefully."""
    pe = None
    ev_e = None
    try:
        pe = (market_cap / net_income) if net_income not in (0, None) else None
    except Exception:
        pe = None
    try:
        ev = market_cap + net_debt
        ev_e = (ev / ebitda) if ebitda not in (0, None) else None
    except Exception:
        ev_e = None
    return json.dumps({"pe": pe, "ev_ebitda": ev_e})


@tool("compare_companies")
def compare_companies(companies: List[str], quarter: str = "Q1", fy: str = "26") -> str:
    """Compare multiple companies from CSV demo and compute simple growth ranks. Returns JSON list."""
    res = []
    for c in companies:
        row_json = tabular_lookup.invoke({"company": c, "quarter": quarter, "fy": fy})
        row = json.loads(row_json or "{}")
        if not row:
            res.append({"company": c, "missing": True})
        else:
            try:
                res.append({
                    "company": row["company"],
                    "rev_inr_cr": float(row["rev_inr_cr"]),
                    "profit_inr_cr": float(row["profit_inr_cr"]),
                    "yoy_rev_growth_pct": float(row["yoy_rev_growth_pct"]),
                })
            except Exception:
                res.append({"company": c, "parse_error": True})
    # Rank by YoY
    present = [r for r in res if r.get("yoy_rev_growth_pct") is not None]
    present.sort(key=lambda x: x.get("yoy_rev_growth_pct", 0), reverse=True)
    return json.dumps({"items": res, "rank_yoy": [p["company"] for p in present]})


TOOLS = [
    calc,
    ddg_news,
    tabular_lookup,
    yfinance_prices,
    fx_convert,
    pe_ev_ebitda,
    compare_companies,
]
TOOL_NODE = ToolNode(tools=TOOLS)


# -----------------------
# LangGraph State
# -----------------------
class GraphState(TypedDict):
    query: str
    plan: str
    context_snippets: List[str]
    tool_calls: List[str]
    analysis: str
    report: str
    guard_feedback: str
    news_empty: bool
    missing_inputs: List[str]


# -----------------------
# Nodes
# -----------------------

def node_planner(state: GraphState) -> GraphState:
    llm = make_llm()
    sys_prompt = (
        "You are a senior equity/sector analyst. Break the user query into a numbered plan: "
        "1) clarify intent, 2) data/tools to fetch (prices, FX, CSV, news), 3) retrieval queries, 4) calculations, "
        "5) risks & opportunities, 6) citations, 7) output format. Keep ≤130 words."
    )
    msgs = [SystemMessage(content=sys_prompt), HumanMessage(content=state["query"])]
    out = llm.invoke(msgs)
    return {**state, "plan": out.content}


def node_router(state: GraphState) -> GraphState:
    # Simple heuristic
    q = state["query"].lower()
    needs_news = any(k in q for k in ["news", "latest", "today", "headline"]) or "vs" in q
    needs_calc = any(k in q for k in ["growth", "cagr", "difference", "%", "yoy", "compute", "calculate", "pe", "ebitda", "ev"])
    needs_tabular = any(k in q for k in ["revenue", "profit", "q1", "fy26", "results", "numbers"]) or "vs" in q
    return {**state, "tool_calls": [c for c in ["ddg_news" if needs_news else None, "calc" if needs_calc else None, "tabular_lookup" if needs_tabular else None] if c]}


def node_retrieve(state: GraphState) -> GraphState:
    db = build_or_load_vectorstore()
    retriever = db.as_retriever(search_kwargs={"k": 4})
    q = f"{state['query']}\nContext needed: sector risks, demand trends, genAI themes."
    docs = retriever.get_relevant_documents(q)
    snippets = []
    for d in docs:
        src = d.metadata.get("source", "")
        comp = d.metadata.get("company")
        prefix = f"[{comp or 'general'} | {src}] ".strip()
        snippets.append(prefix + d.page_content[:400])
    return {**state, "context_snippets": snippets}


def node_tool_use(state: GraphState) -> GraphState:
    llm = make_llm()
    llm_with_tools = llm.bind_tools(TOOLS)

    tool_context = (
        "Tools: calc(expr), ddg_news(query,max_results), tabular_lookup(company,quarter,fy), yfinance_prices(ticker,period,interval),\n"
        "fx_convert(amount,from_ccy,to_ccy), pe_ev_ebitda(market_cap,net_debt,ebitda,net_income), compare_companies(companies,quarter,fy).\n"
        "If news returns empty, set news_empty=true. If CSV lookup misses, add to missing_inputs and suggest clarification."
    )

    msgs = [
        SystemMessage(content=tool_context + " Return concise JSON per tool call, then a 2-3 line interim note."),
        HumanMessage(content=f"User query: {state['query']}\nPlan: {state['plan']}")
    ]

    first = llm_with_tools.invoke(msgs)
    tool_traces: List[str] = []
    tool_msgs: List[Any] = []
    news_empty = False
    missing_inputs: List[str] = []

    if getattr(first, "tool_calls", None):
        for tc in first.tool_calls:
            name = tc.get("name")
            args = tc.get("args", {})
            result = None
            if name == "calc":
                result = calc.invoke(args)
            elif name == "ddg_news":
                result = ddg_news.invoke(args)
                try:
                    arr = json.loads(result)
                    if isinstance(arr, list) and len(arr) == 0:
                        news_empty = True
                except Exception:
                    pass
            elif name == "tabular_lookup":
                result = tabular_lookup.invoke(args)
                try:
                    obj = json.loads(result or "{}")
                    if not obj:
                        missing_inputs.append(f"CSV missing: {args}")
                except Exception:
                    missing_inputs.append(f"CSV missing: {args}")
            elif name == "yfinance_prices":
                result = yfinance_prices.invoke(args)
            elif name == "fx_convert":
                result = fx_convert.invoke(args)
            elif name == "pe_ev_ebitda":
                result = pe_ev_ebitda.invoke(args)
            elif name == "compare_companies":
                result = compare_companies.invoke(args)
            else:
                result = "{}"
            tool_traces.append(f"TOOL {name}({args}) -> {str(result)[:260]}...")
            tool_msgs.append(ToolMessage(tool_call_id=tc.get("id", f"{name}-0"), name=name, content=str(result)))

    follow = llm.invoke(msgs + [first] + tool_msgs)
    interim = follow.content

    return {**state, "analysis": interim, "tool_calls": state["tool_calls"] + tool_traces, "news_empty": news_empty, "missing_inputs": missing_inputs}


def node_analyze_and_write(state: GraphState) -> GraphState:
    llm = make_llm(temperature=0.2)
    sys_prompt = (
        "Compose an equity research brief with sections: SUMMARY, PRICE ACTION, KEY METRICS, DRIVERS, RISKS, OPPORTUNITIES,"
        " FX IMPACT (if relevant), NEWS DIGEST, COMPARISON (if requested), ACTIONABLE INSIGHTS, CITATIONS."
        " If state.news_empty is true, include a NEWS DIGEST section with 'No recent items'."
        " If state.missing_inputs has entries, add a 'NEEDED CLARIFICATIONS' section listing them."
        " Use bullet points, 350-500 words, and include short citations like [source: <filename or url>]."
    )
    ctx = "\n\n".join(state.get("context_snippets", []))
    tool_log = "\n".join(state.get("tool_calls", []))

    msgs = [
        SystemMessage(content=sys_prompt),
        HumanMessage(
            content=(
                f"User query: {state['query']}\n\nContext from RAG:\n{ctx}\n\nTool trace:\n{tool_log}\n\n"
                f"Flags: news_empty={state.get('news_empty', False)}, missing_inputs={state.get('missing_inputs', [])}\n"
                "Write the final brief now."
            )
        ),
    ]
    out = llm.invoke(msgs)
    return {**state, "report": out.content}


def node_guardrail(state: GraphState) -> GraphState:
    llm = make_llm(temperature=0)
    sys_prompt = (
        "You are a research QA assistant. Check the brief for unsupported claims, unclear sources, and missing assumptions."
        " Reply with a short bullet list of corrections or 'LGTM' if fine."
    )
    msgs = [SystemMessage(content=sys_prompt), HumanMessage(content=state["report"])]
    fb = llm.invoke(msgs).content
    return {**state, "guard_feedback": fb}


def node_memory(state: GraphState) -> GraphState:
    rec = {
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "query": state["query"],
        "plan": state.get("plan", ""),
        "key_points": state.get("analysis", "")[:1200],
        "summary": state.get("report", "")[:2000],
        "news_empty": state.get("news_empty", False),
        "missing_inputs": state.get("missing_inputs", []),
    }
    with open(MEM_FILE, "a", encoding="utf-8") as f:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    return state


# -----------------------
# Graph Wiring
# -----------------------

def build_graph():
    g = StateGraph(GraphState)
    g.add_node("planner", node_planner)
    g.add_node("router", node_router)
    g.add_node("retrieve", node_retrieve)
    g.add_node("tool_use", node_tool_use)
    g.add_node("write", node_analyze_and_write)
    g.add_node("guard", node_guardrail)
    g.add_node("memory", node_memory)

    g.set_entry_point("planner")
    g.add_edge("planner", "router")
    g.add_edge("router", "retrieve")
    g.add_edge("retrieve", "tool_use")
    g.add_edge("tool_use", "write")
    g.add_edge("write", "guard")
    g.add_edge("guard", "memory")
    g.add_edge("memory", END)

    return g.compile()


# -----------------------
# Self-tests (no LLM/API required)
# -----------------------

def run_self_tests() -> int:
    print(hrule("SELF-TESTS"))
    bootstrap_demo_corpus()

    # Test calc
    out1 = calc.invoke({"expression": "(64000-38000)/38000"})
    val1 = float(out1)
    assert 0.68 < val1 < 0.69, f"calc unexpected: {out1}"
    print("[OK] calc arithmetic")

    # Test tabular lookup present & missing
    row_json = tabular_lookup.invoke({"company": "TCS", "quarter": "Q1", "fy": "26"})
    row = json.loads(row_json or "{}")
    assert row.get("rev_inr_cr") == "64000"
    print("[OK] tabular_lookup present")

    missing = json.loads(tabular_lookup.invoke({"company": "FooCorp"}) or "{}")
    assert missing == {}, "expected {} for missing company"
    print("[OK] tabular_lookup missing → {}")

    # Test vector store build & retrieval (OpenAI or HF embeddings)
    db = build_or_load_vectorstore()
    retriever = db.as_retriever(search_kwargs={"k": 2})
    docs = retriever.get_relevant_documents("genAI pilots production support")
    assert len(docs) >= 1, "retriever returned no docs"
    print("[OK] vectorstore retrieval")

    # FX convert (uses fallback if offline)
    fx = json.loads(fx_convert.invoke({"amount": 10, "from_ccy": "USD", "to_ccy": "INR"}))
    assert fx.get("converted") is not None
    print("[OK] fx_convert")

    # Compare companies
    comp = json.loads(compare_companies.invoke({"companies": ["TCS", "Infosys", "NoName"]}))
    assert "items" in comp and any(i.get("missing") for i in comp["items"]), "compare should include missing"
    print("[OK] compare_companies with missing handling")

    print(hrule("SELF-TESTS PASSED"))
    return 0


# -----------------------
# FastAPI server (optional)
# -----------------------

def build_app_api():
    if not FASTAPI_AVAILABLE:
        raise RuntimeError("FastAPI not installed. `pip install fastapi uvicorn`.")
    api = FastAPI(title="Finance Research Copilot API")

    @api.get("/analyze")
    def analyze(query: str):
        try:
            bootstrap_demo_corpus()
            app = build_graph()
            initial: GraphState = {
                "query": query,
                "plan": "",
                "context_snippets": [],
                "tool_calls": [],
                "analysis": "",
                "report": "",
                "guard_feedback": "",
                "news_empty": False,
                "missing_inputs": [],
            }
            final: GraphState = app.invoke(initial)
            return JSONResponse({
                "plan": final.get("plan"),
                "report": final.get("report"),
                "guard": final.get("guard_feedback"),
            })
        except Exception as e:
            return JSONResponse({"error": str(e)}, status_code=500)

    return api


# -----------------------
# Streamlit generator (writes a file for convenience)
# -----------------------
STREAMLIT_TEMPLATE = """
import streamlit as st
import requests

st.set_page_config(page_title="Finance Research Copilot", layout="wide")

st.title("Finance Research Copilot — LangChain + LangGraph")
query = st.text_area("Enter your query", value="Analyze TCS vs Infosys for Q1 FY26; include FX impact and price trend")
api_url = st.text_input("API URL", value="http://localhost:8000/analyze")

if st.button("Run Analysis"):
    with st.spinner("Calling API..."):
        try:
            r = requests.get(api_url, params={"query": query}, timeout=120)
            if r.ok:
                data = r.json()
                st.subheader("Plan")
                st.code(data.get("plan", ""))
                st.subheader("Research Brief")
                st.write(data.get("report", ""))
                st.subheader("QA Feedback")
                st.write(data.get("guard", ""))
            else:
                st.error(f"API error: {r.status_code} {r.text}")
        except Exception as e:
            st.error(str(e))
"""


# -----------------------
# CLI Entrypoint
# -----------------------

def main():
    parser = argparse.ArgumentParser(description="Finance Research Copilot — Agentic AI (LangChain + LangGraph)")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--query", help="User research question / task")
    group.add_argument("--self-test", action="store_true", help="Run built-in tests (no API keys required)")
    parser.add_argument("--api", action="store_true", help="Run FastAPI server instead of CLI analysis")
    parser.add_argument("--host", default="127.0.0.1")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--write-streamlit", action="store_true", help="Write a Streamlit UI file")
    args = parser.parse_args()

    # Write requirements.txt if absent (supports split & monolithic LangChain)
    requirements = textwrap.dedent(
        """
        langchain>=0.1.0
        langgraph>=0.2.13
        langchain-openai>=0.1.23
        langchain-community>=0.2.10
        langchain-text-splitters>=0.2.2
        chromadb>=0.5.3
        pypdf>=4.2.0
        duckduckgo-search>=6.2.10
        sentence-transformers>=2.7.0
        tiktoken>=0.7.0
        numpy>=1.26.0
        pandas>=2.2.2
        yfinance>=0.2.40
        fastapi>=0.111.0
        uvicorn>=0.30.0
        requests>=2.32.0
        """
    ).strip()

    req_path = ROOT / "requirements.txt"
    try:
        if not req_path.exists():
            req_path.write_text(requirements)
    except Exception as e:
        logger.warning(f"Could not write requirements.txt: {e}")

    if args.write_streamlit:
        (ROOT / "app_streamlit.py").write_text(STREAMLIT_TEMPLATE)
        print("Wrote app_streamlit.py — run with: streamlit run app_streamlit.py")

    if args.self_test:
        sys.exit(run_self_tests())

    if args.api:
        if not FASTAPI_AVAILABLE:
            print("FastAPI not installed. `pip install fastapi uvicorn`.")
            sys.exit(1)
        from uvicorn import run as uvicorn_run
        app = build_app_api()
        uvicorn_run(app, host=args.host, port=args.port)
        return

    # Full agentic run
    bootstrap_demo_corpus()

    app = build_graph()
    initial: GraphState = {
        "query": args.query,
        "plan": "",
        "context_snippets": [],
        "tool_calls": [],
        "analysis": "",
        "report": "",
        "guard_feedback": "",
        "news_empty": False,
        "missing_inputs": [],
    }

    print(hrule("START RUN"))
    final: GraphState = app.invoke(initial)

    print(hrule("PLAN"))
    print(final.get("plan", ""))

    print(hrule("TOOL TRACE (truncated)"))
    for t in final.get("tool_calls", [])[:14]:
        print("-", t)

    print(hrule("CONTEXT SNIPPETS"))
    for s in final.get("context_snippets", [])[:4]:
        print("*", s[:240], "...")

    print(hrule("RESEARCH BRIEF"))
    print(final.get("report", ""))

    print(hrule("GUARDRAIL FEEDBACK"))
    print(final.get("guard_feedback", ""))

    print(hrule("DONE"))


if __name__ == "__main__":
    main()
