In [None]:
%pip install -U -qqq langchain_core langchain_databricks langchain_community

In [None]:
%restart_python

In [None]:
import os
import pandas as pd

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_databricks import ChatDatabricks
from databricks.sdk import WorkspaceClient
from langchain_core.runnables import Runnable, RunnableSequence, RunnableLambda


# configure workspace tokens
w = WorkspaceClient()
os.environ["DATABRICKS_HOST"] = w.config.host
os.environ["DATABRICKS_TOKEN"] = w.tokens.create(comment="for model serving", lifetime_seconds=1200).token_value

llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-405b-instruct")

def format_context(df: pd.DataFrame) -> str:
    return df.to_json(orient='records', indent=2)

def find_info(location: str) -> pd.DataFrame:
  query = f"""
    SELECT
  specialty_diseases,
  states,
  legal_entity_city,
  extract(year from `dais-hackathon-2025`.mimilabs.cpsc_combined.contract_effective_date) as year,
  extract(month from `dais-hackathon-2025`.mimilabs.cpsc_combined.contract_effective_date) as month,
  count(distinct `dais-hackathon-2025`.mimilabs.snpdata.plan_id) as distinct_plans,
  count( `dais-hackathon-2025`.mimilabs.snpdata.contract_number) as distinct_contracts,
  count( `dais-hackathon-2025`.mimilabs.snpdata.plan_enrollment) as distinct_enrollments
FROM
  `dais-hackathon-2025`.mimilabs.snpdata
  inner join
  `dais-hackathon-2025`.mimilabs.mapd_plan_directory
  on
  `dais-hackathon-2025`.mimilabs.snpdata.plan_type = `dais-hackathon-2025`.mimilabs.mapd_plan_directory.plan_type
AND
  `dais-hackathon-2025`.mimilabs.snpdata.contract_number = `dais-hackathon-2025`.mimilabs.mapd_plan_directory.contract_number
  inner join
  `dais-hackathon-2025`.mimilabs.cpsc_combined
  on
  `dais-hackathon-2025`.mimilabs.snpdata.contract_number = `dais-hackathon-2025`.mimilabs.cpsc_combined.contract_id
  and
  `dais-hackathon-2025`.mimilabs.snpdata.plan_id = `dais-hackathon-2025`.mimilabs.cpsc_combined.plan_id
GROUP BY
    specialty_diseases,
    states,
    legal_entity_city,
    year,
    month

  """
  return format_context(spark.sql(query).toPandas())
  

# === Agent A: ExtractionAgent ===
extraction_prompt = PromptTemplate.from_template(
  """
  You are a helpful healthcare assistant identifying what plans are available in my county. Categorise the speciality diseases into cardio, renal, mental, diabetic. One speciality disease can be mapped to multiple categories. Each speciality disease we map to at least one category. Return for each speciality disease the categories. Return just the list, not an explanation or summary.

  Here is the JSON data:
  {context}
  """
)

extraction_chain = (
    find_info |
    extraction_prompt
    | llm
    | StrOutputParser()
)

def get_summary(location: str) -> pd.DataFrame:
  query = f"""
  
  """
  return format_context(spark.sql(query).toPandas())

# === Agent B: SummaryAgent ===
summary_prompt = PromptTemplate.from_template(
    """
    Based on the extracted categories return how many cardio categories we have.

    Extracted Info:
    {extracted_info}
    """
)
summary_chain = (
    summary_prompt
    | llm
    | StrOutputParser()
)

# === Agentic Chain ===
agentic_chain = (
    extraction_chain 
    | summary_chain | (lambda extracted: {"extracted_info": extracted}) 
    | summary_chain
)

# Run the agentic chain
result = agentic_chain.invoke("Autauga")
print(result)