In [14]:
from __future__ import annotations
import json
import os
import re
import unicodedata
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict, List, Tuple
import pandas as pd
import requests
import functools

# -----------------------------------------------------------------------------
#  Configuration
# -----------------------------------------------------------------------------
CAPTIDE_API_KEY: str = os.getenv("CAPTIDE_API_KEY", "YOUR_CAPTIDE_API_KEY")
HEADERS = {
    "X-API-Key": CAPTIDE_API_KEY,
    "Content-Type": "application/json",
    "Accept": "application/json",
}

TICKERS: List[str] = ["SNAP"]          # ← add more tickers if you need

# -----------------------------------------------------------------------------
#  Helpers: fiscal-period parsing & validation
# -----------------------------------------------------------------------------
_Q_RE  = re.compile(r"Q([1-4]) (\d{4})")
_FY_RE = re.compile(r"FY (\d{4})")

def is_valid_fiscal_period(fp: str) -> bool:
    """True for Q1–Q4 YYYY or FY YYYY with year > 2022."""
    m = _Q_RE.fullmatch(fp)
    if m:
        return int(m.group(2)) > 2022
    m = _FY_RE.fullmatch(fp)
    if m:
        return int(m.group(1)) > 2022
    return False

def parse_fiscal_period(fp: str) -> Tuple[int, int]:
    """
    Returns (year, quarter) where quarter==0 for FY.
    Raises ValueError on bad format.
    """
    m = _Q_RE.fullmatch(fp)
    if m:
        return int(m.group(2)), int(m.group(1))
    m = _FY_RE.fullmatch(fp)
    if m:
        return int(m.group(1)), 0
    raise ValueError(f"Bad fiscalPeriod: {fp!r}")

# -----------------------------------------------------------------------------
#  Fetch filing metadata
# -----------------------------------------------------------------------------
def fetch_documents(ticker: str) -> Tuple[str, List[Dict]]:
    url = f"https://rest-api.captide.co/api/v1/companies/ticker/{ticker}/documents"
    try:
        resp = requests.get(url, headers=HEADERS, timeout=60)
        resp.raise_for_status()
        docs = resp.json()
        valid = [
            {
                "ticker": doc["ticker"],
                "fiscalPeriod": doc["fiscalPeriod"],
                "sourceLink": doc["sourceLink"],
                "date": doc["date"],
            }
            for doc in docs
            if doc["sourceType"] in {"10-K", "10-Q"}
            and "fiscalPeriod" in doc
            and is_valid_fiscal_period(doc["fiscalPeriod"])
        ]
        return ticker, valid
    except Exception as exc:
        print(f"[warn] {ticker}: {exc}")
        return ticker, []

all_docs: List[Dict] = []
with ThreadPoolExecutor(max_workers=5) as ex:
    futures = [ex.submit(fetch_documents, tk) for tk in TICKERS]
    for fut in as_completed(futures):
        _, docs = fut.result()
        all_docs.extend(docs)

grouped_docs: Dict[Tuple[str, str], List[str]] = defaultdict(list)
for d in all_docs:
    grouped_docs[(d["ticker"], d["fiscalPeriod"])].append(d["sourceLink"])

# -----------------------------------------------------------------------------
#  Captide RAG helpers
# -----------------------------------------------------------------------------
BASE_PROMPT = (
    "Return a single valid JSON object with double-quoted keys and textual values. "
    "Each key should be one risk factor mentioned in the filing (Item 1A) and the "
    "value should be a brief explanation of that factor."
)

def parse_sse_response(sse_text: str) -> Dict[str, str]:
    """Captide streams answers as SSE; extract the final JSON object."""
    try:
        lines = [ln[6:] for ln in sse_text.splitlines() if ln.startswith("data: ")]
        for ln in lines:
            obj = json.loads(ln)
            if obj.get("type") == "full_answer":
                # remove any markup like [#ref]
                content = re.sub(r"\s*\[#\w+\]", "", obj["content"])
                m = re.search(r"\{.*\}", content, re.DOTALL)
                return json.loads(m.group(0)) if m else {}
    except Exception:
        pass
    return {}

def fetch_metrics_with_prompt(source_links: List[str], prompt: str) -> Dict[str, str]:
    payload = {"query": prompt, "sourceLink": source_links}
    r = requests.post(
        "https://rest-api.captide.co/api/v1/rag/agent-query-stream",
        json=payload,
        headers=HEADERS,
        timeout=120,
    )
    r.raise_for_status()
    return parse_sse_response(r.text)

# -----------------------------------------------------------------------------
#  Normalisation for set-ops
# -----------------------------------------------------------------------------
def normalise_factor(text: str) -> str:
    s = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
    s = re.sub(r"\s+", " ", s)
    s = re.sub(r"[^\w\s]", "", s)
    return s.strip().lower()

def diff_sets(prev: set[str], curr: set[str]) -> Tuple[List[str], List[str]]:
    return sorted(curr - prev), sorted(prev - curr)

# -----------------------------------------------------------------------------
#  Prompt builder that forces key-name continuity
# -----------------------------------------------------------------------------
def build_prompt(prev_keys: List[str] | None) -> str:
    p = BASE_PROMPT
    if prev_keys:
        p += (
            "\n\nIMPORTANT: The previous filing used the following exact key names. "
            "If the current filing discusses the same underlying risk factor, "
            "REUSE THE EXACT SAME KEY STRING. "
            "Only add new keys for brand-new factors, and omit keys that no longer appear.\n"
            f"{json.dumps(prev_keys, ensure_ascii=False)}\n"
        )
    return p

# -----------------------------------------------------------------------------
#  Chronological ordering and cached RAG calls
# -----------------------------------------------------------------------------
def chronologically_sorted_docs(
    grouped: Dict[Tuple[str, str], List[str]]
) -> List[Tuple[str, List[str]]]:
    parsed = [
        (fp, links, *parse_fiscal_period(fp))
        for (_, fp), links in grouped.items()
    ]
    parsed.sort(key=lambda t: (t[2], t[3]))  # year, quarter
    return [(fp, links) for fp, links, *_ in parsed]

@functools.lru_cache(maxsize=None)
def get_risk_factors(links_tuple: Tuple[str, ...], prompt: str) -> Dict[str, str]:
    return fetch_metrics_with_prompt(list(links_tuple), prompt)

# -----------------------------------------------------------------------------
#  Main diff loop
# -----------------------------------------------------------------------------
results: List[Dict] = []
prev_key_list: List[str] = []
prev_norm_set: set[str] = set()

for fiscal_period, links in chronologically_sorted_docs(grouped_docs):

    prompt     = build_prompt(prev_key_list)
    raw_dict   = get_risk_factors(tuple(links), prompt)  # {key: desc}

    this_keys      = list(raw_dict.keys())               # exact strings
    this_norm_set  = {normalise_factor(k) for k in this_keys}

    additions, removals = diff_sets(prev_norm_set, this_norm_set)

    results.append(
        {
            "fiscal_period": fiscal_period,
            "risk_factors": this_keys,
            "additions_vs_prev_qtr": additions,
            "removals_vs_prev_qtr": removals,
        }
    )

    prev_key_list = this_keys
    prev_norm_set = this_norm_set

# -----------------------------------------------------------------------------
#  Present results
# -----------------------------------------------------------------------------
df = (
    pd.DataFrame(results)
    .sort_values(
        "fiscal_period",
        key=lambda s: s.map(parse_fiscal_period),
        ascending=False,
    )
    .reset_index(drop=True)
)

print(df.to_markdown(index=False))

# Optional: save
# df.to_csv("risk_factor_change_log.csv", index=False)
print("\nDone – table printed above. Save to CSV with df.to_csv(...) if needed.")

| fiscal_period   | risk_factors                                                                                                                                                                                                                                                                                                                                                                        | additions_vs_prev_qtr                                                                                                                                                                                                                                                                                                                                                               | removals_vs_prev_qtr                                                                                                                                                                                                                    