# Here I will try to train an XGBOOST ML to Predict Fixation Time for Graphs

In [None]:
import pandas as pd
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import xgboost as xgb
import shap
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, root_mean_squared_error

# Configuration
BATCH_NAME = 'batch_big_batch_test' # Update this if needed
ROOT = Path(os.getcwd())
DATA_PATH = ROOT / "simulation_data" / f"{BATCH_NAME}_graph_statistics.csv"

print("Setup Complete. Ready to load data.")

In [None]:
# 1. Load Data
drop_raw = ['graph6_string','branching', 'depth', 'n_rods', 'rods_length', 'rod_length', 'seed', 'n_grouped']
df = pd.read_csv(DATA_PATH).drop(columns=drop_raw, errors='ignore')
df = df[df['r'] == 1.1]

# 2. Define Features vs Target
# We want to predict 'median_steps'
target_col = 'median_steps'

# Columns that are ID identifiers or leakage (answers)
drop_cols = [
    'wl_hash', 'graph_name',        # IDs (not features)
    'prob_fixation',                # Different prediction task
    'median_steps', 'mean_steps', 'std_steps', # The Answers (Leakage)
    'category', 'graph_type'        # Strings (Drop for now to keep it simple)
]

# 3. Create X and y
# We use select_dtypes to ensure we only pass numbers to the models
X = df.drop(columns=drop_cols, errors='ignore').select_dtypes(include=[np.number])
y = df[target_col]

# 4. Handle Missing Values
X = X.fillna(-1)

print(f"Data Loaded.")
print(f"Features (X): {X.shape}")
print(f"Target (y): {y.shape}")
print("Feature list:", list(X.columns))

In [None]:
# 1. Split Data (Once for both models)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# --- Model A: Linear Regression Baseline ---
print("\n--- Linear Regression Results ---")
lr = LinearRegression()
lr.fit(X_train, y_train)
lr_preds = lr.predict(X_test)

print(f"LR R^2: {r2_score(y_test, lr_preds):.4f}")
print(f"LR RMSE: {root_mean_squared_error(y_test, lr_preds):.2f}")

# Optional: Show which features drive the linear model
coef_df = pd.DataFrame({'feature': X.columns, 'coefficient': lr.coef_})
print("\nTop 5 Features (Linear Regression):")
print(coef_df.sort_values(by='coefficient', key=abs, ascending=False).head(5))

In [None]:

# --- Model B: XGBoost ---
print("\n--- XGBoost Results ---")
# Check this before .fit()
print(f"Training on shape: {X_train.shape}")
mean_val = y_train.mean()
xgb_model = xgb.XGBRegressor(
    n_estimators=500,
    learning_rate=0.05,
    max_depth=6,
    objective='reg:squarederror',
    n_jobs=1,
    random_state=42,
    base_score=mean_val  # <--- ADD THIS LINE
)

xgb_model.fit(X_train, y_train)
xgb_preds = xgb_model.predict(X_test)

print(f"XGB R^2: {r2_score(y_test, xgb_preds):.4f}")
print(f"XGB RMSE: {root_mean_squared_error(y_test, xgb_preds):.2f}")


In [None]:

# # --- Interpretation (SHAP) ---
# # See what actually drives the non-linear model
# explainer = shap.TreeExplainer(xgb_model)
# shap_values = explainer(X_test)

# plt.figure()
# shap.summary_plot(shap_values, X_test, show=False)
# plt.title("What drives Fixation Time? (XGBoost)")
# plt.show()

In [None]:
# --- Interpretation (SHAP) ---

# 1. Use the PermutationExplainer (The "Black Box" method)
# We pass the .predict FUNCTION, not the model object. This bypasses the version conflict.
explainer = shap.PermutationExplainer(xgb_model.predict, X_test)

# 2. Calculate SHAP values
# Note: PermutationExplainer calculates interactions, so we usually just want the main values
shap_values = explainer(X_test)

# Plot 1: Summary
plt.figure()
shap.summary_plot(shap_values, X_test, show=False)
plt.title("Feature Importance (Permutation)")
plt.show()

# Plot 2: Dependence for the top feature
# (The structure of shap_values might be slightly different, so we handle it safely)
top_feature_idx = np.abs(shap_values.values).mean(0).argmax()
top_feature_name = X_test.columns[top_feature_idx]

print(f"Plotting dependence for top feature: {top_feature_name}")
plt.figure()
shap.dependence_plot(top_feature_name, shap_values.values, X_test, show=False)
plt.show()