In [39]:
from itertools import combinations
from tqdm import tqdm
from scipy import sparse
import plotly.express as px


In [3]:
import pandas as pd
import numpy as np
import chess
from tqdm import tqdm

# how many positions to sample
size = 1_000_000

# 1) Load your CSV and take a random 100 000
df = pd.read_csv('train.csv')  # must have columns 'FEN' and 'value'
df = (
    df
    .sample(frac=1, random_state=42)
    .reset_index(drop=True)
    .iloc[:size]
)

# 2) Feature extractor (768 dims: 12 piece‐types × 64 squares)
def extract_features(fen: str) -> np.ndarray:
    board = chess.Board(fen)
    arr = np.zeros(12 * 64, dtype=np.uint8)
    for sq, piece in board.piece_map().items():
        idx = (piece.piece_type - 1) + (0 if piece.color == chess.WHITE else 6)
        arr[idx * 64 + sq] = 1
    return arr

# 3) Build column names
piece_names = ['wP','wN','wB','wR','wQ','wK',
               'bP','bN','bB','bR','bQ','bK']
col_names = [
    f"{piece_names[i // 64]}{chess.square_name(i % 64)}"
    for i in range(12 * 64)
]

# 4) Featurize all positions
X = np.zeros((len(df), 12 * 64), dtype=np.uint8)
for i, fen in tqdm(enumerate(df['FEN']), total=len(df), desc="Featurizing"):
    X[i] = extract_features(fen)

# 5) Assemble into a DataFrame
X_df = pd.DataFrame(X, columns=col_names)
out_df = pd.concat([X_df, df['value'].reset_index(drop=True)], axis=1)

# 6) Save to Parquet
out_df.to_parquet('train.parquet', index=False)
print("Wrote train.parquet with shape", out_df.shape)


Featurizing: 100%|██████████| 1000000/1000000 [00:57<00:00, 17398.43it/s]


Wrote train.parquet with shape (1000000, 769)


In [50]:
np.sum(out_df["value"] == 1)

np.int64(482628)

In [51]:
np.sum(out_df["value"] == 0)

np.int64(57745)

In [52]:
np.sum(out_df["value"] == -1)

np.int64(459627)

In [21]:
import pandas as pd

feature_cols = [c for c in out_df.columns if c != 'value']
feature_sums = out_df[feature_cols].sum()
valid_feats = feature_sums[feature_sums > 0].index.tolist()
overall_mean = out_df['value'].mean()
value_when_1 = out_df[feature_cols].T.dot(out_df['value']) / feature_sums
effect = value_when_1 - overall_mean
score = feature_sums * effect


feat_stats = pd.DataFrame({
    'sum':        feature_sums,
    'avg_diff':   effect,
    'score':      score
})

# 8) Sort by |score| descending
feat_stats = feat_stats.reindex(feat_stats['score'].abs().sort_values(ascending=False).index)


In [None]:
feat_stats['t'] = np.sqrt(feat_stats['sum']) * feat_stats['avg_diff']
feat_stats = feat_stats.sort_values('t', key=lambda col: col.abs(), ascending=False)

In [61]:
feat_stats[abs(feat_stats["t"]) > 2]

Unnamed: 0,sum,avg_diff,score,t
bKg8,302870,-0.073349,-22215.312870,-40.366798
wKg1,332551,0.062661,20837.994449,36.134894
bKd8,10843,0.255981,2775.600157,26.655214
wPg2,595903,0.034443,20524.635097,26.588122
wPb2,540029,0.035872,19371.792971,26.360963
...,...,...,...,...
bKb1,235,0.134446,31.594765,2.061015
wBf2,5509,-0.027721,-152.712509,-2.057492
bNb4,8484,-0.021940,-186.140484,-2.020880
wRc3,3535,-0.033751,-119.308535,-2.006674


In [31]:
# 2) Build X as int64 so dot‐products accumulate correctly
X = sparse.csr_matrix(
    out_df[feature_cols].values.astype(np.int64),
    dtype=np.int64
)
y = out_df['value'].values.astype(np.float64)

# 3) Compute joint‐counts C and joint‐sum‐of‐y S
C_sparse = X.T.dot(X)          # int64 counts, up to ~100 000
Xy        = X.multiply(y[:,None])  # each row k scaled by y[k]
S_sparse = X.T.dot(Xy)         # float64 sums of y over co‐occurrences

# 4) Bring into dense arrays
C = C_sparse.toarray()
S = S_sparse.toarray()

# 5) Overall mean of y
overall_mean = y.mean()

# 6) For all (i,j) with C[i,j]>0:
#      mean_when_both = S / C
#      effect         = mean_when_both - overall_mean
#      score          = C * effect
mask = C > 0
mean_when_both = np.zeros_like(S)
mean_when_both[mask] = S[mask] / C[mask]

effect = mean_when_both - overall_mean
score  = C.astype(np.float64) * effect

# 7) Extract upper‐triangle (i<j), build DataFrame, filter & sort
i, j = np.triu_indices_from(score, k=1)
df_pairs = pd.DataFrame({
    'feat1':       np.array(feature_cols)[i],
    'feat2':       np.array(feature_cols)[j],
    'joint_count': C[i, j],
    'avg_diff':    effect[i, j],
    'score':       score[i, j],
})
df_pairs = df_pairs[df_pairs['joint_count'] > 0]
df_pairs = df_pairs.reindex(df_pairs['score'].abs().sort_values(ascending=False).index)

In [65]:
# map each feat1 and feat2 to its avg_diff in feat_stats
map_avg = feat_stats["avg_diff"]
df_pairs["avg_diff_adj"] = (
    df_pairs["avg_diff"]
    - df_pairs["feat1"].map(map_avg)
    - df_pairs["feat2"].map(map_avg)
)


In [66]:
df_pairs

Unnamed: 0,feat1,feat2,joint_count,avg_diff,score,t,avg_diff_adj
6871,wPb2,wPg2,396295,4.670717e-02,18509.818705,29.403083,-0.023608
10958,wPg2,wKg1,235861,7.375514e-02,17395.961139,35.819569,-0.023349
240241,bPf7,bKg8,204945,-8.324643e-02,-17060.939945,-37.686369,0.019386
9893,wPf2,wPg2,408074,4.057570e-02,16557.889926,25.920032,-0.026937
11710,wPh2,wKg1,214358,7.379501e-02,15818.551642,34.166215,-0.013951
...,...,...,...,...,...,...,...
97689,wBe2,wBe7,87,-1.249425e-05,-0.001087,-0.000117,-0.116477
9260,wPe2,wBh1,87,-1.249425e-05,-0.001087,-0.000117,0.044563
17447,wPh3,wNd3,1000,-1.000000e-06,-0.001000,-0.000032,0.020268
102822,wBe3,wKf2,1000,-1.000000e-06,-0.001000,-0.000032,-0.014457


In [62]:
df_pairs['t'] = np.sqrt(df_pairs['joint_count']) * df_pairs['avg_diff']

In [67]:
df_pairs['t_adj'] = np.sqrt(df_pairs['joint_count']) * df_pairs['avg_diff_adj']

In [70]:
df_pairs.sort_values("t_adj")

Unnamed: 0,feat1,feat2,joint_count,avg_diff,score,t,avg_diff_adj,t_adj
131921,wRf1,wKg1,182872,0.056306,10296.761128,24.078367,-0.046110,-19.718069
9893,wPf2,wPg2,408074,0.040576,16557.889926,25.920032,-0.026937,-17.207393
7049,wPb2,wRa1,403179,0.028544,11508.479821,18.124624,-0.027076,-17.192366
6870,wPb2,wPf2,363198,0.042170,15316.082802,25.414173,-0.026771,-16.133948
10578,wPf2,bQd8,273238,0.012265,3351.252762,6.411160,-0.030772,-16.085024
...,...,...,...,...,...,...,...,...
238586,bPb7,bPf7,352271,-0.026305,-9266.585271,-15.612817,0.022747,13.501150
238919,bPc7,bPf7,233313,0.012051,2811.567687,5.820752,0.028031,13.539888
236903,bPe6,bPc7,71023,0.049074,3485.399977,13.078355,0.054843,14.615749
239913,bPf7,bPg7,398821,-0.023773,-9481.281821,-15.013365,0.023681,14.954981


In [64]:
df_pairs[abs(df_pairs["t"]) > 2]

Unnamed: 0,feat1,feat2,joint_count,avg_diff,score,t
6871,wPb2,wPg2,396295,0.046707,18509.818705,29.403083
10958,wPg2,wKg1,235861,0.073755,17395.961139,35.819569
240241,bPf7,bKg8,204945,-0.083246,-17060.939945,-37.686369
9893,wPf2,wPg2,408074,0.040576,16557.889926,25.920032
11710,wPh2,wKg1,214358,0.073795,15818.551642,34.166215
...,...,...,...,...,...,...
49790,wNd1,bQa6,4,-1.023001,-4.092004,-2.046002
226021,bPf2,bBg2,4,-1.023001,-4.092004,-2.046002
244936,bNe1,bRh7,4,-1.023001,-4.092004,-2.046002
244947,bNe1,bQc1,4,-1.023001,-4.092004,-2.046002


In [44]:
fig = px.histogram(feat_stats, x='sum', nbins=700)
fig.update_layout(
    title='Histogram of joint_count',
    xaxis_title='joint_count',
    yaxis_title='Count'
)
fig.show()

In [46]:
df_pairs.sort_values("avg_diff").head(30)

Unnamed: 0,feat1,feat2,joint_count,avg_diff,score
137709,wRh2,bNf3,16,-1.023001,-16.368016
212618,wKc6,bQe4,3,-1.023001,-3.069003
267919,bBa4,bKe3,3,-1.023001,-3.069003
263206,bBf1,bBd8,3,-1.023001,-3.069003
263222,bBf1,bRd2,3,-1.023001,-3.069003
263229,bBf1,bRc3,3,-1.023001,-3.069003
263232,bBf1,bRf3,3,-1.023001,-3.069003
9784,wPe2,bQd3,3,-1.023001,-3.069003
288648,bQd3,bQc4,3,-1.023001,-3.069003
210367,wKf5,bNe1,3,-1.023001,-3.069003


In [29]:
feat_stats

Unnamed: 0,sum,avg_diff,score
bKg8,302870,-0.073349,-22215.312870
wKg1,332551,0.062661,20837.994449
wPg2,595903,0.034443,20524.635097
wPb2,540029,0.035872,19371.792971
wPf2,549718,0.033070,18178.936282
...,...,...,...
bPd8,0,,
bPe8,0,,
bPf8,0,,
bPg8,0,,


In [9]:
feat_stats.head(20)

Unnamed: 0,sum,avg_diff,score
bKg8,302870,-0.073349,-22215.31287
wKg1,332551,0.062661,20837.994449
wPg2,595903,0.034443,20524.635097
wPb2,540029,0.035872,19371.792971
wPf2,549718,0.03307,18178.936282
bPf7,576441,-0.029283,-16879.719441
wPh2,578401,0.025085,14509.198599
wPd4,258591,0.050822,13142.148409
wPa2,588217,0.020711,12182.420783
wRa1,538834,0.019749,10641.279166


In [4]:
out_df

Unnamed: 0,wPa1,wPb1,wPc1,wPd1,wPe1,wPf1,wPg1,wPh1,wPa2,wPb2,...,bKh7,bKa8,bKb8,bKc8,bKd8,bKe8,bKf8,bKg8,bKh8,value
0,0,0,0,0,0,0,0,0,1,1,...,0,0,0,0,0,1,0,0,0,1.0
1,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-1.0
2,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,1.0
3,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,1,0,0,0,1.0
4,0,0,0,0,0,0,0,0,0,1,...,0,0,0,0,0,0,0,0,0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999995,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,-1.0
999996,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,1.0
999997,0,0,0,0,0,0,0,0,1,1,...,0,0,0,0,0,0,0,1,0,1.0
999998,0,0,0,0,0,0,0,0,1,1,...,0,0,0,0,0,0,0,1,0,-1.0
