In [None]:
import os
import json
import time
import logging
import datetime as dt
from io import BytesIO
from typing import Any, Dict, List, Optional

import pandas as pd
import boto3
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import col, row_number, to_date, regexp_replace, quarter, when
from goldmansachs.amisp.data.spark import get_spark_context

# Configure logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# If provided, only these columns will be selected; otherwise, all columns are kept.
DEFAULT_FINAL_COLUMNS: Optional[List[str]] = [
    "FILE_DATE", "SOURCE_TIMESTAMP", "TRANS_TIMESTAMP", "startDate",
    "companyId", "companyName", "companyTicker", "year", "quarter",
    "storyType", "eventTitle", "eventBody", "version"
]

class EarningsPreprocessor:
    """
    Generalized earnings data preprocessor.
    
    Users may supply a list of final columns to select; if None, all columns are retained.
    It applies filtering by date, (optional) latest-version filtering, and extracts quarter.
    """
    def __init__(self, final_columns: Optional[List[str]] = DEFAULT_FINAL_COLUMNS):
        self.final_columns = final_columns

    def filter_by_latest_version(self, df: DataFrame) -> DataFrame:
        """Keep 'Final' version if available; otherwise fallback to 'Preliminary'."""
        window_spec = Window.partitionBy("companyTicker", "year", "quarter") \
                            .orderBy(
                                when(col("version") == "Final", 1)
                                .when(col("version") == "Preliminary", 2)
                                .otherwise(3)
                            )
        return (df.withColumn("row_num", row_number().over(window_spec))
                  .filter(col("row_num") == 1)
                  .drop("row_num"))

    def filter_date_range(self, df: DataFrame, start_date: dt.date, end_date: dt.date) -> DataFrame:
        """Filter DataFrame based on FILE_DATE range."""
        df = df.withColumn("FILE_DATE", to_date(col("FILE_DATE"), "yyyy/MM/dd"))
        return df.filter((col("FILE_DATE") >= start_date) & (col("FILE_DATE") <= end_date))

    def extract_quarter(self, df: DataFrame) -> DataFrame:
        """Compute the quarter from startDate and save it as 'quarter'."""
        return df.withColumn("quarter", quarter("startDate"))

    def select_final_columns(self, df: DataFrame) -> DataFrame:
        """
        If final_columns is provided, clean eventBody and select only those columns.
        Otherwise, return the DataFrame unchanged.
        """
        if self.final_columns:
            if "eventBody" in self.final_columns:
                df = df.withColumn("eventBody", regexp_replace(col("eventBody"), r"(=|--)", " "))
            return df.select(*self.final_columns)
        return df

def load_earnings_data(data_type: str, start_date: dt.date, end_date: dt.date,
                       filter_latest: bool = True, max_executors: int = 8,
                       env: str = "research",
                       preprocessor: Optional[EarningsPreprocessor] = None) -> DataFrame:
    """
    Load earnings data from parquet and apply processing.
    
    Uses a custom EarningsPreprocessor if provided.
    """
    try:
        spark = get_spark_context(env=env, max_executors=max_executors)
        df = spark.read.parquet(PARQUET_BASE_PATH)
        processor = preprocessor or EarningsPreprocessor()
        if filter_latest:
            df = processor.filter_by_latest_version(df)
        df = processor.filter_date_range(df, start_date, end_date)
        df = processor.extract_quarter(df)
        df = processor.select_final_columns(df)
        logger.info(f"Loaded {data_type} earnings data")
        return df
    except Exception as e:
        logger.error(f"Error loading {data_type} earnings data: {e}")
        raise

def load_earnings_call_data_ai_metrics(start_date: dt.date, end_date: dt.date,
                                         preprocessor: Optional[EarningsPreprocessor] = None) -> DataFrame:
    """Load earnings data for AI metrics (without latest version filtering)."""
    return load_earnings_data("ai_metrics", start_date, end_date,
                              filter_latest=False, preprocessor=preprocessor)

def load_earnings_call_data_vespa(start_date: dt.date, end_date: dt.date,
                                  preprocessor: Optional[EarningsPreprocessor] = None) -> DataFrame:
    """Load earnings data for Vespa ingestion (with latest version filtering)."""
    return load_earnings_data("vespa", start_date, end_date,
                              filter_latest=True, max_executors=1, preprocessor=preprocessor)