In [7]:
# data_prep_sp500.py
"""
Prepare S&P500 factor and price data for neutralization.

Outputs (in ./output/):
 - factors_merged.csv        : merged long-format factor table (datetime, instrument, <factor columns...>)
 - df_ltsz.csv              : pivot table of market cap (index=date, columns=code) -- floats
 - industry_dummies.csv     : industry dummy variables per ticker (index=code)
 - ticker_industry_map.csv  : ticker -> industry string map (cached)
"""

import os
import glob
import pandas as pd
import numpy as np
import yfinance as yf
import argparse
from tqdm import tqdm

# -------------------------
# Config - change if needed
# -------------------------
FACTORS_FOLDER = r"C:\Users\ns243\Documents\Academic\AI Master\Internship\Data\alpha158_processed"
PRICE_FILE = r"C:\Users\ns243\Documents\Academic\AI Master\Internship\Data\df_sp500.csv"
OUTPUT_DIR = os.path.abspath("./output")
TICKER_INDUSTRY_CACHE = os.path.join(OUTPUT_DIR, "ticker_industry_map.csv")

os.makedirs(OUTPUT_DIR, exist_ok=True)

# -------------------------
# Helper functions
# -------------------------

def parse_timestamp_series(s):
    """Parse datetimes, keep timezone info, convert to UTC and normalize to midnight UTC if needed.
       Returns pandas.DatetimeIndex (UTC tz-aware) and original formatted string column for saving.
    """
    # coerce errors, keep timezone where present, convert to UTC
    dt = pd.to_datetime(s, errors='coerce', utc=True)
    # Some data might be NaT: keep them as NaT
    return dt

def infer_factor_name(filepath):
    """Create a sane factor name from filename (without extension)."""
    name = os.path.splitext(os.path.basename(filepath))[0]
    # sanitize
    name = name.replace(" ", "_").replace("-", "_")
    return f"factor_{name}"

# -------------------------
# 1) Merge factor CSVs
# -------------------------
def merge_factors(folder):
    """
    Each factor CSV assumed to have:
      - a 'date' column
      - remaining columns are tickers (one column per ticker)
    Produces a long-format DataFrame: columns = ['datetime','instrument', <factor columns>...]
    """
    csv_files = sorted(glob.glob(os.path.join(folder, "*.csv")))
    if not csv_files:
        raise FileNotFoundError(f"No CSVs found in {folder}")
    print(f"Found {len(csv_files)} factor files. Merging...")

    factors_long_list = []
    factor_names = []
    for fp in tqdm(csv_files, desc="Factor files"):
        fname = infer_factor_name(fp)
        factor_names.append(fname)
        df = pd.read_csv(fp, dtype=str)  # read as str to avoid unexpected dtype issues
        if 'date' not in df.columns:
            raise ValueError(f"file {fp} missing 'date' column")
        # parse datetime column
        df['datetime'] = parse_timestamp_series(df['date'])
        df = df.drop(columns=['date'])

        # melt: columns except datetime -> instrument, value
        val_cols = [c for c in df.columns if c != 'datetime']
        df_m = df.melt(id_vars=['datetime'], value_vars=val_cols,
                       var_name='instrument', value_name=fname)
        # coerce numeric
        df_m[fname] = pd.to_numeric(df_m[fname], errors='coerce')
        factors_long_list.append(df_m)

    # Merge all on ['datetime','instrument']
    print("Merging factor long tables...")
    merged = factors_long_list[0]
    for dfm in factors_long_list[1:]:
        merged = merged.merge(dfm, how='outer', on=['datetime', 'instrument'])

    # Ensure consistent ordering and drop rows with no factor values at all
    factor_cols = [f for f in merged.columns if f not in ['datetime','instrument']]
    merged = merged.dropna(axis=0, how='all', subset=factor_cols)
    merged = merged.sort_values(['datetime','instrument']).reset_index(drop=True)

    # Save a wide-ish format: columns datetime,instrument,<factors...>
    out_path = os.path.join(OUTPUT_DIR, "factors_merged.csv")
    merged.to_csv(out_path, index=False)
    print(f"Saved merged factors to {out_path}")
    return out_path

# -------------------------
# 2) Prepare price & market cap
# -------------------------
def prepare_price_marketcap(price_csv):
    """
    Reads df_sp500.csv which should have columns similar to:
    date, stock_code, open, high, low, close, factor, change, volume, money, shares_out
    Produces:
     - df_ltsz.csv : pivoted market cap (index: date, columns: code)
    """
    print("Reading price file...")
    df = pd.read_csv(price_csv)
    # Rename likely columns to canonical ones if needed
    # Accept either 'stock_code' or 'code' or 'ticker'
    for cand in ['stock_code', 'code', 'ticker', 'instrument']:
        if cand in df.columns:
            df = df.rename(columns={cand: 'code'})
            break
    if 'date' not in df.columns or 'code' not in df.columns:
        raise ValueError("price file must contain 'date' and a ticker column (stock_code/code/ticker)")

    df['datetime'] = parse_timestamp_series(df['date'])
    df = df.drop(columns=['date'])

    # ensure numeric columns exist
    if 'close' not in df.columns:
        raise ValueError("price file must contain 'close' column")
    # shares_out must exist to compute market cap
    if 'shares_out' not in df.columns:
        raise ValueError("price file must contain 'shares_out' column to compute market cap")

    # coerce numeric
    df['close'] = pd.to_numeric(df['close'], errors='coerce')
    df['shares_out'] = pd.to_numeric(df['shares_out'], errors='coerce')

    # compute market cap
    df['market_cap'] = df['close'] * df['shares_out']

    # drop rows missing code or datetime
    df = df.dropna(subset=['code','datetime'])

    # pivot to create market cap table (index: datetime, columns: code)
    df_pivot = df.pivot_table(index='datetime', columns='code', values='market_cap', aggfunc='first')

    out_path = os.path.join(OUTPUT_DIR, "df_ltsz.csv")
    # save raw market cap (not log), index will be timestamps in ISO format
    df_pivot.to_csv(out_path)
    print(f"Saved market cap pivot to {out_path} (values are market_cap, not log).")
    return out_path

# -------------------------
# 3) Fetch industries (yfinance)
# -------------------------
def fetch_industries_for_tickers(tickers, cache_path=TICKER_INDUSTRY_CACHE):
    """
    Uses yfinance to fetch the 'industry' string for each ticker.
    Caches the mapping to disk so repeated runs are fast.
    Returns DataFrame: index=ticker, columns=['industry','sector']
    """
    # try to load cache
    if os.path.exists(cache_path):
        df_cache = pd.read_csv(cache_path, index_col=0)
    else:
        df_cache = pd.DataFrame(columns=['industry','sector'])

    tickers = sorted(set([t.upper() for t in tickers if pd.notna(t)]))
    new_needed = [t for t in tickers if t not in df_cache.index]
    print(f"{len(tickers)} unique tickers found; {len(new_needed)} missing from cache. Fetching...")

    # fetch in batches: yfinance supports Tickers(t1 t2 ...)
    # but we'll loop to avoid occasional partial failures
    for t in tqdm(new_needed, desc="Fetching industry"):
        try:
            tk = yf.Ticker(t)
            info = tk.info
            industry = info.get('industry') or info.get('Industry') or None
            sector = info.get('sector') or info.get('Sector') or None
        except Exception as e:
            # fail gracefully
            industry = None
            sector = None
        df_cache.loc[t] = [industry, sector]

    # fill NaN industries with 'Unknown'
    df_cache['industry'] = df_cache['industry'].fillna('Unknown')
    df_cache['sector'] = df_cache['sector'].fillna('Unknown')

    # save cache
    df_cache.to_csv(cache_path)
    print(f"Saved ticker industry cache to {cache_path}")
    return df_cache

def build_industry_dummies(industry_map_df):
    """
    From ticker->industry mapping DataFrame, build one-hot dummies.
    Returns DataFrame indexed by ticker with dummy columns.
    """
    df = industry_map_df.copy()
    df['industry'] = df['industry'].fillna('Unknown')
    dummies = pd.get_dummies(df['industry'], prefix='ind')
    dummies.index = df.index
    out_path = os.path.join(OUTPUT_DIR, "industry_dummies.csv")
    dummies.to_csv(out_path)
    print(f"Saved industry dummies to {out_path}")
    return out_path

# -------------------------
# Main
# -------------------------
if __name__ == "__main__":
    print("Starting data preparation...")

    merged_factors_path = merge_factors(FACTORS_FOLDER)
    df_ltsz_path = prepare_price_marketcap(PRICE_FILE)

    # read merged factors to get ticker universe
    merged = pd.read_csv(merged_factors_path, parse_dates=['datetime'])
    tickers = merged['instrument'].unique().tolist()

    industry_map = fetch_industries_for_tickers(tickers)
    industry_dummies_path = build_industry_dummies(industry_map)

    print("All done. Outputs are in the './output' directory.")
    print("Files produced:")
    print(" -", merged_factors_path)
    print(" -", df_ltsz_path)
    print(" -", industry_dummies_path)
    print(" -", TICKER_INDUSTRY_CACHE)
    print("\nNotes:")
    print(" - If yfinance fetches fail for some tickers, check ticker format (e.g. BRK-B vs BRK.B or different Yahoo symbols).")
    print(" - The script caches industry lookups; if you need to refresh, delete ticker_industry_map.csv and re-run.")



Starting data preparation...
Found 101 factor files. Merging...


Factor files: 100%|██████████| 101/101 [28:13<00:00, 16.76s/it]  


Merging factor long tables...
Saved merged factors to c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\factors_merged.csv
Reading price file...
Saved market cap pivot to c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\df_ltsz.csv (values are market_cap, not log).
503 unique tickers found; 503 missing from cache. Fetching...


Fetching industry: 100%|██████████| 503/503 [07:02<00:00,  1.19it/s]

Saved ticker industry cache to c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\ticker_industry_map.csv
Saved industry dummies to c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\industry_dummies.csv
All done. Outputs are in the './output' directory.
Files produced:
 - c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\factors_merged.csv
 - c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\df_ltsz.csv
 - c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\industry_dummies.csv
 - c:\Users\ns243\Documents\Academic\AI Master\Internship\Codes\output\ticker_industry_map.csv

Notes:
 - If yfinance fetches fail for some tickers, check ticker format (e.g. BRK-B vs BRK.B or different Yahoo symbols).
 - The script caches industry lookups; if you need to refresh, delete ticker_industry_map.csv and re-run.



