In [4]:
import pandas as pd
import numpy as np
import re
from config import *

# Sklearn
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

import umap

# Visual
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

from itertools import permutations 
from tqdm import tqdm

In [5]:
features = pd.read_parquet(FEATURES_DIR / "features.pq").query("y==1")
features_with_info = features
features = features.drop(["participant_id", "y"], axis=1)
features = features.dropna()
features_stdz = pd.DataFrame(StandardScaler().fit_transform(features.values), columns=features.columns)

In [6]:
column_combinations = list(permutations(features_stdz.columns, 1))

silhouette_scores = []
silhouette_features = []
n_clusters = []

for n_cluster in range(2,11):
    model = AgglomerativeClustering(n_clusters=n_cluster)
    for col_comb in tqdm(column_combinations):
        X = features_stdz[list(col_comb)]

        model.fit(X)
        
        sil_score = silhouette_score(X, model.labels_)
        
        silhouette_scores.append(sil_score)
        silhouette_features.append(X.columns.to_list())
        n_clusters.append(n_cluster)


100%|██████████| 483/483 [00:00<00:00, 1251.58it/s]
100%|██████████| 483/483 [00:00<00:00, 1261.21it/s]
100%|██████████| 483/483 [00:00<00:00, 1252.22it/s]
100%|██████████| 483/483 [00:00<00:00, 1099.65it/s]
100%|██████████| 483/483 [00:00<00:00, 1251.31it/s]
100%|██████████| 483/483 [00:00<00:00, 1234.11it/s]
100%|██████████| 483/483 [00:00<00:00, 1219.21it/s]
100%|██████████| 483/483 [00:00<00:00, 1197.93it/s]
100%|██████████| 483/483 [00:00<00:00, 1217.13it/s]


['FITTS_LAW_mean_duration_sacc', 'EVIL_BASTARD_mean_duration_sacc'] : 0.8771598789776172 : 2

In [None]:

def get_max_sil_score(group):
    return group.loc[group['silhouette_scores'].idxmax()]


best_all_comb_features = (pd.DataFrame({
    'silhouette_features': silhouette_features,
    'silhouette_scores': silhouette_scores,
    'n_clusters': n_clusters
})
 .groupby('n_clusters')
 .apply(get_max_sil_score)
 .reset_index(drop=True)
)

best_all_comb_features


  .groupby('n_clusters').apply(get_max_sil_score).reset_index(drop=True)


Unnamed: 0,silhouette_features,silhouette_scores,n_clusters
0,[REACTION_duration_min_sacc],1.0,2
1,[SMOOTH_PURSUITS_distance_to_fixpoint_y_min],1.0,3
2,[ANTI_SACCADE_duration_min_sacc],0.981818,4
3,[KING_DEVICK_duration_min_sacc],0.981818,5
4,[SHAPES_duration_min_sacc],0.945455,6
5,[REACTION_n_correct_trials],0.923333,7
6,[REACTION_n_correct_trials],0.972727,8
7,[REACTION_n_correct_trials],0.981818,9
8,[REACTION_prop_correct_trials],0.963636,10


In [8]:
chosen_features = []
chosen_n_clusters = []
chosen_silhouette_scores = []

for i, (sil_chosen_features, best_sil_score, n_cluster) in best_all_comb_features.iterrows():
    logging.info(f"Testing for {n_cluster} clusters")
    logging.info(f"Score to beat: {best_sil_score}")
    features_columns_left = features_stdz.drop(sil_chosen_features, axis=1).columns.to_list()
    while len(features_columns_left) > 0:
        new_sil_scores = []
        new_features = []
        
        for feature_column in features_columns_left:
            X = features_stdz[sil_chosen_features + [feature_column]]
            model = AgglomerativeClustering(n_clusters=n_cluster)
            model.fit(X)
            
            new_sil_score = silhouette_score(X, model.labels_)
            
            new_sil_scores.append(new_sil_score)
            new_features.append(feature_column)
        

        new_best_sil_score = np.max(new_sil_scores)
        new_best_sil_score_idx = np.argmax(new_sil_scores)
        new_feature = new_features[new_best_sil_score_idx]
        logging.info(f"New best silhouette score is {new_best_sil_score}")

        if best_sil_score < new_best_sil_score:
            logging.info(f"Adding new feature {feature_column} since it improved silhouette score (Old: {best_sil_score} vs New: {new_best_sil_score})")
            best_sil_score = new_best_sil_score
            sil_chosen_features += [new_feature]
        else:
            logging.info("Not better, breaking")
            break
        
        features_columns_left = features.drop(sil_chosen_features, axis=1).columns.to_list()
        
    chosen_features.append(sil_chosen_features)
    chosen_n_clusters.append(n_cluster)
    chosen_silhouette_scores.append(best_sil_score)
    
    logging.info(f"Testing for {n_cluster} clusters done\n\n\n")


        
chosen_features_df = (pd.DataFrame({
    'features': chosen_features,
    'silhouette_scores': chosen_silhouette_scores,
    'n_clusters': chosen_n_clusters
}))


2025-04-30 08:55:58,434 - INFO - 880880423.<module>:6 - Testing for 2 clusters
2025-04-30 08:55:58,435 - INFO - 880880423.<module>:7 - Score to beat: 1.0
2025-04-30 08:55:58,822 - INFO - 880880423.<module>:27 - New best silhouette score is 1.0
2025-04-30 08:55:58,823 - INFO - 880880423.<module>:34 - Not better, breaking
2025-04-30 08:55:58,823 - INFO - 880880423.<module>:43 - Testing for 2 clusters done



2025-04-30 08:55:58,823 - INFO - 880880423.<module>:6 - Testing for 3 clusters
2025-04-30 08:55:58,824 - INFO - 880880423.<module>:7 - Score to beat: 1.0
2025-04-30 08:55:59,210 - INFO - 880880423.<module>:27 - New best silhouette score is 1.0
2025-04-30 08:55:59,211 - INFO - 880880423.<module>:34 - Not better, breaking
2025-04-30 08:55:59,211 - INFO - 880880423.<module>:43 - Testing for 3 clusters done



2025-04-30 08:55:59,211 - INFO - 880880423.<module>:6 - Testing for 4 clusters
2025-04-30 08:55:59,211 - INFO - 880880423.<module>:7 - Score to beat: 0.9818181818181818
2025-04-30 

In [9]:
chosen_features_df

Unnamed: 0,features,silhouette_scores,n_clusters
0,[REACTION_duration_min_sacc],1.0,2
1,[SMOOTH_PURSUITS_distance_to_fixpoint_y_min],1.0,3
2,[ANTI_SACCADE_duration_min_sacc],0.981818,4
3,[KING_DEVICK_duration_min_sacc],0.981818,5
4,[SHAPES_duration_min_sacc],0.945455,6
5,[REACTION_n_correct_trials],0.923333,7
6,[REACTION_n_correct_trials],0.972727,8
7,[REACTION_n_correct_trials],0.981818,9
8,[REACTION_prop_correct_trials],0.963636,10
