# Trend Analysis Bot Demo


Usage Notes & Caveats
*   Input your bigqquery project IDs for permissions/running queries and storage of output tables
*   Bring your own data (modify the YAML file and SQL generation fuctions)
*   This is a POC and still has bugs to work out, use with caution
*   The bot's analysis can be incorrect, and it often makes non-useful recommendations / interpretations of data
*   Follow along with future iterations here:  https://github.com/oscarhealth/trend-analyzer




In [None]:
#@title Process Flowchart
import IPython
from google.colab import output

# The HTML content containing the Mermaid diagram
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Mermaid Diagram</title>
</head>
<body>

    <div class="mermaid" style="text-align: left;">
    graph TD
        subgraph "Phase 1: Setup and Data Prep"
            direction LR
            A[Start] --> B{Initialize<br>Notebook};
            B --> C[Install Dependencies<br>& Authenticate];
            C --> D[Define YAML<br>Configuration];
            D -- "Defines dimensions,<br>metrics, tables" --> E{Generate SQL};
            E --> F["Execute BigQuery SQL to<br>Generate Analysis Tables"];
            F --> G["Create Descriptor (Claims)<br>and Norm (Membership) Tables"];
        end

        subgraph "Phase 2: Agent & Tool Definition"
            direction LR
            H{Define<br>AI Agent} --> I["Set Agent's System<br>Prompt & Analysis Plan"];
            I --> J[Define Data<br>Access Functions];
            J -- Wraps --> K{Create<br>Agent Tools};
            K --> L[Data Analysis<br>Tools];
            K --> M[Reporting<br>Tools];
        end

        subgraph "Phase 3: Iterative Analysis Loop"
            direction LR
            N{Start Analysis<br>Loop} --> O{"Agent: Formulate<br>Hypothesis (PLAN)"};
            O --> P{"Agent: Select<br>Tool(s)"};
            P --> Q["Execute Tool Call<br>e.g., get_trend_data(...)"];
            Q --> R{Get Results<br>from BigQuery};
            R --> S{"Agent: Interpret<br>Results (REFLECT)"};
            S --> T{Update Report};
            T --> U[Append to<br>Google Doc];
            U --> V{More<br>Iterations?};
            V -- Yes --> O;
        end

        subgraph "Phase 4: Finalization"
            direction LR
            W{Synthesize<br>Findings} --> X["Generate Final<br>Summary &<br>Recommendations"];
            X --> Y["Write Final Report<br>to Google Doc"];
            Y --> Z[End];
        end

        %% Link the phases vertically
        G --> H;
        M --> N;
        L --> N;
        V -- No --> W;


        style B fill:#f9f,stroke:#333,stroke-width:2px
        style H fill:#f9f,stroke:#333,stroke-width:2px
        style N fill:#f9f,stroke:#333,stroke-width:2px
        style W fill:#f9f,stroke:#333,stroke-width:2px
    </div>

    <!-- Load Mermaid.js library -->
    <script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>

    <!-- Initialize Mermaid and render the diagram -->
    <script>
        mermaid.initialize({ startOnLoad: false });
        mermaid.run();
    </script>
</body>
</html>
"""

# Display the HTML in the Colab output
display(IPython.display.HTML(html_content))

# Authorization
Run this once

In [None]:
!pip install -q --upgrade google-cloud-bigquery vertexai google-cloud-aiplatform python-dotenv
!pip install openai-agents
!pip install -q python-docx
from google.colab import auth
from pathlib import Path

SCOPES = [
    "https://www.googleapis.com/auth/documents",
    "https://www.googleapis.com/auth/drive"
]

auth.authenticate_user()                       # OAuth flow in a popup

import google.auth, vertexai, os
CREDS, _ = google.auth.default(scopes=SCOPES)               # picks up the freshly-granted token

BILLING_PROJECT_ID_BQ = "oscaractuarial" # @param {"type":"string"}
REGION_BIGQUERY = "US" # @param {"type":"string"}
DESTINATION_PROJECT_ID_BQ = "oscaractuarial" # @param {"type":"string"}

# YAML configuration file for the cubes
DIM_PATH = "dimensions.yml" # @param {"type":"string"}



# Run this if you do not have a YAML configuration file

In [None]:
yaml_file = """
_meta:
  # Default aliases that the code will substitute if the expression
  # does not already contain a dot ("."):
  descriptor_alias: clc      # raw claim-lines snapshot
  norm_alias: m              # enrollment CTE that we join on

_targets:
  dataset: zone_djf_llm            # destination dataset only; project stays env-var
  descriptor_table: trend_llm_descriptor
  norm_table: trend_llm_norm

_years:
  2023:
    member_snapshot: oscaractuarial.zone_djf.llm_member_months_snap_cy2023
    claim_snapshot: oscaractuarial.zone_djf.llm_claim_lines_snap_cy2023
    months: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

  2024:
    member_snapshot: oscaractuarial.zone_djf.llm_member_months_snap_cy2024
    claim_snapshot: oscaractuarial.zone_djf.llm_claim_lines_snap_cy2024
    months: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

_norm_joins:

_filters:
  common:
    - "line_of_business = 'Individual'"
    - "plan_sponsor = 'Oscar'"
    - "tenant = 'oscar'"

  norm:
    - "NOT test_account"

  descriptor:
    - "NOT is_test_claim"
    - "final"

dimensions:
  # Basic identifying dimensions that appear in both tables
  - name: hios_id
    scope: both
    expr: hios_id
    description: "Health Insurance Oversight System (HIOS) plan identifier"

  - name: state
    scope: both
    expr: state
    description: "US state where the plan is offered"

  - name: plan_network_access_type
    scope: both
    expr: plan_network_access_type
    description: "Network access type (e.g., HMO, PPO)"

  - name: year
    scope: both
    expr: "EXTRACT(YEAR FROM {norm_alias}.month)"
    description: "Calendar year of the enrollment/claim"

  - name: geographic_reporting
    scope: both
    expr: |
      CASE
        WHEN SUBSTR({norm_alias}.hios_id,1,5) = '29341'
            THEN 'OHB (Columbus)'
        WHEN SUBSTR({norm_alias}.hios_id,1,5) = '45845'
            THEN 'OHC (Cleveland Clinic Product)'
        WHEN {norm_alias}.state = 'TX'
            THEN CONCAT({norm_alias}.state, '-', {norm_alias}.plan_network_access_type)
        ELSE {norm_alias}.state
      END
    description: "Oscar reporting market (OHB, OHC, TX-HN, or state)"

  - name: plan_metal
    scope: both
    expr: plan_metal
    description: "Plan metal level (e.g., Bronze, Silver) and CSR"

  - name: age_group
    scope: both
    expr: age_group
    description: "Age of the member grouped into age groups"

  - name: gender
    scope: both
    expr: gender
    description: "Gender of the member"

  # Claim-specific dimensions
  - name: claim_type
    scope: descriptor
    expr: claim_type
    description: "Type of claim: Facility, Professional, or RX"

  - name: major_service_category
    scope: descriptor
    expr: "COALESCE({descriptor_alias}.major_service_category, 'Unmapped')" #dave edit to a cleaned field using actuarial mapping
    description: "Highest level Health Cost Guidelines (HCG) service category"

  - name: provider_specialty
    scope: descriptor
    expr: |
      CASE
        WHEN {descriptor_alias}.claim_type = 'RX' THEN 'Pharmacy'
        ELSE {descriptor_alias}.specialty
      END
    description: "Specialty of the rendering provider (or Pharmacy for RX claims)"

  # - name: hcg_2_major_service_category_2
  #  scope: descriptor
  #  expr: "COALESCE({descriptor_alias}.hcg_2, 'Unmapped')"
  #  description: "Second level Health Cost Guidelines (HCG) service category"

  - name: detailed_service_category
    scope: descriptor
    expr: COALESCE({descriptor_alias}.detailed_service_category, 'Unmapped') # dave edit to a cleaned field using actuarial mapping
    description: "Detailed Health Cost Guidelines (HCG) code and description"

  - name: ms_drg
    scope: descriptor
    expr: "{descriptor_alias}.ms_drg || ' ' || {descriptor_alias}.ms_drg_description"
    description: "Medicare Severity Diagnosis Related Group (MS-DRG) code and description"

  - name: ms_drg_mdc
    scope: descriptor
    expr: "{descriptor_alias}.ms_drg_mdc || ' ' || {descriptor_alias}.ms_drg_mdc_desc"
    description: "Major Diagnostic Category (MDC) for the MS-DRG"

  - name: cpt
    scope: descriptor
    expr: cpt
    description: "Current Procedural Terminology (CPT) code"

  - name: cpt_consumer_description
    scope: descriptor
    expr: cpt_consumer_description
    description: "Consumer-friendly description of the CPT code"

  - name: procedure_level_1
    scope: descriptor
    expr: procedure_level_1
    description: "CPT Classification Level 1, highest level"

  - name: procedure_level_2
    scope: descriptor
    expr: procedure_level_2
    description: "CPT Classification Level 2, more detailed level"

  - name: procedure_level_3
    scope: descriptor
    expr: procedure_level_3
    description: "CPT Classification Level 3, even more detailed level"

  - name: procedure_level_4
    scope: descriptor
    expr: procedure_level_4
    description: "CPT Classification Level 4, surgeries only"

  - name: procedure_level_5
    scope: descriptor
    expr: procedure_level_5
    description: "CPT Classification Level 5, most detailed level, surgeries only"

  - name: channel
    scope: descriptor
    expr: channel
    description: "Place of service (e.g. IP, OP, SNF, URG)"

  - name: drug_name
    scope: descriptor
    expr: drug_name
    description: "Name of the drug"

  - name: drug_class
    scope: descriptor
    expr: drug_class
    description: "Medispan GPI-4 therapeutic class of the drug"

  - name: drug_subclass
    scope: descriptor
    expr: drug_subclass
    description: "Medispan GPI-6 therapeutic subclass of the drug"

  - name: drug
    scope: descriptor
    expr: drug
    description: "Medispan GPI-8 name of the drug"

  - name: is_out_of_network
    scope: descriptor
    expr: is_oon
    description: "Indicator (1/0) if the claim is out of network"

  - name: best_contracting_entity_name
    scope: descriptor
    expr: best_contracting_entity_name
    description: "Name of the provider contracting entity"

  - name: provider_group_name
    scope: descriptor
    expr: provider_group_name
    description: "Provider group, use as alternative to best_contracting_entity_name"

  - name: ccsr_system_description
    scope: descriptor
    expr: ccsr_system_description
    description: "CCSR body system & description"

  - name: ccsr_description
    scope: descriptor
    expr: ccsr_description
    description: "CCSR category code & description"

  # Membership-specific dimensions
  - name: region
    scope: both
    expr: region
    description: "Region of the member, generally the city"

  - name: enrollment_length_continuous
    scope: both
    expr: enrollment_length_continuous
    description: "Continuous length of enrollment for the member"

  - name: clinical_segment
    scope: both
    expr: clinical_segment
    description: "Member's clinical complexity segment"

  - name: general_agency_name
    scope: both
    expr: general_agency_name
    description: "Name of the broker general agency which acquired the member"

  - name: broker_name
    scope: both
    expr: broker_name
    description: "Name of the broker that acquired the member"

  - name: sa_contracting_entity_name
    scope: both
    expr: sa_contracting_entity_name
    description: "Name of the member's attributed contracting entity"

  - name: new_member_in_period
    scope: both
    expr: "IF({norm_alias}.enrollment_length_continuous <= 5, 1, 0)"
    description: "Indicator (1/0) if member is new (enrolled for 5 months or less)"

  - name: member_called_oscar
    scope: both
    expr: "IF({norm_alias}.call_count > 0, 1, 0)"
    description: "Indicator (1/0) if member called Oscar"

  - name: member_used_app
    scope: both
    expr: "IF({norm_alias}.app_login_count > 0, 1, 0)"
    description: "Indicator (1/0) if member used the Oscar app"

  - name: member_had_web_login
    scope: both
    expr: "IF({norm_alias}.web_login_count > 0, 1, 0)"
    description: "Indicator (1/0) if member had a web login on Oscar's website"

  - name: member_visited_new_provider_ind
    scope: both
    expr: member_visited_new_provider_ind
    description: "1 if the member saw a 'new-to-them' provider in the month"

  - name: high_cost_member
    scope: both
    expr: high_cost_member
    description: "High cost flag: allowed ≥$100k in the month"

  - name: mutually_exclusive_hcc_condition
    scope: both
    expr: mutually_exclusive_hcc_condition
    description: "HHS HCC Chronic Condition Label (mutually exclusive)"

  - name: wisconsin_area_deprivation_index
    scope: both
    expr: wisconsin_area_deprivation_index
    description: "ADI (WI national block-group decile 1-10). 10 is the most socially deprived areas"

metrics:
  # Claim metrics (descriptor scope)
  - name: charges
    scope: descriptor
    requires: [clean_claim_status, charges]
    expr: |
      SUM(
        CASE
          WHEN clean_claim_status = 'PAID'
               THEN charges
          ELSE 0
        END
      )

  - name: denied_charges
    scope: descriptor
    requires: [clean_claim_status, charges]
    expr: |
      SUM(
        CASE
          WHEN clean_claim_status <> 'PAID'
               THEN charges
          ELSE 0
        END
      )

  - name: allowed
    scope: descriptor
    requires: [allowed]
    expr: "SUM(allowed)"

  - name: count_of_denied_claims
    scope: descriptor
    requires: [clean_claim_status]
    expr: "COUNTIF(clean_claim_status = 'DENIED')"

  - name: count_of_claims
    scope: descriptor
    requires: [claim_id]
    expr: "COUNT(DISTINCT claim_id)"

  - name: out_of_network_allowed
    scope: descriptor
    requires: [is_oon, allowed]
    expr: |
      SUM(
        CASE
          WHEN is_oon
               THEN allowed
          ELSE 0
        END
      )

  - name: utilization
    scope: descriptor
    requires: [utilization]
    expr: "SUM(utilization)"

  - name: units_days
    scope: descriptor
    requires: [hcg_units_days]
    expr: "SUM(hcg_units_days)"

  # Simplified avg days calculation
  - name: avg_days_service_to_paid
    scope: descriptor
    requires: [clean_claim_status, clean_claim_out, claim_from]
    expr: |
      AVG(
        CASE
          WHEN clean_claim_status = 'PAID'
          THEN DATE_DIFF(clean_claim_out, claim_from, DAY)
          ELSE NULL
        END
      )

  # Membership metrics (norm scope)
  - name: member_months
    scope: norm
    requires: [ra_mm]
    expr: "SUM(ra_mm)"

  - name: unique_members_enrolled
    scope: norm
    requires: [member_id]
    expr: "COUNT(DISTINCT member_id)"
"""

with open(DIM_PATH, "w") as f:
    f.write(yaml_file)

# Loading BigQuery and Google Docs utils

In [None]:
# ── BigQuery helpers ───────────────────────────────────────────

"""
Thin BigQuery helper with retry/back-off baked in.
"""
from __future__ import annotations
import logging
from typing import Union
import pandas as pd
from google.cloud import bigquery
from google.api_core import retry
from google.cloud.exceptions import GoogleCloudError
from google.protobuf.json_format import MessageToJson
import proto
import json
from copy import deepcopy
from vertexai.generative_models import Content, Part
import math

# Exponential back-off for retriable errors
_default_retry = retry.Retry(
    predicate=retry.if_exception_type(
        GoogleCloudError,        # transient network / 5xx
    ),
    deadline=600,               # overall max 10 min
)

_bq_client: bigquery.Client | None = None
_CREDS      = None

def get_client() -> bigquery.Client:
    global _bq_client, _CREDS
    if _bq_client is None:
        _CREDS, _ = google.auth.default(
            # makes colab pick up the correct billing / quota project
            quota_project_id=BILLING_PROJECT_ID_BQ
        )
        _bq_client = bigquery.Client(
            project=BILLING_PROJECT_ID_BQ,
            credentials=_CREDS,
            location=REGION_BIGQUERY
        )
    return _bq_client

def run_query(
    sql: str,
    job_config: bigquery.job.QueryJobConfig | None = None,
    max_results: int | None = None,
) -> Union[pd.DataFrame, dict]:
    """
    Executes `sql` with built-in retry.
    Returns pandas DataFrame or {"error": "..."} on failure.
    """
    #print("BQ QUERY:\n%s", sql)
    client = get_client()
    try:
        if job_config is None:
            job_config = bigquery.QueryJobConfig()

        # Pass the job_config to the query method
        job = client.query(sql, job_config=job_config, location=REGION_BIGQUERY, retry=_default_retry)
        result_iter = job.result()              # RowIterator
        df = result_iter.to_dataframe(
            create_bqstorage_client=True
        )
        if max_results is not None and len(df) > max_results:
            df = df.head(max_results)           # truncate in pandas
        print("BQ rows fetched: %d", len(df))
        return df
    except Exception as exc:
        print("BQ query failed")
        return {"error": str(exc)}

# ────────────────────────────────────────────────────────────────────
#  CREATE THE DATA TABLE
# ────────────────────────────────────────────────────────────────────
def execute_creation_query(sql):
    job = get_client().query(sql, job_config=bigquery.QueryJobConfig(), location=REGION_BIGQUERY)
    return job.result()

from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload

docs = build("docs", "v1", credentials=CREDS)
drive = build("drive", "v3", credentials=CREDS)

def create_or_open_doc(title:str) -> str:
    doc = docs.documents().create(body={"title": title}).execute()
    return doc["documentId"]

def append_text(doc_id:str, text:str):
    docs.documents().batchUpdate(
        documentId=doc_id,
        body={"requests": [
            {"insertText": {"location": {"index": 1}, "text": text + "\n\n"}}
        ]}).execute()

def upload_png(local_path:str) -> str:
    meta = {"name": local_path, "mimeType": "image/png"}
    media = MediaFileUpload(local_path, mimetype="image/png")
    f = drive.files().create(body=meta, media_body=media, fields="id,webContentLink").execute()
    # make it publicly readable so Docs can fetch it
    drive.permissions().create(fileId=f["id"], body={"type": "anyone", "role": "reader"}).execute()
    return f["webContentLink"]

def insert_image(doc_id:str, image_url:str, pt_height=250):
    docs.documents().batchUpdate(
        documentId=doc_id,
        body={"requests":[
            {"insertInlineImage":{
                "location":{"index":1},
                "uri": image_url,
                "objectSize":{"height":{"magnitude":pt_height,"unit":"PT"}}
            }}
        ]}).execute()

# Dataset compliation helpers

In [None]:
# ── Dataset compilation helpers ───────────────────────────────────────────

from __future__ import annotations
import textwrap, time, yaml, google.auth
from pathlib import Path
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
import re
from typing import List, Dict, Any, Set

# If the user wrote an unqualified column (no '.') prepend correct alias
BARE_IDENTIFIER = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")   # no dot, no paren

SQL_KEYWORDS = {
    "CASE", "WHEN", "ELSE", "END",
    "SELECT", "WHERE", "AND", "OR",
}

# ────────────────────────────────────────────────────────────────────
#  LOAD & NORMALISE DIMENSION REGISTRY
# ────────────────────────────────────────────────────────────────────
class Dim:
    def __init__(self, d: dict, aliases: dict[str, str]):
        self.name = d["name"]
        self.scope = d.get("scope", "descriptor")
        self.description = d.get("description", "No description available")
        raw_expr = textwrap.dedent(d["expr"]).strip().format(**aliases)

        first_tok = raw_expr.split()[0]
        if (BARE_IDENTIFIER.match(first_tok)
            and first_tok.upper() not in SQL_KEYWORDS):
            if self.scope == "norm":
                alias = aliases["norm_alias"]
            elif self.scope == "descriptor":
                alias = aliases["descriptor_alias"]
            else:  # scope == "both" - this is the problem case
                # For "both" scope, we need different behavior in different contexts
                # We'll handle this in the SQL generation by passing context
                alias = None

            if alias:
                raw_expr = f"{alias}.{raw_expr}"

        self.expr = raw_expr
        self.raw_column = d["expr"].strip() if BARE_IDENTIFIER.match(d["expr"].strip()) else None

class Metric:
    def __init__(self, d: dict, aliases: dict[str, str]):
        self.name = d["name"]
        self.scope = d.get("scope", "descriptor")
        raw_expr = textwrap.dedent(d["expr"]).strip().format(**aliases)
        self.expr = raw_expr
        self.requires = [c.strip() for c in d.get("requires", [])]

def load_registry(path: Path):
    doc = yaml.safe_load(path.read_text())
    aliases = {
        "descriptor_alias": doc["_meta"]["descriptor_alias"],
        "norm_alias": doc["_meta"]["norm_alias"]
    }

    dims_all = [Dim(d, aliases) for d in doc["dimensions"]]
    metrics = [Metric(m, aliases) for m in doc["metrics"]]

    descriptor_dims = [d for d in dims_all if d.scope in ("descriptor", "both")]
    norm_dims = [d for d in dims_all if d.scope in ("norm", "both")]

    descriptor_metrics = [m for m in metrics if m.scope in ("descriptor", "both")]
    norm_metrics = [m for m in metrics if m.scope in ("norm", "both")]

    return descriptor_dims, norm_dims, descriptor_metrics, norm_metrics, aliases

def load_config(path: Path):
    doc = yaml.safe_load(path.read_text())

    targets = doc["_targets"]
    years_cfg = {int(y): cfg for y, cfg in doc["_years"].items()}
    filters = doc["_filters"]
    norm_joins_dict = doc.get("_norm_joins") or {}
    descriptor_joins_dict = doc.get("_descriptor_joins") or {}
    norm_joins = list(norm_joins_dict.values()) if norm_joins_dict else []
    descriptor_joins = list(descriptor_joins_dict.values()) if descriptor_joins_dict else []

    return targets, years_cfg, filters, norm_joins, descriptor_joins

# ────────────────────────────────────────────────────────────────────
#  SQL BUILDERS
# ────────────────────────────────────────────────────────────────────
def _where(filter_key: str, year: int, filters: dict, months: dict) -> str:
    parts = []
    parts += filters.get("common", [])
    parts += filters.get(filter_key, [])
    # year / month limits
    mm = ", ".join(map(str, months[year]))
    parts += [f"EXTRACT(YEAR FROM m.month) = {year}",
              f"EXTRACT(MONTH FROM m.month) IN ({mm})"]
    return " AND ".join(parts)

def resolve_dim_expr_for_context(dim: Dim, context_alias: str, aliases: dict) -> str:
    """Resolve dimension expression for specific context"""
    if dim.scope != "both" or not dim.raw_column:
        return dim.expr

    # For "both" scope with simple column names, use the context alias
    return f"{context_alias}.{dim.raw_column}"

def year_norm_cte(year: int, config: dict) -> str:
    """Build normalization CTE for a given year - SIMPLIFIED without window functions"""
    descriptor_dims, norm_dims, descriptor_metrics, norm_metrics, aliases = config["registry"]
    targets, years, filters, norm_joins, descriptor_joins = config["config"]

    # Get columns that are needed for metrics but not already in dimensions
    raw_cols_norm = sorted({c for m in norm_metrics for c in m.requires})
    norm_dim_names = {d.name for d in norm_dims}
    extra_cols = [c for c in raw_cols_norm if c not in norm_dim_names]

    # Remove member_id and month from extra_cols since we'll handle them specially
    essential_cols = ['member_id', 'month']
    extra_cols = [c for c in extra_cols if c not in essential_cols]

    # Build select list - aggregate at member level to avoid month duplication
    norm_select_parts = []
    for d in norm_dims:
        # Skip month-related dimensions that would cause duplication
        if d.name in ['month']:
            continue
        expr = resolve_dim_expr_for_context(d, aliases["norm_alias"], aliases)
        # Use MAX for dimensions that should be the same across months for a member
        if d.raw_column and d.raw_column in ['hios_id', 'state', 'plan_network_access_type', 'plan_metal', 'age_group']:
            norm_select_parts.append(f"MAX({expr}) AS {d.name}")
        else:
            # For calculated dimensions, use ANY_VALUE since they should be consistent
            norm_select_parts.append(f"ANY_VALUE({expr}) AS {d.name}")

    norm_select_list = ",\n              ".join(norm_select_parts)

    # For metrics, we need to sum across months
    extra_select_parts = []
    for col in extra_cols:
        if col == 'ra_mm':
            extra_select_parts.append(f"SUM(m.{col}) AS {col}")
        else:
            extra_select_parts.append(f"MAX(m.{col}) AS {col}")  # Most other cols should be same across months

    extra_select = ",\n              ".join(extra_select_parts) if extra_select_parts else ""

    select_cols_parts = [norm_select_list]
    if extra_select:
        select_cols_parts.append(extra_select)
    select_cols = ",\n              ".join(select_cols_parts)

    where_sql = _where("norm", year, filters, {y: cfg["months"] for y, cfg in years.items()})

    # Build joins
    extra_join_sql = ""
    if norm_joins:
        join_snippets = []
        for j in norm_joins:
            formatted_join = textwrap.dedent(j).format(
                claim_snapshot=years[year]["claim_snapshot"],
                norm_alias=aliases["norm_alias"],
                descriptor_alias=aliases["descriptor_alias"],
            )
            join_snippets.append(formatted_join)
        extra_join_sql = "\n".join(join_snippets)

    member_snapshot = years[year]["member_snapshot"]

    return f"""
    norm_{year} AS (
      SELECT m.member_id,
             {select_cols}
      FROM {member_snapshot} m
      {extra_join_sql}
      WHERE {where_sql}
      GROUP BY m.member_id
    )""".strip()

def year_descriptor_cte(year: int, config: dict) -> str:
    """Build descriptor CTE for a given year - FIXED to prevent claim duplication"""
    descriptor_dims, norm_dims, descriptor_metrics, norm_metrics, aliases = config["registry"]
    targets, years, filters, norm_joins, descriptor_joins = config["config"]

    # Get columns needed for descriptor metrics
    raw_cols_descriptor = sorted({c for m in descriptor_metrics for c in m.requires})

    # Build select lists
    descriptor_select_parts = []
    for d in descriptor_dims:
        if d.scope == "both":
            # For "both" scope dimensions, we need to get them from the member data
            # But we'll join on member_id only (not month) to avoid duplication
            expr = f"m.{d.name}"
        else:
            expr = d.expr
        descriptor_select_parts.append(f"{expr} AS {d.name}")

    descriptor_select_list = ",\n              ".join(descriptor_select_parts)
    extra_select = ",\n              ".join(f"clc.{c}" for c in raw_cols_descriptor)

    select_cols_parts = [descriptor_select_list]
    if extra_select:
        select_cols_parts.append(extra_select)
    select_cols = ",\n              ".join(select_cols_parts)

    # Build where clause
    where_parts = []
    where_parts += filters.get("common", [])
    where_parts += filters.get("descriptor", [])
    mm = ", ".join(map(str, years[year]["months"]))
    where_parts += [f"EXTRACT(YEAR FROM clc.month) = {year}",
                   f"EXTRACT(MONTH FROM clc.month) IN ({mm})"]
    mm_where = " AND ".join(where_parts)

    claim_snapshot = years[year]["claim_snapshot"]

    # Join on member_id only to avoid month-based duplication
    return f"""
    descriptor_{year} AS (
      SELECT {select_cols}
      FROM {claim_snapshot} clc
      INNER JOIN norm_{year} m ON clc.member_id = m.member_id
      WHERE {mm_where}
    )""".strip()

def build_descriptor_sql(config: dict) -> str:
    """Full CREATE OR REPLACE TABLE … AS … sql text."""
    descriptor_dims, norm_dims, descriptor_metrics, norm_metrics, aliases = config["registry"]
    targets, years, filters, norm_joins, descriptor_joins = config["config"]
    destination_project = config["destination_project"]

    ctes: list[str] = []

    # 1) Build CTEs for each year
    for yr in years:
        ctes.append(year_norm_cte(int(yr), config))
        ctes.append(year_descriptor_cte(int(yr), config))

    # 2) Union all years
    union_parts = [f"SELECT * FROM descriptor_{yr}" for yr in years]
    ctes.append(f"""
        descriptor_combined AS (
            {' UNION ALL '.join(union_parts)}
        )
    """.strip())

    ctes_sql = ",\n".join(ctes)

    # Build final select
    descriptor_metric_select = ",\n    ".join(f"{m.expr} AS {m.name}" for m in descriptor_metrics)
    descriptor_group_by = ", ".join(d.name for d in descriptor_dims)

    sql = f"""
    CREATE OR REPLACE TABLE {destination_project}.{targets['dataset']}.{targets['descriptor_table']} AS
    WITH
    {ctes_sql}
    SELECT
        {", ".join(d.name for d in descriptor_dims)},
        {descriptor_metric_select}
    FROM descriptor_combined clc
    GROUP BY {descriptor_group_by}
    """.strip()

    return sql

def build_norm_sql(config: dict) -> str:
    """Build normalization table SQL - FIXED to avoid member duplication"""
    descriptor_dims, norm_dims, descriptor_metrics, norm_metrics, aliases = config["registry"]
    targets, years, filters, norm_joins, descriptor_joins = config["config"]
    destination_project = config["destination_project"]

    ctes = [year_norm_cte(int(yr), config) for yr in years]

    # Union all years
    union_parts = [f"SELECT * FROM norm_{yr}" for yr in years]
    ctes.append(f"""
        norm_combined AS (
            {' UNION ALL '.join(union_parts)}
        )
    """.strip())

    ctes_sql = ",\n".join(ctes)

    # Build final select
    norm_metric_select = ",\n    ".join(f"{m.expr} AS {m.name}" for m in norm_metrics)
    # Remove member_id from group by since we're now aggregating across members
    norm_group_by_dims = [d.name for d in norm_dims if d.name not in ['member_id', 'month']]
    norm_group_by = ", ".join(norm_group_by_dims)

    return f"""
        CREATE OR REPLACE TABLE {destination_project}.{targets['dataset']}.{targets['norm_table']} AS
        WITH
        {ctes_sql}
        SELECT
            {", ".join(d.name for d in norm_dims if d.name not in ['member_id', 'month'])},
            {norm_metric_select}
        FROM norm_combined m
        GROUP BY {norm_group_by}
    """.strip()

# ────────────────────────────────────────────────────────────────────
#  MAIN CREATION FUNCTION
# ────────────────────────────────────────────────────────────────────
def create_table(dim_path: Path, destination_project: str, region: str):
    """Create both descriptor and norm tables"""

    # Load configuration
    registry = load_registry(dim_path)
    config_data = load_config(dim_path)

    # Package everything together
    config = {
        "registry": registry,
        "config": config_data,
        "destination_project": destination_project
    }

    # Get BigQuery client
    client = bigquery.Client(project=destination_project)
    dataset_id = f"{destination_project}.{config_data[0]['dataset']}"

    # Create dataset if it doesn't exist
    try:
        client.get_dataset(dataset_id)
    except NotFound:
        ds = bigquery.Dataset(dataset_id)
        ds.location = region
        client.create_dataset(ds)
        time.sleep(3)

    # Build and execute SQL
    sql_descriptor = build_descriptor_sql(config)
    sql_norm = build_norm_sql(config)

    print("DESCRIPTOR SQL:")
    print(sql_descriptor)
    print("\nNORM SQL:")
    print(sql_norm)

    # Execute queries
    try:
        job1 = client.query(sql_descriptor)
        result1 = job1.result()
        print(f"✓ {config_data[0]['descriptor_table']} rebuilt.")

        job2 = client.query(sql_norm)
        result2 = job2.result()
        print(f"✓ {config_data[0]['norm_table']} rebuilt.")

    except Exception as e:
        print(f"Error executing queries: {e}")
        raise

# Creating the compiled data cubes, descriptor and norm

In [None]:
create_table(Path(DIM_PATH), BILLING_PROJECT_ID_BQ, REGION_BIGQUERY)

DESCRIPTOR SQL:
CREATE OR REPLACE TABLE oscaractuarial.zone_djf_llm.trend_llm_descriptor AS
    WITH
    norm_2023 AS (
      SELECT m.member_id,
             MAX(m.hios_id) AS hios_id,
              MAX(m.state) AS state,
              MAX(m.plan_network_access_type) AS plan_network_access_type,
              ANY_VALUE(EXTRACT(YEAR FROM m.month)) AS year,
              ANY_VALUE(CASE
  WHEN SUBSTR(m.hios_id,1,5) = '29341'
      THEN 'OHB (Columbus)'
  WHEN SUBSTR(m.hios_id,1,5) = '45845'
      THEN 'OHC (Cleveland Clinic Product)'
  WHEN m.state = 'TX'
      THEN CONCAT(m.state, '-', m.plan_network_access_type)
  ELSE m.state
END) AS geographic_reporting,
              MAX(m.plan_metal) AS plan_metal,
              MAX(m.age_group) AS age_group,
              ANY_VALUE(m.gender) AS gender,
              ANY_VALUE(m.region) AS region,
              ANY_VALUE(m.enrollment_length_continuous) AS enrollment_length_continuous,
              ANY_VALUE(m.clinical_segment) AS clinical_segment,

# Verifying the compiled data cubes

Show the two data tables we created

In [None]:
from google.cloud import bigquery
import pandas as pd

config = load_config(Path(DIM_PATH))
targets, years, filters, norm_joins, descriptor_joins = config

# The two tables we generated
tables = [targets['norm_table'], targets['descriptor_table']]

dataset_id = config[0]['dataset']

TABLE_FQN = [f"{DESTINATION_PROJECT_ID_BQ}.{dataset_id}.{table}" for table in tables]

client = bigquery.Client(project=DESTINATION_PROJECT_ID_BQ)

for table in TABLE_FQN:
    # 1️⃣  Print schema
    tbl = client.get_table(table)
    print(f"✅ Table {table} exists")
    print(f"• Total rows (cached metadata): {tbl.num_rows:,}\n")
    print("Schema:")
    for f in tbl.schema:
        print(f"  • {f.name:<35} {f.field_type} ({f.mode})")

    # ------------------------------------------------------------------
    # 2️⃣  Simple row count
    row_cnt = client.query(f"SELECT COUNT(*) AS row_count FROM `{table}`").to_dataframe()
    print("\nRow count from fresh query:")
    display(row_cnt)

    # ------------------------------------------------------------------
    # 3️⃣  Preview a few records
    preview = client.query(f"""
        SELECT *
        FROM `{table}`
        LIMIT 10
    """).to_dataframe()
    print("\nPreview of first 10 rows:")
    display(preview)

# Defining data access functions for the AI

In [None]:
# ── Query interface for cube tables ───────────────────────────────────────

import re
import json
from typing import List, Dict, Any, Set
from pathlib import Path
from google.cloud import bigquery


# Build dimension metadata
dimension_metadata = {}

descriptor_dims, norm_dims, descriptor_metrics, norm_metrics, aliases = load_registry(Path(DIM_PATH))
targets, years, filters, norm_joins, descriptor_joins = load_config(Path(DIM_PATH))

# Add descriptor dimensions
for d in descriptor_dims:
    dimension_metadata[d.name] = {
        "source_table": "descriptor",
        "original_column": d.name,
        "description": d.description
    }

# Add norm dimensions
for d in norm_dims:
    dimension_metadata[d.name] = {
        "source_table": "norm",
        "original_column": d.name,
        "description": d.description
    }

# Build table references
claim_table_id = f"{BILLING_PROJECT_ID_BQ}.{targets['dataset']}.{targets['descriptor_table']}"
membership_table = f"{BILLING_PROJECT_ID_BQ}.{targets['dataset']}.{targets['norm_table']}"

# Get metric names
descriptor_metric_names = [m.name for m in descriptor_metrics]
norm_metric_names = [m.name for m in norm_metrics]

# Year bounds
years_min, years_max = min(years), max(years)

# ────────────────────────────────────────────────────────────────────
#  QUERY HELPER FUNCTIONS
# ────────────────────────────────────────────────────────────────────

def _coerce_bool(col: str, v: Any):
    """Convert 'TRUE'/'FALSE' string to real bool when appropriate"""
    if col in ["in_network", "is_out_of_network", "member_called_oscar", "member_used_app", "member_had_web_login"]:
        if isinstance(v, str) and v.upper() in {"TRUE", "FALSE"}:
            return v.upper() == "TRUE"
    return v

def _format_value(v: Any) -> str:
    """Format values for SQL"""
    if isinstance(v, bool):
        return str(v).upper()
    if isinstance(v, (int, float)):
        return str(v)
    if v is None:
        return "NULL"
    # Escape backslashes first, then single quotes
    return "'" + str(v).replace("\\", "\\\\").replace("'", "''") + "'"

# Allowed operators
_ALLOWED_OPS = {
    "=", "==", "EQUAL", "EQ",
    "!=", "<>", "<", ">", "<=", ">=",
    "IN", "NOT IN",
    "LIKE", "NOT LIKE",
    "IS NULL", "IS NOT NULL",
    "BETWEEN",
}

# Friendly aliases for null checks
_NULL_ALIASES = {
    "IS_NULL": "IS NULL",
    "IS_NOT_NULL": "IS NOT NULL",
    "ISNULL": "IS NULL",
    "ISNOTNULL": "IS NOT NULL",
}
_ALLOWED_OPS.update(_NULL_ALIASES.keys())


def _validate_dim(key: str) -> Dict[str, str]:
    """Validate and resolve dimension name"""
    if key not in dimension_metadata:
        raise KeyError(f"Unknown dimension: {key}")
    return dimension_metadata[key]

def _dims_for(source_table: str, dims: List[str]) -> Set[str]:
    """Get dimensions that belong to a specific source table"""
    return {
        d for d in dims
        if dimension_metadata[d]["source_table"] == source_table
    }

def get_trend_data(
    group_by_dimensions: List[str] | None = None,
    filters: List[dict] | None = None,
    top_n: int | None = None,
) -> Dict[str, Any]:
    """
    Returns yearly trend metrics from the YAML-defined descriptor and norm cubes.

    Parameters
    ----------
    group_by_dimensions : list[str]
        Dimension names (as in YAML) for GROUP BY.
    filters : list[dict]
        Each dict: {"dimension_name": str, "operator": str, "value": Any}.
        Supports =, !=, IN, NOT IN, IS NULL, IS NOT NULL, etc.
    top_n : int
        Return at most N rows per year (ordered by total_allowed DESC).

    Returns
    -------
    dict
        {"data": JSON string, "warning": …?}
        On error: {"error": "..."}
    """

    # Input validation and defaults
    user_supplied_top_n = top_n is not None
    top_n = int(top_n or 100)
    group_by_dimensions = group_by_dimensions or []
    filters = filters or []

    # Validate all dimension names
    try:
        for dim_name_iter in group_by_dimensions: # Renamed to avoid conflict
            _validate_dim(dim_name_iter)
        for f_iter in filters: # Renamed to avoid conflict
            _validate_dim(f_iter["dimension_name"])
    except KeyError as e:
        return {"error": f"Invalid dimension: {e}"}

    # Dimension processing for SQL generation
    memb_gb_cols = []
    claim_gb_cols = []
    for dim_name in group_by_dimensions:
        meta = dimension_metadata[dim_name]
        if meta["source_table"] == "norm":
            memb_gb_cols.append(meta["original_column"])
        elif meta["source_table"] == "descriptor":
            claim_gb_cols.append(meta["original_column"])
        # Handle 'both' scope dimensions if necessary, assuming they are treated as member dimensions for grouping
        # For simplicity, 'both' could be added to memb_gb_cols or handled based on specific logic if needed
        # Current YAML parsing adds 'both' to both norm_dims and descriptor_dims
        # If a 'both' dim is in group_by_dimensions, it will be in claim_table_id. We'll primarily use its member aspect for cohort definition.

    memb_filter_dim_cols = []
    claim_filter_dim_cols = []
    for f_spec in filters: # Renamed to avoid conflict
        dim_name = f_spec["dimension_name"]
        meta = dimension_metadata[dim_name]
        if meta["source_table"] == "norm":
            if meta["original_column"] not in memb_filter_dim_cols:
                 memb_filter_dim_cols.append(meta["original_column"])
        elif meta["source_table"] == "descriptor":
            if meta["original_column"] not in claim_filter_dim_cols:
                claim_filter_dim_cols.append(meta["original_column"])

    # Ensure 'year' is not duplicated if manually added to _cols lists
    memb_gb_cols = [col for col in memb_gb_cols if col != "year"]
    # claim_gb_cols typically won't include 'year' as it's handled separately.

    all_memb_grouping_cols = sorted(list(set(memb_gb_cols + memb_filter_dim_cols)))

    # Filter SQL string generation
    client = get_client() # Get BigQuery client for quoting names if necessary, or assume simple names

    year_filter_clauses = []
    if not any(f_spec.get("dimension_name") == "year" for f_spec in filters): # Renamed to avoid conflict
        year_filter_clauses.append(f"year BETWEEN {years_min} AND {years_max}")


    # Helper to build filter clauses for specific tables
    def build_filter_clauses(table_alias, filter_list):
        clauses = []
        for f_spec_inner in filter_list: # Renamed to avoid conflict
            dim_name = f_spec_inner["dimension_name"]
            meta = _validate_dim(dim_name)
            col_name = meta["original_column"]
            op_in = str(f_spec_inner["operator"]).upper()
            op = _NULL_ALIASES.get(op_in, op_in)

            if op not in _ALLOWED_OPS:
                # This should be an error raised, or handled to return an error response
                raise ValueError(f"Operator {op} not allowed.")

            val = _coerce_bool(col_name, f_spec_inner.get("value"))

            # Apply table alias
            aliased_col = f"{table_alias}.{col_name}"

            if op in ("IN", "NOT IN"):
                if isinstance(val, list) and val:
                    formatted_vals = ", ".join(_format_value(v_item) for v_item in val) # Renamed to avoid conflict
                    clauses.append(f"{aliased_col} {op} ({formatted_vals})")
                elif not val: # Empty list for IN/NOT IN can be problematic or mean "match nothing"/"match everything"
                    if op == "IN": clauses.append("FALSE") # Match nothing
                    else: clauses.append("TRUE") # Match everything
                else: # Single value for IN/NOT IN
                    clauses.append(f"{aliased_col} {op} ({_format_value(val)})")

            elif op in ("IS NULL", "IS NOT NULL"):
                clauses.append(f"{aliased_col} {op}")
            elif op == "BETWEEN" and isinstance(val, list) and len(val) == 2:
                clauses.append(f"{aliased_col} {op} {_format_value(val[0])} AND {_format_value(val[1])}")
            else:
                clauses.append(f"{aliased_col} {op} {_format_value(val)}")
        return clauses

    # Filters for claims_data CTE (alias 'c')
    # These include claim-dim filters and member-dim filters (as member dims are in claim_table_id)
    claim_cte_filter_clauses = []
    for f_item in filters: # Renamed variable
        meta = dimension_metadata[f_item["dimension_name"]]
        # All filters are potentially applicable if the column exists in claim_table_id
        claim_cte_filter_clauses.extend(build_filter_clauses("c", [f_item]))

    effective_claims_where_parts = claim_cte_filter_clauses + [f"c.{yclause}" for yclause in year_filter_clauses]
    effective_claims_where_sql = " AND ".join(effective_claims_where_parts) if effective_claims_where_parts else "TRUE"


    # Filters for membership_data and membership_parent_data CTEs (alias 'm')
    membership_cte_filter_clauses = []
    for f_item in filters: # Renamed variable
        meta = dimension_metadata[f_item["dimension_name"]]
        if meta["source_table"] == "norm": # Only apply norm filters to membership table
            membership_cte_filter_clauses.extend(build_filter_clauses("m", [f_item]))

    effective_membership_where_parts = membership_cte_filter_clauses + [f"m.{yclause}" for yclause in year_filter_clauses]
    effective_membership_where_sql = " AND ".join(effective_membership_where_parts) if effective_membership_where_parts else "TRUE"

    claims_data_group_by_cols_sql = ["c.year"] + \
                                   [f"c.{col}" for col in all_memb_grouping_cols] + \
                                   [f"c.{col}" for col in claim_gb_cols]
    claims_data_group_by_cols_sql = sorted(list(set(claims_data_group_by_cols_sql))) # Deduplicate

    claims_data_select_cols_sql = claims_data_group_by_cols_sql[:] # Select the grouping columns

    metric_sums_sql = [f"SUM(c.{m}) AS total_{m}" for m in descriptor_metric_names]
    claims_data_select_cols_sql.extend(metric_sums_sql)

    claims_data_cte_sql = f"""
      claims_data AS (
        SELECT
          {", ".join(claims_data_select_cols_sql)}
        FROM {claim_table_id} c
        WHERE {effective_claims_where_sql}
        GROUP BY {", ".join(claims_data_group_by_cols_sql)}
      )
    """

    membership_data_group_by_cols_sql = ["m.year"] + [f"m.{col}" for col in all_memb_grouping_cols]
    membership_data_group_by_cols_sql = sorted(list(set(membership_data_group_by_cols_sql)))

    membership_data_select_cols_sql = membership_data_group_by_cols_sql[:]
    membership_data_select_cols_sql.extend([
        "SUM(m.member_months) AS member_months",
        "SUM(m.unique_members_enrolled) AS unique_members_enrolled"
    ])

    membership_data_cte_sql = f"""
      membership_data AS (
        SELECT
          {", ".join(membership_data_select_cols_sql)}
        FROM {membership_table} m
        WHERE {effective_membership_where_sql}
        GROUP BY {", ".join(membership_data_group_by_cols_sql)}
      )
    """

    parent_membership_group_by_cols_sql = ["m.year"] + [f"m.{col}" for col in memb_gb_cols]
    parent_membership_group_by_cols_sql = sorted(list(set(parent_membership_group_by_cols_sql)))

    parent_membership_select_cols_sql = parent_membership_group_by_cols_sql[:]
    parent_membership_select_cols_sql.append("SUM(m.member_months) AS parent_member_months")

    membership_parent_data_cte_sql = f"""
      membership_parent_data AS (
        SELECT
          {", ".join(parent_membership_select_cols_sql)}
        FROM {membership_table} m
        WHERE {effective_membership_where_sql} -- Same filters as membership_data
        GROUP BY {", ".join(parent_membership_group_by_cols_sql)}
      )
    """

    # Build FILTERS
    alias_for_src = {"descriptor": "c", "norm": "m"}
    where_parts: List[str] = []

    for f in filters:
        dim = f["dimension_name"]
        meta = _validate_dim(dim)
        col = f"{meta['original_column']}"
        op_in = str(f["operator"]).upper()

        if op_in in {"=", "==", "EQ", "EQUAL"}:
            op = "="
        elif op_in in _NULL_ALIASES:
            op = _NULL_ALIASES[op_in]
        else:
            op = op_in

        if op not in _ALLOWED_OPS:
            return {"error": f"Operator {op} not allowed."}

        val = _coerce_bool(col, f.get("value"))
        alias = alias_for_src[meta["source_table"]]

        if op in ("IN", "NOT IN") and isinstance(val, list):
            in_vals = ", ".join(_format_value(v) for v in val)
            where_parts.append(f"{alias}.{col} {op} ({in_vals})")
        elif op in ("IS NULL", "IS NOT NULL"):
            where_parts.append(f"{alias}.{col} {op}")
        else:
            where_parts.append(f"{alias}.{col} {op} {_format_value(val)}")

    # Add year constraint if not specified
    if not any(f.get("dimension_name") == "year" for f in filters):
        where_parts.append(f"c.year BETWEEN {years_min} AND {years_max}")

    # Final SELECT statement construction
    final_select_list = ["cc.year"]

    # Add grouping dimensions to select list (user's original group_by_dimensions)
    for col in memb_gb_cols: # member dimensions from group_by_dimensions
        final_select_list.append(f"cc.{col}")
    for col in claim_gb_cols: # claim dimensions from group_by_dimensions
        final_select_list.append(f"cc.{col}")

    final_select_list = sorted(list(set(final_select_list))) # Deduplicate, year might be in memb_gb_cols

    # Add aggregated metrics from claims_data
    for m_name in descriptor_metric_names: # Renamed to avoid conflict
        final_select_list.append(f"cc.total_{m_name}")

    # Add metrics from membership_data and membership_parent_data
    final_select_list.append("mc.member_months")
    final_select_list.append("mc.unique_members_enrolled")
    final_select_list.append("mpc.parent_member_months")

    # Add derived PMPM and other metrics
    # Ensure total_allowed and total_utilization are present in descriptor_metric_names for these calcs
    safe_total_allowed = "cc.total_allowed" if "allowed" in descriptor_metric_names else "0"
    safe_total_util = "cc.total_utilization" if "utilization" in descriptor_metric_names else "0"
    safe_total_units_days = "cc.total_units_days" if "units_days" in descriptor_metric_names else "0"

    # Check if 'member_months' and 'parent_member_months' could be zero before division.
    # The NULLIF in SAFE_DIVIDE handles this.
    final_select_list.extend([
        f"ROUND(SAFE_DIVIDE({safe_total_allowed}, NULLIF(mc.member_months, 0)), 4) AS allowed_pmpm",
        f"ROUND(SAFE_DIVIDE({safe_total_allowed}, NULLIF(mpc.parent_member_months, 0)), 4) AS parent_allowed_pmpm",
        f"ROUND(SAFE_DIVIDE({safe_total_util} * 12000, NULLIF(mc.member_months, 0)), 6) AS utilization_pkpy",
        f"ROUND(SAFE_DIVIDE({safe_total_allowed}, NULLIF({safe_total_util}, 0)), 6) AS cost_per_service",
        f"ROUND(SAFE_DIVIDE({safe_total_units_days}, NULLIF({safe_total_util}, 0)), 6) AS length_of_stay"
    ])
    if "charges" in descriptor_metric_names:
        final_select_list.append(f"ROUND(SAFE_DIVIDE(cc.total_allowed, NULLIF(cc.total_charges, 0)), 2) AS allowed_to_billed_ratio")
    if "count_of_claims" in descriptor_metric_names and "avg_days_service_to_paid" in descriptor_metric_names: # avg_days needs total_sum_of_days and total_count
         pass


    # Join conditions
    # Join claims_data (cc) with membership_data (mc) on year and all_memb_grouping_cols
    join_cc_mc_conditions = ["cc.year = mc.year"] + \
                            [f"cc.{col} = mc.{col}" for col in all_memb_grouping_cols]

    # Join claims_data (cc) with membership_parent_data (mpc) on year and memb_gb_cols
    join_cc_mpc_conditions = ["cc.year = mpc.year"] + \
                             [f"cc.{col} = mpc.{col}" for col in memb_gb_cols]

    sql = f"""
    WITH
      {claims_data_cte_sql},
      {membership_data_cte_sql},
      {membership_parent_data_cte_sql}
    SELECT
      {", ".join(final_select_list)}
    FROM claims_data cc
    LEFT JOIN membership_data mc ON {" AND ".join(join_cc_mc_conditions)}
    LEFT JOIN membership_parent_data mpc ON {" AND ".join(join_cc_mpc_conditions)}
    QUALIFY ROW_NUMBER() OVER (PARTITION BY cc.year ORDER BY {safe_total_allowed} DESC) <= {top_n}
    ORDER BY cc.year, {safe_total_allowed} DESC
    """

    # debugging sql
    # print("tool call sql debugging ---------------")
    # print(sql)
    # print("end tool call sql debugging ---------------")

    try:
        query_job = client.query(sql)
        df = query_job.to_dataframe()

        result = {"data": df.to_json(orient="records", date_format="iso")}

        if not user_supplied_top_n and len(df) > 5_000:
            result["warning"] = (
                f"Result set contains {len(df)} rows. "
                "Consider using 'top_n' or adding filters."
            )
        return result

    except Exception as e:
        return {"error": f"Query execution failed: {str(e)}"}

def list_available_dimensions():
    """Return all available dimensions with their metadata"""
    return {"data": json.dumps(dimension_metadata, indent=2)}

def get_dimension_values(dimension_name: str):
    """Get distinct values for a given dimension from the appropriate cube table"""
    try:
        meta = _validate_dim(dimension_name)
    except KeyError:
        return {"error": f"Invalid dimension '{dimension_name}'."}

    # Determine which table to query based on the dimension's source
    if meta["source_table"] == "descriptor":
        table_id = claim_table_id
    elif meta["source_table"] == "norm":
        table_id = membership_table
    else:
        return {"error": f"Unknown source table '{meta['source_table']}' for dimension '{dimension_name}'."}

    sql = f"""
      SELECT DISTINCT `{meta['original_column']}` AS v
      FROM `{table_id}`
      WHERE `{meta['original_column']}` IS NOT NULL
      ORDER BY 1 LIMIT 500
    """

    try:
        query_job = client.query(sql)
        df = query_job.to_dataframe()

        # Convert to list, handling various data types
        values = df["v"].tolist()

        # Simple JSON serialization (expand if you need custom handling)
        try:
            return {"data": json.dumps(values, default=str)}
        except TypeError:
            # Fallback for non-serializable types
            return {"data": json.dumps([str(v) for v in values])}

    except Exception as e:
        return {"error": f"Query failed: {str(e)}"}

# Testing data access functions for the AI

In [None]:
# ── Usage Example ───────────────────────────────────────────

from pathlib import Path
from google.cloud import bigquery

# Example queries
# 1. Overall totals, no filters
print("=== TESTING: Overall trend data ===")
result = get_trend_data(filters=[
        {"dimension_name": "state", "operator": "=", "value": "FL"}
    ])
if "error" in result:
    print(f"Error: {result['error']}")
else:
    print("Success! Data returned.")
    print(result)
    if "warning" in result:
        print(f"Warning: {result['warning']}")

# 2. Split by geographic reporting and claim type
print("\n=== TESTING: Grouped by geographic_reporting and claim_type ===")
result = get_trend_data(
    group_by_dimensions=["geographic_reporting", "claim_type"],
    top_n=15
)
if "error" in result:
    print(f"Error: {result['error']}")
else:
    print("Success! Grouped data returned.")
    print(result)

# 3. Filter by specific state
print("\n=== TESTING: Filter by state ===")
result = get_trend_data(
    group_by_dimensions=["plan_metal", "claim_type"],
    filters=[
        {"dimension_name": "state", "operator": "=", "value": "NY"}
    ],
    top_n=10
)
if "error" in result:
    print(f"Error: {result['error']}")
else:
    print("Success! Filtered data returned.")
    print(result)

# 4. List available dimensions
print("\n=== TESTING: List dimensions ===")
result = list_available_dimensions()
print("Available dimensions loaded successfully")
print(result)

# 5. Get dimension values
print("\n=== TESTING: Get dimension values ===")
result = get_dimension_values("state")
if "error" in result:
    print(f"Error: {result['error']}")
else:
    print("Success! Dimension values returned.")
    print(result)

# 6. Test filtering by a member only dimensions
print("\n=== TESTING: overall allowed pmpm filtered by state ===")
result = get_trend_data(filters=[
        {"dimension_name": "state", "operator": "=", "value": "FL"}
    ])
if "error" in result:
    print(f"Error: {result['error']}")
else:
    print("Success! Data returned.")
    print(result)
    if "warning" in result:
        print(f"Warning: {result['warning']}")

# Defining the AI agent

In [None]:
BASE_SYSTEM = """
    You are a seasoned expert in **health insurance medical economics**, specializing in
    **medical expense trend analysis** and the development of **cost of care management strategies**.
    You provide insightful, data-driven explanations for trends and identify
    actionable opportunities for affordability initiatives.

    Your goal is explain the trend in chiropractic spending in Texas from 2023 to 2024.
    """

ANALYSIS_PLAN_TEMPLATE = """
    Analysis Protocol (Employing the Pyramid Principle):\n\n
    Your analysis will follow the Pyramid Principle, starting with the main finding (e.g., overall company trend)
    and then supporting it with successively more detailed layers of analysis. Your objective is to clearly
    communicate the story behind the medical expense trends, pinpoint the most significant drivers,
    and identify potential areas for trend management and affordability initiatives. Dig as deep as possible.

    Here are the principles of your analysis.
    1. You have two data sources available to you: all health insurance claims, and all members of the insurance plan.
    You can slice both sources by many different dimensions.
    2. The claims table has more dimensions available than the membership table:
    there are some dimensions that make sense to distinguish claims by, but they don't make sense to distinguish members by.
    For example, a claim clearly belongs to a particular provider - but a member does not clearly belong to that provider, because he could have seen other providers.
    3. We only care about comparing two specific periods: 2023 vs. 2024.
    We do not care about how the metrics changed over time.
    4. We care about spend per member per month. You can normalize all metrics from the claims table by dividing by member months from the membership table.
    So the drivers of the changes in spend may result from two different sources:
    a. either claims spend of a particular driver went up (e.g., left-handed members utilized more in 2023 than in 2024), or
    b. the mix between drivers changed (e.g., we have more left-handed members in 2024 than in 2023).

    You have the following tools available to you:
    - get_trend_data_tool: Get trend metrics with optional grouping, filtering, and top-N limits
    - list_available_dimensions_tool: See all available dimensions and their sources
    - get_dimension_values_tool: Get distinct values for any dimension
    - create_chart_tool: If a numeric comparison will be clearer as a picture, call this with a ChartSpec (x_dimension (label axis), y_metrics (numbers to graph), chart_type ("bar", "line", or "stacked_bar"), etc.)
    - write_google_doc_tool: Write your thoughts and conclusions into a Google Doc for the user to follow.

    Here is the plan you should follow for your analysis.
    1.  **High-Level Overview (The Apex of the Pyramid):**
        * Begin by reviewing the period-over-period trend.
        * State the overall trend clearly and concisely. This is your primary assertion.
    2.  **Iterative Drill-Down to Uncover Key Drivers (Building the Support):**
        * Decompose the total company pmpm trend by systematically exploring its components.
    At each step, identify and quantify the **largest contributing drivers** to the trend observed at the
    parent level before drilling further into those specific drivers.
            * **Significant Population Mix Shifts:** Monitor trends in different dimensions. Quantify how these shifts contribute
    to overall PMPM changes.
            * **Detailed Service Category Trends:** Examine the detaild services within each major cateogry to understand what specifically is driving. If a specific detailed service category is significant, drill in further to CPT codes, DRGs, etc.
            * **Operational Process Changes:** Use metrics like **percent_of_claims_denied** and **allowed_to_billed_ratio**. Significant changes in these metrics,
    especially when correlated with specific providers or service categories, can indicate operational
    inefficiencies, changes in claims processing, or provider billing practices that are impacting allowed spend.
            * **Out-of-Network Utilization Changes:** Pay attention to **is_out_of_network** trends and their impact on
    both utilization and cost per service. Significant shifts may point to network adequacy issues or changes
    in member steerage.
            * **Underlying Behavior Changes:** If you drilled all the way into a driver (e.g., certain claims activity), and activity is still up, then
            you may conclude that underlying behavior simply changed. See if you can find corroborating evidence.
            For example, after the recent pandemic, utilization increased because people had deferred care.
    3.  **Company-wide Summary & Key Drivers (The Base of the Pyramid - Conclusion & Recommendations):**
        * Synthesize your findings into a clear, concise summary of the high-level trends and the most
    significant sub-trends that explain them.
        * Clearly list the **key drivers** (e.g., specific service categories, geographies, population segments,
    provider entities, operational issues) and quantify their impact on the overall company and significant
    state-level trends.\n"
        * Based on your findings, suggest potential **trend management / affordability initiatives**. These should be
    specific and linked to the drivers you've identified (e.g., 'Investigate contract terms with X provider
    due to a Y% increase in cost_per_service for Z procedures,' or 'Develop a targeted member outreach program
    for demographic A, which shows increasing utilization of high-cost service B').
        * Include any notable failures in your tool calls that should be fixed for future analyses. You can also
    request new tools or data dimensions that would enhance your analytical capabilities.
    4.  Iteration budget:
        * You can make as many tool calls as you want.
        * For number of iterations, please do not exceed {LIMIT} iterations. Your primary goal is to utilize this entire budget for exploration before the final summarization phase.
        * When you are NOT YET in the final 3 iterations (i.e., {LIMIT} minus current iteration > 3) → you MUST output PLAN + tool call(s) for further investigation. Do not attempt to summarize or conclude the analysis.
        * Only when you are 3 or fewer iterations away from the limit (i.e., {LIMIT} minus current iteration <= 2) → THEN and ONLY THEN should you stop initiating new exploratory tool calls and instead:
            1. PLAN summary (briefly state what you would have investigated further if you had more calls, or confirm you are now summarizing).
            2. (optional) one last data call if absolutely essential for your summary figures.
            3. Output your "Company-wide Summary & Key Drivers (including actionable recommendations)". When you provide this final comprehensive summary, begin it with one of the exact phrases: "FINAL REPORT AND ANALYSIS CONCLUDED", "ABSOLUTE FINAL SUMMARY AND RECOMMENDATIONS", or "## COMPANY-WIDE SUMMARY & KEY DRIVERS FINAL DOCUMENT".
        * Never finish early for any other reason. Prematurely outputting one of the above final markers or concluding the analysis before the designated final 3 iterations is a failure to follow instructions.
    5.  You need to drill deeply. Just going down one or two levels is rarely ever sufficient.
        * It is your job to fully explore the space of possible drivers, using tool calls in each iteration, until you are in the final 3-iteration summarization phase.
        * If you believe you have run out of avenues before your iteration limit (excluding the final 3 for summarization) is reached, consider that a strong sign that you did not drill down sufficiently or creatively enough - AND KEEP GOING by formulating new hypotheses or examining existing data from different angles.
        * We want to understand ALL possible drivers, which requires using the full iterative process.
        * If you have exhaused one line of investigation, and still have remaining iterations, go back to the top and drill through a different path.

    For example, if a particular type of utilization goes up, and you cannot tell why, then perhaps look at the type of people who are utilizing, and if anything is insightful or changed there.
    WHY is something happening, and WHY the WHY, and so on? THAT is what you need to find out with all your data and analytical firepower.

    For every step output *exactly* three sections, in this order, and do not emit the tool-call JSON until after the REFLECT stage.
    PLAN:
    <detailing your hypothesis and what you expect to find>
    REFLECT:
    <interpreting the results and explaining how they inform your next step or confirm/refute your hypothesis>
    <write your results, observations and conclusions into the Google Doc in each step by calling the tool>
    <create a chart to illustrate your work in each step by calling the tool>
    TOOL:
    <name + JSON args>

    Important: do NOT call too many tools in one iteration, at most 3. Then summarize, write to the document, and go to the next iteration.
    Then you can keep calling more tools!

    Other notes:
    * When grouping by a high-cardinality column, always include \"top_n\" ≤100 to avoid truncation.
    * The get_trend_data tool provides you with member month totals needed to understand the normalized metrics.
"""


def make_initial_prompt(limit: int) -> str:
    ANALYSIS_PLAN = ANALYSIS_PLAN_TEMPLATE.format(LIMIT=limit)
    return "\n\n".join(
        [
            BASE_SYSTEM,
            ANALYSIS_PLAN
        ]
    )

# Helper functions for the AI agent

In [None]:
import asyncio
import json
import math
import os
from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union, Literal
from datetime import datetime

from agents import Agent, Runner, function_tool, RunContextWrapper, RunConfig
from agents.exceptions import MaxTurnsExceeded, ModelBehaviorError, UserError
from pydantic import BaseModel

from typing import List, Optional, Literal
from pydantic import BaseModel, Field
import pandas as pd, matplotlib.pyplot as plt, io, os
from docx import Document
from docx.shared import Inches
from pathlib import Path

DOC_PATH = Path("trend_analysis.docx")   # single file for all content

def get_document() -> Document:
    """Load existing docx or create a new one."""
    if DOC_PATH.exists():
        return Document(str(DOC_PATH))
    return Document()                    # starts new empty doc

# ───────────────────────────────
# Helper Functions for Display
# ───────────────────────────────
def format_agent_thoughts(text: str, max_width: int = 80) -> str:
    """Format agent thoughts for better readability"""
    lines = text.split('\n')
    formatted_lines = []

    for line in lines:
        if len(line) <= max_width:
            formatted_lines.append(f"   {line}")
        else:
            # Wrap long lines
            words = line.split(' ')
            current_line = "   "
            for word in words:
                if len(current_line) + len(word) + 1 <= max_width:
                    current_line += word + " "
                else:
                    formatted_lines.append(current_line.rstrip())
                    current_line = f"   {word} "
            if current_line.strip():
                formatted_lines.append(current_line.rstrip())

    return '\n'.join(formatted_lines)

# ───────────────────────────────
# Pydantic Models for Strict Schemas
# ───────────────────────────────
class FilterCondition(BaseModel):
    """A single filter condition for data queries"""
    dimension_name: str
    operator: Literal["=", "!=", ">", ">=", "<", "<=", "IN", "NOT IN", "LIKE", "NOT LIKE", "IS NULL", "IS NOT NULL", "BETWEEN"]
    value: Optional[Union[str, int, float, bool, List[Union[str, int, float]]]] = None

# ───────────────────────────────
# Context and Buffer Classes
# ───────────────────────────────
@dataclass
class AnalysisContext:
    """Context object - no need to store anything since functions are global"""
    pass

class ReportBuffer:
    """Buffer for collecting analysis report content"""
    def __init__(self):
        self._lines: List[str] = []

    def add(self, content: str):
        """Add content to the report"""
        if content and content.strip():
            self._lines.append(content + "\n\n")

    def get_content(self) -> str:
        """Get the full report content"""
        return "".join(self._lines)

# ───────────────────────────────
# Tool Function Wrappers (calling your original functions)
# ───────────────────────────────
@function_tool
async def get_trend_data_tool(
    group_by_dimensions: Optional[List[str]] = None,
    filters: Optional[List[FilterCondition]] = None,
    top_n: Optional[int] = None
) -> str:
    """
    Yearly claim trend metrics 2023 → 2024 with optional grouping, filters and top-n reduction.

    Output always includes:
    • allowed_pmpm – service-level denominator
    • parent_allowed_pmpm – parent-level denominator
    • utilization_pkpy, cost_per_service, etc.

    Args:
        group_by_dimensions: List of dimension names to group by
        filters: List of filter conditions with dimension_name, operator, and value
        top_n: Return only the N highest-volume groups
    """
    #print(f"🔧 TOOL CALL: get_trend_data_tool")
    #print(f"   📊 Grouping: {group_by_dimensions or 'None'}")
    #print(f"   🔍 Filters: {len(filters or [])} conditions")
    #print(f"   📈 Top N: {top_n or 'All'}")

    try:
        # Convert Pydantic models to dicts for your original function
        filters_dict = None
        if filters:
            filters_dict = [
                {
                    "dimension_name": f.dimension_name,
                    "operator": f.operator,
                    "value": f.value
                }
                for f in filters
            ]

        # Call your original function directly
        result = get_trend_data(
            group_by_dimensions=group_by_dimensions,
            filters=filters_dict,
            top_n=top_n
        )

        # Handle error case
        if "error" in result:
            print(f"   ❌ Error: {result['error']}")
            return f"Error processing trend data: {result['error']}"

        # Parse the JSON data from your function
        if "data" in result:
            data_json = result["data"]
            # Parse to get row count and sample data for summary
            try:
                data_list = json.loads(data_json)
                row_count = len(data_list)
                print(f"   ✅ Success: {row_count} rows returned")

                # Show sample of key metrics if available
                if data_list and len(data_list) > 0:
                    sample_row = data_list[0]
                    if 'allowed_pmpm' in sample_row:
                        print(f"   💰 Sample PMPM: ${sample_row.get('allowed_pmpm', 'N/A')}")
                    if 'year' in sample_row:
                        print(f"   📅 Years: {sorted(set(row.get('year', 'N/A') for row in data_list))}")

                # Create a more readable summary
                response = f"""
                TREND DATA ANALYSIS RESULTS:

                Configuration:
                - Grouping: {', '.join(group_by_dimensions) if group_by_dimensions else 'None'}
                - Filters Applied: {len(filters or [])}
                - Top N Limit: {top_n or 'All records'}
                - Total Rows Returned: {row_count:,}

                Data Sample (first few rows):
                {json.dumps(data_list[:3], indent=2) if data_list else 'No data'}

                Full Dataset:
                {data_json}
                """
            except (json.JSONDecodeError, TypeError):
                print(f"   ⚠️  JSON parsing failed, returning raw data")
                # Fallback if JSON parsing fails
                response = f"""
                TREND DATA ANALYSIS RESULTS:

                Configuration:
                - Grouping: {', '.join(group_by_dimensions) if group_by_dimensions else 'None'}
                - Filters Applied: {len(filters or [])}
                - Top N Limit: {top_n or 'All records'}

                Raw Data:
                {data_json}
                """

            # Add warning if present
            if "warning" in result:
                print(f"   ⚠️  Warning: {result['warning']}")
                response += f"\n⚠️  Warning: {result['warning']}"
        else:
            print(f"   ❌ No data returned")
            response = "No data returned from trend analysis"

        # Mark this as truncatable content for history compression
        response += "\n<!-- TRUNCATABLE_DATA_BLOCK -->"

        return response

    except Exception as e:
        print(f"   ❌ Exception: {str(e)}")
        return f"Error processing trend data: {str(e)}"


@function_tool
async def list_available_dimensions_tool() -> str:
    """
    Lists dimension keys & descriptions available for analysis.

    Returns a formatted list of all available dimensions that can be used
    for grouping and filtering in the trend analysis.
    """
    #print(f"🔧 TOOL CALL: list_available_dimensions_tool")

    try:
        # Call your original function directly
        result = list_available_dimensions()

        if "data" in result:
            # Your function returns JSON string, parse it for better formatting
            dimensions_data = json.loads(result["data"])
            print(f"   ✅ Success: {len(dimensions_data)} dimensions found")

            response = "AVAILABLE DIMENSIONS:\n\n"

            # Group by source table for better organization
            descriptor_dims = []
            norm_dims = []

            for dim_name, metadata in dimensions_data.items():
                source = metadata.get("source_table", "unknown")
                description = metadata.get("description", "No description available")
                dim_info = f"• {dim_name}: {description}"

                if source == "descriptor":
                    descriptor_dims.append(dim_info)
                elif source == "norm":
                    norm_dims.append(dim_info)

            print(f"   📊 Descriptor dims: {len(descriptor_dims)}, Norm dims: {len(norm_dims)}")

            if descriptor_dims:
                response += "DESCRIPTOR (Claims) Dimensions:\n"
                response += "\n".join(descriptor_dims) + "\n\n"

            if norm_dims:
                response += "NORM (Membership) Dimensions:\n"
                response += "\n".join(norm_dims) + "\n\n"

            response += f"Total: {len(dimensions_data)} dimensions available"
        else:
            print(f"   ❌ No dimension data available")
            response = "No dimension data available"

        return response

    except Exception as e:
        print(f"   ❌ Exception: {str(e)}")
        return f"Error listing dimensions: {str(e)}"


@function_tool
async def get_dimension_values_tool(dimension_name: str) -> str:
    """
    Get distinct non-null values for a specific dimension.

    Args:
        dimension_name: The name of the dimension to get values for
    """
    #print(f"🔧 TOOL CALL: get_dimension_values_tool")
    #print(f"   🏷️  Dimension: {dimension_name}")

    try:
        # Call your original function directly
        result = get_dimension_values(dimension_name)

        # Handle error case
        if "error" in result:
            print(f"   ❌ Error: {result['error']}")
            return f"Error getting dimension values for '{dimension_name}': {result['error']}"

        # Parse the data from your function
        if "data" in result:
            values_data = json.loads(result["data"]) if isinstance(result["data"], str) else result["data"]

            if isinstance(values_data, list):
                values = values_data
            elif isinstance(values_data, dict):
                # If your function returns a dict, extract the values
                values = list(values_data.keys()) if values_data else []
            else:
                values = [str(values_data)]

            print(f"   ✅ Success: {len(values)} distinct values found")
            print(f"   📝 Sample values: {values[:5]}")

            response = f"VALUES FOR DIMENSION '{dimension_name}':\n\n"

            # Show first 20 values with numbering
            display_values = values[:20]
            for i, value in enumerate(display_values, 1):
                response += f"{i:2d}. {value}\n"

            if len(values) > 20:
                response += f"... and {len(values) - 20} more values\n"

            response += f"\nTotal: {len(values)} distinct values"

            # If there are many values, also show the full list
            if len(values) > 20:
                response += f"\n\nComplete list: {json.dumps(values)}"

            # Mark as truncatable if the list is long
            if len(values) > 10:
                response += "\n<!-- TRUNCATABLE_DATA_BLOCK -->"
        else:
            print(f"   ❌ No values found")
            response = f"No values found for dimension '{dimension_name}'"

        return response

    except Exception as e:
        print(f"   ❌ Exception: {str(e)}")
        return f"Error getting dimension values for '{dimension_name}': {str(e)}"

# ───────────────────────────────
# Enabling Google Docs Writing
# ───────────────────────────────

from pydantic import BaseModel, Field
from typing import List, Optional, Literal

class ChartSpec(BaseModel):
    title: str
    x_dimension: str
    y_metrics: List[str]
    agg: Literal["sum", "mean", "count"] = "sum"
    chart_type: Literal["bar", "line", "stacked_bar"] = "bar"
    hue_dimension: Optional[str] = None
    top_n: Optional[int] = None
    height_px: int = 250

@function_tool
async def write_google_doc_tool(text: str) -> str:
    """
    Append plain text paragraph(s) to local docx.

    Args:
        text: The text to append to the docx
    """
    doc = get_document()
    for para in text.split("\n\n"):
        doc.add_paragraph(para)
    doc.save(str(DOC_PATH))
    return f"Wrote {len(text)} chars to {DOC_PATH}"


@function_tool
async def create_chart_tool(df_json: str, spec: ChartSpec) -> str:
    """
    Draw a chart based on data in a dataframe.

    Args:
        df_json: The data to draw in the chart
        spec: The chart specification
    """

    """Aggregate → plot → insert PNG into local docx."""
    df = pd.read_json(df_json)

    # optional top-n truncation
    if spec.top_n:
        df = df.nlargest(spec.top_n, spec.y_metrics[0])

    # group & aggregate
    agg_fn = {"sum": "sum", "mean": "mean", "count": "count"}[spec.agg]
    if spec.hue_dimension:
        pivot = (df
                 .groupby([spec.x_dimension, spec.hue_dimension])[spec.y_metrics]
                 .agg(agg_fn)
                 .unstack(spec.hue_dimension)
                 .fillna(0))
    else:
        pivot = df.groupby(spec.x_dimension)[spec.y_metrics].agg(agg_fn)

    # plot
    kind = "bar" if spec.chart_type in {"bar", "stacked_bar"} else "line"
    ax = pivot.plot(kind=kind,
                    stacked=(spec.chart_type == "stacked_bar"),
                    figsize=(6, 4))                           # ~ 400 px
    ax.set_title(spec.title)
    ax.set_xlabel(spec.x_dimension)
    plt.tight_layout()

    # save PNG to tmp
    png_path = "/tmp/chart.png"
    plt.savefig(png_path, dpi=150, bbox_inches="tight")
    plt.close()

    # append picture to docx
    doc = get_document()
    doc.add_paragraph(spec.title, style="Heading 3")
    doc.add_picture(png_path, width=Inches(6))
    doc.save(str(DOC_PATH))

    return f'Chart “{spec.title}” inserted into {DOC_PATH}'

# ───────────────────────────────
# History Compression
# ───────────────────────────────
def compress_conversation_history(
    messages: List[Dict[str, Any]],
    max_tokens: int,
    placeholder: str = "[data block truncated]"
) -> List[Dict[str, Any]]:
    """
    Compress conversation history by truncating large data blocks.
    """

    def rough_tokens(content: str) -> int:
        """Rough token estimation: ~4 characters per token"""
        return math.ceil(len(content) / 4)

    # Make a deep copy to avoid mutating original
    compressed_messages = deepcopy(messages)

    # Calculate total tokens
    total_tokens = sum(
        rough_tokens(msg.get("content", ""))
        for msg in compressed_messages
    )

    if total_tokens <= max_tokens:
        return compressed_messages

    # Find messages with large data blocks (tool responses)
    large_data_candidates = []
    for i, msg in enumerate(compressed_messages):
        content = msg.get("content", "")
        if ("<!-- TRUNCATABLE_DATA_BLOCK -->" in content and
            msg.get("role") == "tool"):
            tokens = rough_tokens(content)
            large_data_candidates.append((i, tokens))

    # Sort by size, largest first
    large_data_candidates.sort(key=lambda x: x[1], reverse=True)

    # Truncate largest data blocks until we're under the limit
    for idx, original_size in large_data_candidates:
        compressed_messages[idx]["content"] = placeholder
        placeholder_tokens = rough_tokens(placeholder)
        total_tokens -= (original_size - placeholder_tokens)

        if total_tokens <= max_tokens:
            break

    return compressed_messages

# Main AI agent code

In [None]:
# ───────────────────────────────
# Configuration Constants
# ───────────────────────────────
ITERATION_LIMIT = 30
TEMPERATURE = 0.3
MAX_OUTPUT_TOK = 4096
MAX_CONTEXT_TOK = 800000
SUMMARY_HEADER = "## Final Summary"
RUN_TS = datetime.now().strftime("%Y%m%d_%H%M%S")

model_name = "o3" # @param {"type":"string"}

In [None]:
# ───────────────────────────────
# Agent Creation and Main Loop
# ───────────────────────────────
async def create_analysis_agent() -> Agent:
    """Create the main analysis agent with tools and instructions"""

    initial_prompt = make_initial_prompt(ITERATION_LIMIT)

    agent = Agent(
        name="Trend Decomposition Agent",
        instructions=initial_prompt,
        model=model_name,
        tools=[get_trend_data_tool, list_available_dimensions_tool, get_dimension_values_tool, write_google_doc_tool, create_chart_tool]
    )

    return agent

async def run_once_streamed(agent, user_msg, *, max_turns=8):
    result = Runner.run_streamed(
        agent, input=user_msg, max_turns=max_turns,
        run_config=RunConfig(tracing_disabled=True),
    )

    async for ev in result.stream_events():
        if ev.type != "run_item_stream_event":
            continue
        it = ev.item

        # 1️⃣  Model reasoning
        if it.type == "reasoning_item":
            thought = "\n".join(
                s.text for s in it.raw_item.summary
                if s.type == "summary_text"          # each fragment is a PLAN/REFLECT chunk
            )
            if thought:
                print("\n🧠  THOUGHT\n" + thought + "\n")

        # 2️⃣  Tool call
        elif it.type == "tool_call_item":
            print(f"🔧  TOOL → {it.raw_item.name} "
                  f"{json.dumps(it.raw_item.arguments)}")

        # 3️⃣  Tool result
        elif it.type == "tool_call_output_item":
            # the SDK gives you the original call object as `tool_call`
            tname = getattr(it, "tool_call", None)
            tname = tname.name if tname else "<unknown>"
            print(f"📬  RESULT ({tname}): "
                  f"{str(it.output)[:500]}…")

        # 4️⃣  Plain assistant text (rare in streamed loops)
        elif it.type == "message_output_item":
            from agents.items import ItemHelpers           # helper in the SDK
            print("\n💬  ASSISTANT:\n"
                  + ItemHelpers.text_message_output(it)     # safe extractor
                  + "\n")

    return result

async def run_analysis_loop() -> str:
    """
    Run the main analysis loop with detailed progress tracking and iteration control.
    """
    report = ReportBuffer()

    # Create the analysis agent
    agent = await create_analysis_agent()

    # Start the analysis
    print("🔍 Starting claims trend analysis...")
    print(f"🔄 Max iterations: {ITERATION_LIMIT}")
    print("-" * 50)


    # try to get the agent to use its full budgeted iterations
    _COMPLETION_MARKERS = ["FINAL REPORT AND ANALYSIS CONCLUDED", "ABSOLUTE FINAL SUMMARY AND RECOMMENDATIONS", "## COMPANY-WIDE SUMMARY & KEY DRIVERS FINAL DOCUMENT"]

    try:
        # Initial run to get the agent started
        print("🤖 ITERATION 1: Initial exploration")

        result = await run_once_streamed(agent, user_msg="Begin the analysis.", max_turns=ITERATION_LIMIT)

        if result.final_output:
            print(f"💭 AGENT THOUGHTS:")
            print(format_agent_thoughts(result.final_output))
            print("-" * 60)
            report.add(f"## Iteration 1\n{result.final_output}")

            # Continue with iterative analysis
            conversation_history = result.to_input_list()

            for iteration in range(2, ITERATION_LIMIT + 1):
                print(f"\n🤖 ITERATION {iteration}: Continuing analysis")

                # Check if agent indicated completion
                if any(marker in result.final_output for marker in _COMPLETION_MARKERS): # Case-sensitive check
                    print(f"✅ Agent indicated analysis completion with a specific marker at iteration {iteration-1}")
                    break

                # Continue the conversation
                current_iteration_for_prompt = iteration
                iterations_left_for_agent = ITERATION_LIMIT - current_iteration_for_prompt

                if iterations_left_for_agent < 3: # e.g., Limit 50. Iter 48 (2 left), 49 (1 left), 50 (0 left)
                    next_prompt = (
                        f"You are on step {current_iteration_for_prompt} of {ITERATION_LIMIT}. "
                        f"There are {iterations_left_for_agent} iterations remaining (including this one) for your final summarization. "
                        "Focus on providing your 'Company-wide Summary & Key Drivers'. If absolutely necessary, make one final data call. "
                        "Remember to start your final report with one of the specific completion markers."
                    )
                else:
                    next_prompt = (
                        f"Continue your deep-dive analysis. You are on step {current_iteration_for_prompt} of {ITERATION_LIMIT}. "
                        f"There are {iterations_left_for_agent} iterations remaining. "
                        "DO NOT CONCLUDE THE ANALYSIS OR PROVIDE A FINAL SUMMARY YET. "
                        "Your objective is to continue exploring new hypotheses, drilling down into data, and making further tool calls. "
                        "Generate a new PLAN for investigation and the corresponding TOOL call(s)."
                    )
                conversation_history.append({"role": "user", "content": next_prompt})

                try:
                    result = await run_once_streamed(agent,
                                 user_msg=conversation_history,
                                 max_turns=ITERATION_LIMIT)

                    if result.final_output:
                        print(f"💭 AGENT THOUGHTS:")
                        print(format_agent_thoughts(result.final_output))
                        print("-" * 60)
                        report.add(f"## Iteration {iteration}\n{result.final_output}")

                        # Update conversation history for next iteration
                        conversation_history = result.to_input_list()

                        # Check for completion again
                        if any(marker in result.final_output for marker in _COMPLETION_MARKERS): # Case-sensitive check
                            print(f"✅ Agent indicated analysis completion with a specific marker at iteration {iteration}")
                            break
                    else:
                        print(f"⚠️  No output from iteration {iteration}")
                        break

                except Exception as e:
                    print(f"❌ Error in iteration {iteration}: {e}")
                    break

            print(f"✅ Analysis completed after {iteration} iterations")
            print(f"📄 Report length: {len(result.final_output)} characters")

            # Save the report
            report_content = report.get_content()
            report_path = f"claims-trend-analysis-{RUN_TS}.md"

            # Save to Colab files
            with open(report_path, 'w', encoding='utf-8') as f:
                f.write(report_content)

            print(f"💾 Report saved to: {report_path}")
            return report_content
        else:
            print("⚠️  No initial output received from analysis")
            return "Analysis started but no initial output was generated."

    except MaxTurnsExceeded:
        print(f"⚠️  Analysis reached maximum turns")
        return "Analysis reached maximum iteration limit."
    except Exception as e:
        print(f"❌ Analysis failed: {e}")
        raise

# ───────────────────────────────
# Main Execution Function
# ───────────────────────────────
async def run_claims_analysis():
    """
    Main function to run the complete claims analysis.
    Call this function to start the analysis.
    """

    # Verify OpenAI API key
    if not os.getenv("OPENAI_API_KEY"):
        print("❌ Error: OPENAI_API_KEY environment variable not set")
        print("Please set it using:")
        print("   import os")
        print("   os.environ['OPENAI_API_KEY'] = 'your-api-key-here'")
        return None

    # Verify required functions exist
    try:
        # Test that your functions are available
        dims_result = list_available_dimensions()
        if "error" in dims_result:
            print(f"❌ Error testing list_available_dimensions: {dims_result['error']}")
            return None
        print("✅ Data analysis functions are available")
    except Exception as e:
        print(f"❌ Error: Required analysis functions not available: {e}")
        print("Make sure get_trend_data, list_available_dimensions, and get_dimension_values are defined")
        return None

    print("🚀 Starting Claims Trend Analysis with OpenAI Agent SDK")
    print("=" * 60)

    try:
        report = await run_analysis_loop()

        print("\n" + "=" * 60)
        print("📋 ANALYSIS SUMMARY")
        print("=" * 60)

        # Show a preview of the report
        preview = report[:500] + "..." if len(report) > 500 else report
        print(preview)

        return report

    except Exception as e:
        print(f"\n❌ Analysis failed with error: {e}")
        return None

# ───────────────────────────────
# Usage Instructions
# ───────────────────────────────
print("🏥 Claims Trend Analysis Agent Setup Complete!")
print("\nTo run the analysis, execute:")
print("   report = await run_claims_analysis()")
print("\nOr for synchronous execution:")
print("   import asyncio")
print("   report = asyncio.run(run_claims_analysis())")

🏥 Claims Trend Analysis Agent Setup Complete!

To run the analysis, execute:
   report = await run_claims_analysis()

Or for synchronous execution:
   import asyncio
   report = asyncio.run(run_claims_analysis())


# Running the AI agent

In [None]:
import os
import asyncio
os.environ['OPENAI_API_KEY'] = 'your_key_here' # @param {"type":"string"}

report = await run_claims_analysis()

✅ Data analysis functions are available
🚀 Starting Claims Trend Analysis with OpenAI Agent SDK
🔍 Starting claims trend analysis...
🔄 Max iterations: 30
--------------------------------------------------
🤖 ITERATION 1: Initial exploration
   ✅ Success: 45 dimensions found
   📊 Descriptor dims: 23, Norm dims: 22
🔧  TOOL → list_available_dimensions_tool "{}"
📬  RESULT (<unknown>): AVAILABLE DIMENSIONS:

DESCRIPTOR (Claims) Dimensions:
• claim_type: Type of claim: Facility, Professional, or RX
• major_service_category: Highest level Health Cost Guidelines (HCG) service category
• provider_specialty: Specialty of the rendering provider (or Pharmacy for RX claims)
• detailed_service_category: Detailed Health Cost Guidelines (HCG) code and description
• ms_drg: Medicare Severity Diagnosis Related Group (MS-DRG) code and description
• ms_drg_mdc: Major Diagnostic Category (MDC…
   ✅ Success: 14 rows returned
   💰 Sample PMPM: $110.188
   📅 Years: [2023, 2024]
🔧  TOOL → get_trend_data_tool "{\"

  df = pd.read_json(df_json)


🔧  TOOL → create_chart_tool "{\"df_json\":\"[{\\\"group\\\":\\\"PROVIDER XXXXXX\\\",\\\"year\\\":2023,\\\"pmpm\\\":0.4275},{\\\"group\\\":\\\"PROVIDER XXXXXX\\\",\\\"year\\\":2024,\\\"pmpm\\\":0.0261},{\\\"group\\\":\\\"Other\\\",\\\"year\\\":2023,\\\"pmpm\\\":0.0216},{\\\"group\\\":\\\"Other\\\",\\\"year\\\":2024,\\\"pmpm\\\":0.0275}]\",\"spec\":{\"title\":\"Texas Chiropractic Allowed PMPM \u2013 PROVIDER XXXXXX vs Other Providers\",\"x_dimension\":\"year\",\"y_metrics\":[\"pmpm\"],\"agg\":\"sum\",\"chart_type\":\"bar\",\"hue_dimension\":\"group\",\"top_n\":null,\"height_px\":300}}"
📬  RESULT (<unknown>): Chart “Texas Chiropractic Allowed PMPM – PROVIDER XXXXXX vs Other Providers” inserted into trend_analysis.docx…
🔧  TOOL → write_google_doc_tool "{\"text\":\"Added bar chart showing PROVIDER XXXXXX PMPM collapsing while other providers flat.\\n\"}"
📬  RESULT (<unknown>): Wrote 77 chars to trend_analysis.docx…

💬  ASSISTANT:
PLAN:
We've robustly traced drivers; further optional deep di