In [None]:
import pandas as pd
from google.cloud import bigquery
import numpy as np
import warnings

warnings.simplefilter(action='ignore', category=pd.errors.PerformanceWarning)

PROJECT_ID = "pivotal-glider-472219-r7"
TABLE_ID = "pivotal-glider-472219-r7.market_data.master_data"
TRAIN_RATIO = 0.80
TEST_RATIO = 0.15
VALIDATION_RATIO = 1.0 - TRAIN_RATIO - TEST_RATIO

def get_data_from_bigquery(project_id: str, table_id: str) -> pd.DataFrame:
    """
    Downloads the entire master_data table from BigQuery, ordered by date.
    
    NOTE: Make sure you are authenticated!
    Run `gcloud auth application-default login` in your terminal.
    """
    print(f"Connecting to BigQuery project '{project_id}'...")
    client = bigquery.Client(project=project_id)
    
    # Query to select all data ordered by date
    query = f"""
    SELECT *
    FROM `{table_id}`
    ORDER BY date ASC
    """
    
    print("Downloading data... This may take a moment.")
    try:
        query_job = client.query(query)
        df = query_job.to_dataframe(create_bqstorage_client=True, dtypes={"fed_Funds": "float64", "dgs_10": "float64"})
        print(f"Successfully downloaded {len(df)} rows.")
    except Exception as e:
        print(f"Error downloading data: {e}")
        print("Please ensure your project/table ID is correct and you have authenticated.")
        return pd.DataFrame()
        
    return df

def preprocess_data(df: pd.DataFrame) -> (pd.DataFrame, list):
    """
    Prepares the raw DataFrame for ML:
    1. Converts 'date' to datetime.
    2. Creates 'reference_price' (last valid close, e.g., Friday's price for Fri, Sat, Sun).
    3. Creates a 'target_price' column (next valid market close, e.g., Monday's price for Fri, Sat, Sun).
    4. Creates the binary 'target' variable based on 'target_price' vs. 'reference_price'.
    5. Drops rows that still have NaNs in target or reference (start/end of dataset).
    6. Forward-fills/backward-fills missing *feature* data.
    """
    if df.empty:
        return pd.DataFrame(), []

    print("Starting preprocessing...")
    
    df['date'] = pd.to_datetime(df['date'])
    
    # 2. Create 'reference_price': The last known market close.
    #    For Fri, Sat, Sun, this will be Friday's price.
    df['reference_price'] = df['sp500'].ffill()
    
    # 3. Create 'target_price': The next available market close.
    #    Shift by -1 (to get Sat's 'null' for Fri), then bfill()
    #    to pull Monday's price back to Fri, Sat, and Sun.
    df['target_price'] = df['sp500'].shift(-1).bfill()

    # 4. Drop rows where we couldn't determine a reference or target.
    initial_rows = len(df)
    df = df.dropna(subset=['reference_price', 'target_price']).copy()
    if initial_rows > len(df):
        print(f"Dropped {initial_rows - len(df)} rows from start/end of dataset (no reference or target).")

    # 5. Create the target variable
    #    Target = 1 ('buy') if next market day's price >= last market day's price
    #    Target = 0 ('sell') if next market day's price < last market day's price
    df['target'] = (df['target_price'] >= df['reference_price']).astype(int)
    
    # 6. Handle missing feature data (e.g., 'fed_Funds', 'dgs_10' on holidays)
    #    We forward-fill first, assuming values persist until a new one is reported.
    #    Then, backward-fill to catch any NaNs at the very beginning of the dataset.
    print("Filling missing feature data (ffill/bfill)...")
    
    core_cols = ['date', 'target', 'sp500', 'reference_price', 'target_price']
    feature_cols = [col for col in df.columns if col not in core_cols]
    
    df.loc[:, feature_cols] = df[feature_cols].ffill().bfill()
    
    # 7. Final cleanup
    df = df.rename(columns={'reference_price': 'sp500_last_close'})
    
    helper_cols_to_drop = ['sp500', 'target_price']
    
    print(f"Preprocessing complete. Final dataset shape: {df.shape}")
    return df, helper_cols_to_drop

def split_data_chronologically_by_week(df: pd.DataFrame, train_ratio: float, test_ratio: float) -> (pd.DataFrame, pd.DataFrame, pd.DataFrame):
    """
    Splits the data into train/test/validation sets based on chronological weeks.
    This prevents data leakage by ensuring the validation set is the most recent
    and the train set is the oldest.
    """
    print("Splitting data chronologically by week...")
    
    df['year'] = df['date'].dt.year
    df['year_week'] = df['year'].astype(str) + '-' + df['fiscal_week'].astype(str)
    
    unique_weeks = df['year_week'].unique()
    n_weeks = len(unique_weeks)
    
    if n_weeks < 3:
        print(f"Warning: Only {n_weeks} unique weeks. Not enough data to split into train/test/val.")
        print("Returning all data as training set.")
        helper_cols = ['year', 'year_week']
        df = df.drop(columns=helper_cols)
        return df, pd.DataFrame(), pd.DataFrame()

    val_count = int(np.ceil(n_weeks * VALIDATION_RATIO))
    test_count = int(np.ceil(n_weeks * test_ratio))
    
    if (val_count + test_count) >= n_weeks:
        val_count = 1
        test_count = 1
        train_count = max(1, n_weeks - 2)
    else:
        train_count = n_weeks - val_count - test_count

    train_split_idx = train_count
    test_split_idx = train_count + test_count
    
    train_weeks = unique_weeks[:train_split_idx]
    test_weeks = unique_weeks[train_split_idx:test_split_idx]
    val_weeks = unique_weeks[test_split_idx:]
    
    print(f"Total unique weeks: {n_weeks}")
    
    if len(train_weeks) > 0:
        print(f"  - Train weeks: {len(train_weeks)} (Ending: {train_weeks[-1]})")
    else:
        print("  - Train weeks: 0")

    if len(test_weeks) > 0:
        print(f"  - Test weeks: {len(test_weeks)} (Starting: {test_weeks[0]}, Ending: {test_weeks[-1]})")
    else:
        print("  - Test weeks: 0")

    if len(val_weeks) > 0:
        print(f"  - Validation weeks: {len(val_weeks)} (Starting: {val_weeks[0]})")
    else:
        print("  - Validation weeks: 0")
    
    train_df = df[df['year_week'].isin(train_weeks)].copy()
    test_df = df[df['year_week'].isin(test_weeks)].copy()
    val_df = df[df['year_week'].isin(val_weeks)].copy()
    
    helper_cols = ['year', 'year_week']
    train_df = train_df.drop(columns=helper_cols)
    test_df = test_df.drop(columns=helper_cols)
    val_df = val_df.drop(columns=helper_cols)
    
    print("\nData split complete:")
    print(f"  - Train set shape: {train_df.shape}")
    print(f"  - Test set shape:  {test_df.shape}")
    print(f"  - Val set shape:   {val_df.shape}")
    
    return train_df, test_df, val_df

def main():
    """
    Main function to run the data processing pipeline.
    """
    raw_df = get_data_from_bigquery(PROJECT_ID, TABLE_ID)
    
    if raw_df.empty:
        print("Exiting due to data download failure.")
        return

    processed_df, helper_cols = preprocess_data(raw_df)
    
    if processed_df.empty:
        print("Exiting due to preprocessing failure.")
        return

    train_df, test_df, val_df = split_data_chronologically_by_week(
        processed_df, TRAIN_RATIO, TEST_RATIO
    )
    
    train_df = train_df.drop(columns=helper_cols, errors='ignore')
    test_df = test_df.drop(columns=helper_cols, errors='ignore')
    val_df = val_df.drop(columns=helper_cols, errors='ignore')

    try:
        print("\nSaving data to CSV files...")
        train_df.to_csv("train_data.csv", index=False)
        test_df.to_csv("test_data.csv", index=False)
        val_df.to_csv("validation_data.csv", index=False)
        print("Successfully saved train_data.csv, test_data.csv, and validation_data.csv")
    except Exception as e:
        print(f"Error saving files: {e}")

if __name__ == "__main__":
    main()