In [0]:
%run /Workspace/Users/jorgegarciaotero@gmail.com/tfm_databricks/config/database_connector

In [0]:
from pyspark.sql.window import Window
import pyspark.sql.functions as F


#### FUNCIONES

In [0]:
def check_null_values(df) -> None:
    """
    Checks for null values in the DataFrame and displays the count of null values for each column.

    ARGS:
        df: Spark DataFrame

    RETURNS:
        None
    """
    excluded_cols = ['date', 'symbol']
    cols_to_check = [c for c in df.columns if c not in excluded_cols]

    # Create list of (column_name, count expression)
    null_exprs = [
        (c, F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c))
        for c in cols_to_check
    ]

    # Select row with all null counts
    null_row = df.select([expr for _, expr in null_exprs])

    # Convert to long format using stack
    stacked = null_row.select(F.expr("stack({}, {})".format(
        len(null_exprs),
        ", ".join([f"'{c}', {c}" for c, _ in null_exprs])
    )).alias("column", "nulls"))

    # Filter and sort
    result = stacked.filter("nulls > 0").orderBy(F.desc("nulls"))

    display(result)


In [0]:
def remove_initial_days_per_symbol(df, min_days=20):
    """
    Removes the first `min_days` rows per symbol based on date order.
    
    Args:
        df (DataFrame): Spark DataFrame with at least ['symbol', 'date']
        min_days (int): Number of initial rows to drop per symbol

    Returns:
        DataFrame: Cleaned DataFrame with initial rows removed
    """
    from pyspark.sql.window import Window
    from pyspark.sql.functions import row_number

    w = Window.partitionBy("symbol").orderBy("date")
    df = df.withColumn("row_num", row_number().over(w))
    df = df.filter(F.col("row_num") > min_days).drop("row_num")
    return df


#### MAIN

In [0]:
# Creates the input widgets and sets the default values
dbutils.widgets.text("storage_account", "smartwalletjorge", "Storage Account")
dbutils.widgets.text("container", "smart-wallet-dl", "Container")
dbutils.widgets.text("database", "smart_wallet", "Database")

storage_account = dbutils.widgets.get("storage_account")
container = dbutils.widgets.get("container")
database_name = dbutils.widgets.get("database")
date_value = dbutils.widgets.get("date")
if (date_value is None) or (date_value==''):
    date_value=None

db_connector = DatabaseConnector()
print(f"database_name :{database_name}")

df=db_connector.read_table_from_path(container, database_name, "stock_data_parquet", date_value,"parquet")
    

##### 1. Esquema y nulos

In [0]:
print(f"Counts : {df.count()}")
df.printSchema()


In [0]:
# Check the null values of the dataframe
check_null_values(df)

#### Target counts

In [0]:
from pyspark.sql.functions import col, when

df = df.withColumn(
    "target_3m",
    when(col("ret_next_3m") > 0.1, 1).otherwise(0)
).withColumn(
    "target_6m",
    when(col("ret_next_6m") > 0.1, 1).otherwise(0)
).withColumn(
    "target_1y",
    when(col("ret_next_1y") > 0.1, 1).otherwise(0)
)

print("=== target_3m ===")
df.groupBy("target_3m") \
  .count() \
  .orderBy("target_3m") \
  .show()

print("=== target_6m ===")
df.groupBy("target_6m") \
  .count() \
  .orderBy("target_6m") \
  .show()

print("=== target_1y ===")
df.groupBy("target_1y") \
  .count() \
  .orderBy("target_1y") \
  .show()


In [0]:
from pyspark.sql import DataFrame, functions as F
from pyspark.sql.window import Window
from typing import List

def verify_future_return_3m(
    df: DataFrame,
    symbol: str,
    reference_dates: List[str]
) -> DataFrame:
    """
    Given a Spark DataFrame that already contains columns:
      - price_lead_3m (Double): the stored “lead-close” at ~63 trading days ahead
      - ret_next_3m  (Double): the stored return (price_lead_3m - close_v) / close_v

    This function computes a manual “close_3m_manual” using lead(close_v, 63)
    over a partitionBy(symbol).orderBy(date) window, and then compares:
      * price_lead_3m  (precomputed)
      * close_3m_manual (lead of close_v 63 rows ahead)
      * ret_next_3m     (precomputed)
      * ret_3m_manual   (computed as (close_3m_manual - close_v) / close_v)

    Args:
    ----
    df : DataFrame
        A Spark DataFrame that contains at least:
          - symbol (string)
          - date (date or string)
          - close_v (Double)
          - price_lead_3m (Double)
          - ret_next_3m  (Double)

    symbol : str
        The ticker to filter for (e.g. "AAPL").

    reference_dates : List[str]
        A list of date strings (e.g., ["2023-01-03", "2023-02-01"]) for which
        you want to inspect and compare the stored vs. manual values.

    Returns:
    -------
    DataFrame
        A Spark DataFrame with one row per reference_date, containing:
          - date
          - close_v
          - price_lead_3m       (precomputed)
          - close_3m_manual     (computed via lead)
          - ret_3m_manual       (computed on the fly)
          - ret_next_3m         (precomputed)
    """
    # 1) Define the window partitioned by symbol and ordered by date
    window_spec = Window.partitionBy("symbol").orderBy("date")

    # 2) Filter for the desired symbol and add a manual "lead(close_v, 63)" column
    df_symbol = (
        df
          .filter(F.col("symbol") == symbol)
          .withColumn("close_3m_manual", F.lead("close_v", 63).over(window_spec))
    )

    # 3) Restrict to the reference_dates and select/compute all necessary fields
    result = (
        df_symbol
          .filter(F.col("date").isin(reference_dates))
          .select(
              F.col("date"),
              F.col("close_v"),
              F.col("price_lead_3m"),
              F.col("close_3m_manual"),
              # manual return: (close_3m_manual - close_v) / close_v
              ((F.col("close_3m_manual") - F.col("close_v")) / F.col("close_v"))
                  .alias("ret_3m_manual"),
              F.col("ret_next_3m")
          )
    )

    return result


In [0]:
# Supón que df_all ya tiene las columnas price_lead_3m y ret_next_3m
reference_dates = ["2023-03-04", "2023-04-04"]
df_check = verify_future_return_3m(df, symbol="AAPL", reference_dates=reference_dates)

df_check.show(truncate=False)


In [0]:

df_apple = df.filter(F.col("symbol") == "AAPL")

# 2) Seleccionar únicamente las columnas deseadas
df_apple_sel = df_apple.select(
    "date",
    "close_v",
    "price_lead_3m",
    "price_lead_6m",
    "price_lead_1y"
)

df_apple_ord = df_apple_sel.orderBy("date")

display(df_apple_ord)
