In [0]:
!pip install uv
!uv add databricks-langchain
!uv sync --active --quiet
dbutils.library.restartPython()

In [0]:

import os
import pandas as pd
import json
import toml
from typing import Optional
from pyspark.sql.connect.dataframe import DataFrame


In [0]:
# Load environment variables.
env_vars = toml.load("../../conf/env_vars.toml")

# Set as environment variables.
for key, value in env_vars.items():
    os.environ[key] = str(value)

In [0]:
srag_df = spark.read.table(F'{env_vars["CATALOG"]}.{env_vars["FS_SCHEMA"]}.srag_features')
hospital_df = spark.read.table(F'{env_vars["CATALOG"]}.{env_vars["FS_SCHEMA"]}.hospital_features')

In [0]:
class SRAGMetrics:
    def __init__(
        self, 
        df_srag: Optional[DataFrame] = None, 
        df_hospital: Optional[DataFrame] = None):
        """
        df_srag: optional main dataset (e.g., monthly cases)
        df_hospital: optional secondary dataset (e.g., hospitalizations)
        """
        catalog = os.environ["CATALOG"]
        schema = os.environ["FS_SCHEMA"]
        self.df_srag = df_srag if df_srag is not None else spark.read.table(f'{catalog}.{schema}.srag_features')

        self.df_hospital = df_hospital if df_hospital is not None else spark.read.table(f'{catalog}.{schema}.hospital_features')

    # Metric functions
    def calculate_cases_per_month(
        self, 
        start_date: Optional[str] = None, 
        end_date: Optional[str] = None
        ) -> pd.DataFrame:
        """
        Calculates number of cases per month. If start and end_dates are not provided, use last 12 months.
        
        Args:
            df (DataFrame): Spark DataFrame with column DT_NOTIFIC (date).
            start_date (str): Start date in 'yyyy-MM-dd'. If None, defaults to 12 months ago.
            end_date (str): End date in 'yyyy-MM-dd'. If None, defaults to today.
            
        Returns:
            Pandas DataFrame with ['year_month', 'count'].
        """
        # Default period = last 12 months.
        if end_date is None:
            end_date = pd.to_datetime("today").strftime("%Y-%m-%d")
        if start_date is None:
            start_date = (pd.to_datetime(end_date) - pd.DateOffset(months=12)).strftime("%Y-%m-%d")
        # Filter Spark DataFrame.
        df_filtered = self.df_srag.filter((F.col("DT_NOTIFIC") >= F.lit(start_date)) & 
                            (F.col("DT_NOTIFIC") <= F.lit(end_date)))      
        # Aggregate cases per month.
        cases_per_month = (
            df_filtered
            .withColumn("year_month", F.date_format("DT_NOTIFIC", "yyyy-MM"))
            .groupBy("year_month")
            .count()
            .orderBy("year_month")
        )
        # Convert to Pandas
        cases_pd = cases_per_month.toPandas()
        cases_pd["year_month"] = pd.to_datetime(cases_pd["year_month"])
        return cases_pd
    
    def calculate_cases_per_month_variation_rate(
        self,
        cases_current_count: Optional[pd.DataFrame] = None, 
        cases_comparison_count: Optional[pd.DataFrame] = None
        ) -> float:
        """
        Calculates the increase rate of cases per month. If cases_current_count and cases_comparison_count are not provided, use last month and 12 months ago.
        
        Args:
            cases_current_count: number of srag cases in the current month. If None, defaults number of the last month.
            cases_comparison_count: number of srag cases in the previous period to compare with. If None, defaults to 12 months ago.
        Returns:
            increase_rate (float): increase rate of cases compared to the previos period.
        """
        last_month = pd.to_datetime("today") - pd.DateOffset(months=1)
        one_year_before = pd.to_datetime("today") - pd.DateOffset(months=13)

        if cases_current_count is None:
            cases_current_count = self.calculate_cases_per_month(
                start_date=pd.offsets.MonthBegin().rollback(last_month), 
                end_date=pd.offsets.MonthEnd().rollforward(last_month))["count"][0]
        if cases_comparison_count is None:
            cases_comparison_count = self.calculate_cases_per_month(
                start_date=pd.offsets.MonthBegin().rollback(one_year_before), 
                end_date=pd.offsets.MonthEnd().rollforward(one_year_before))["count"][0]

        variation_rate = ((cases_current_count - cases_comparison_count) / cases_comparison_count).round(2)*100
        return variation_rate
    
    def calculate_cases_per_day(
        self, 
        start_date: Optional[str] = None, 
        end_date: Optional[str] = None,
        ) -> pd. DataFrame :
        """
        Calculates number of cases per day.
        
        Args:
            df (DataFrame): Spark DataFrame with column DT_NOTIFIC (date) of the period of time to calculate the daily number of SRAG cases.
            start_date (str): Start date in 'yyyy-MM-dd'. If None, defaults to 30 days interval.
            end_date (str): End date in 'yyyy-MM-dd'. If None, defaults to today.
            
        Returns:
            Pandas DataFrame with ['DT_NOTIFIC', 'count'].
        """
        # Default period = last 30 days
        if end_date is None:
            end_date = pd.to_datetime("today").strftime("%Y-%m-%d")
        if start_date is None:
            start_date = (pd.to_datetime(end_date) - pd.DateOffset(days=30)).strftime("%Y-%m-%d")

        # Filter Spark DataFrame to the period of time desired.
        filtered = self.df_srag.filter((F.col("DT_NOTIFIC") >= F.lit(start_date)) & 
                            (F.col("DT_NOTIFIC") <= F.lit(end_date)))
        
        # Aggregate cases per day
        cases_per_day = (
            filtered
            .groupBy("DT_NOTIFIC")
            .count()
            .orderBy("DT_NOTIFIC")
        )
        
        # Convert to Pandas
        cases_per_day_pd = cases_per_day.toPandas()
        cases_per_day_pd["DT_NOTIFIC"] = pd.to_datetime(cases_per_day_pd["DT_NOTIFIC"])

        return cases_per_day_pd

    # Agent-facing run method
    def run(self, query: str = None) -> str:
        """Run function to calculate the metrics and return the results as a JSON string."""
        results = {
            # "cases_per_month": self.calculate_cases_per_month().to_dict(orient="records"),
            "calculate_cases_per_month_variation_rate": self.calculate_cases_per_month_variation_rate(),
            # "hospitalizations": self.get_hospitalizations().to_dict(orient="records"),
            # "summary": self.get_summary().to_dict(orient="records")
        }
        return json.dumps(results, indent=2)