In [1]:
import glob
import math
import polars as pl
import xgboost as xgb
from sklearn.metrics import r2_score

# 1) Gather all Parquet files
all_files = glob.glob(
    "/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=*/part-0.parquet"
)
print("Files found:", all_files)

# ------------------------------------------------------------
# Set your fraction for sampling. We'll pick 0.60 (60%).
SAMPLE_FRAC = 0.60

# Decide how you want to fill nulls. Options: "mean" or "median"
FILL_STRATEGY = "mean"  # or "median"
# ------------------------------------------------------------

# 2) Create a list of lazy frames and concatenate them
df_list = [pl.scan_parquet(f) for f in all_files]
df_lazy = pl.concat(df_list)  # still lazy

# 3) Collect to get an eager DataFrame
df_eager = df_lazy.collect()
print("Eager DataFrame shape:", df_eager.shape)

# 4) Sample 60% of rows (if you want strictly 60%, you can do random sampling)
#    But if you must preserve the time order, you could do a direct slice.
#    For random sampling:
n_to_sample = math.floor(df_eager.height * SAMPLE_FRAC)
df_sample = df_eager.sample(n=n_to_sample, with_replacement=False)
print(f"Sample shape (frac={SAMPLE_FRAC}):", df_sample.shape)

# 5) Fill Nulls with Mean or Median for numeric columns
#    We'll define which columns are numeric.
numeric_cols = []
for col, dtype in df_sample.schema.items():
    if dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.Float32, pl.Float64):
        numeric_cols.append(col)

# Fill nulls column by column using Polars expressions
if FILL_STRATEGY == "mean":
    df_imputed = df_sample.with_columns([
        pl.col(col).fill_null(strategy="mean").alias(col)
        for col in numeric_cols
    ])
elif FILL_STRATEGY == "median":
    df_imputed = df_sample.with_columns([
        pl.col(col).fill_null(strategy="median").alias(col)
        for col in numeric_cols
    ])
else:
    raise ValueError("FILL_STRATEGY must be 'mean' or 'median'")

print("After imputation shape:", df_imputed.shape)

# 6) Time-based split (date_id cutoff = 1500)
date_cutoff = 1500
train_df = df_imputed.filter(pl.col("date_id") < date_cutoff)
val_df   = df_imputed.filter(pl.col("date_id") >= date_cutoff)

print("Train shape:", train_df.shape)
print("Val shape:  ", val_df.shape)

# 7) Exclude non-feature columns, define target
exclude_cols = [
    "date_id", "time_id", "symbol_id",
    "responder_0","responder_1","responder_2",
    "responder_3","responder_4","responder_5",
    "responder_6","responder_7","responder_8"
]
all_cols = df_imputed.columns
feature_cols = [c for c in all_cols if c not in exclude_cols]
target_col = "responder_6"

# 8) Convert to NumPy
X_train = train_df.select(feature_cols).to_numpy()
y_train = train_df.select(target_col).to_numpy().ravel()

X_val = val_df.select(feature_cols).to_numpy()
y_val = val_df.select(target_col).to_numpy().ravel()

print("X_train shape:", X_train.shape)
print("y_train shape:", y_train.shape)
print("X_val shape:  ", X_val.shape)
print("y_val shape:  ", y_val.shape)

# 9) Train an XGBoost model
reg = xgb.XGBRegressor(
    objective="reg:squarederror",
    tree_method="hist",  # more memory-friendly than "auto"
    max_depth=6,
    learning_rate=0.1,
    n_estimators=100
)

reg.fit(X_train, y_train)

# 10) Predict and evaluate
y_val_pred = reg.predict(X_val)
val_r2 = r2_score(y_val, y_val_pred)
print(f"R2 on {int(SAMPLE_FRAC*100)}% sample + {FILL_STRATEGY} imputation: {val_r2:.4f}")

Files found: ['/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=1/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=6/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=8/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=9/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=7/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=0/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=5/part-0.parquet', '/Users/nicky/Documents/codebase/jane-street-real-time-market-data-forecasting/train.parquet/partition_id=2/par

: 