In [34]:
import pandas as pd
import numpy as np
import re
from config import *

# Sklearn
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.mixture import GaussianMixture
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.cluster import AgglomerativeClustering, DBSCAN, KMeans
from sklearn.metrics import silhouette_score
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler

import umap

# Visual
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns

from itertools import permutations 
from tqdm import tqdm

In [46]:
sig_features = ['ANTI_SACCADE_total_acceleration_magnitude_right_max',
 'ANTI_SACCADE_total_acceleration_magnitude_right_median',
 'FITTS_LAW_mean_amplitude_sacc',
 'FITTS_LAW_mean_duration_sacc',
 'FITTS_LAW_mean_duration_fix',
 'FITTS_LAW_avg_fixations_pr_second',
 'FITTS_LAW_std_fixations_pr_second',
 'FITTS_LAW_total_acceleration_magnitude_right_mean',
 'FITTS_LAW_total_acceleration_magnitude_right_median',
 'REACTION_reaction_time_avg',
 'REACTION_reaction_time_std',
 'REACTION_total_acceleration_magnitude_right_median',
 'EVIL_BASTARD_mean_duration_sacc',
 'EVIL_BASTARD_avg_fixations_pr_second',
 'EVIL_BASTARD_std_fixations_pr_second',
 'EVIL_BASTARD_total_acceleration_magnitude_right_max',
 'EVIL_BASTARD_total_acceleration_magnitude_right_median',
 'EVIL_BASTARD_distance_to_fixpoint_max',
 'EVIL_BASTARD_distance_to_fixpoint_x_max',
 'EVIL_BASTARD_distance_to_fixpoint_y_max',
 'SHAPES_total_acceleration_magnitude_right_mean',
 'SHAPES_total_acceleration_magnitude_right_median',
 'SMOOTH_PURSUITS_mean_peak_velocity_sacc',
 'SMOOTH_PURSUITS_mean_duration_sacc',
 'SMOOTH_PURSUITS_total_acceleration_magnitude_right_max',
 'SMOOTH_PURSUITS_total_acceleration_magnitude_right_median',
 'SMOOTH_PURSUITS_Var_total',
 'SMOOTH_PURSUITS_distance_to_fixpoint_max',
 'SMOOTH_PURSUITS_distance_to_fixpoint_x_std',
 'SMOOTH_PURSUITS_distance_to_fixpoint_y_max',
 'SMOOTH_PURSUITS_distance_to_fixpoint_y_std',
 'KING_DEVICK_mean_duration_fix',
 'KING_DEVICK_avg_time_elapsed_pr_trial',
 'KING_DEVICK_total_acceleration_magnitude_left_mean',
 'KING_DEVICK_total_acceleration_magnitude_right_mean',
 'KING_DEVICK_total_acceleration_magnitude_right_median']
features = pd.read_parquet(FEATURES_DIR / "features.pq").query("y==1")
features_with_info = features[sig_features + ["participant_id", 'y']]
features = features[sig_features]
features = features.dropna()
features_stdz = pd.DataFrame(StandardScaler().fit_transform(features.values), columns=features.columns)

In [71]:
column_combinations = list(permutations(features_stdz.columns, 3))

silhouette_scores = []
silhouette_features = []
n_clusters = []

for n_cluster in range(2,11):
    model = AgglomerativeClustering(n_clusters=n_cluster)
    for col_comb in tqdm(column_combinations):
        X = features_stdz[list(col_comb)]

        model.fit(X)
        
        sil_score = silhouette_score(X, model.labels_)
        
        silhouette_scores.append(sil_score)
        silhouette_features.append(X.columns.to_list())
        n_clusters.append(n_cluster)


100%|██████████| 42840/42840 [00:34<00:00, 1233.81it/s]
100%|██████████| 42840/42840 [00:34<00:00, 1227.75it/s]
100%|██████████| 42840/42840 [00:35<00:00, 1210.90it/s]
100%|██████████| 42840/42840 [00:35<00:00, 1215.90it/s]
100%|██████████| 42840/42840 [00:35<00:00, 1211.25it/s]
100%|██████████| 42840/42840 [00:35<00:00, 1202.52it/s]
100%|██████████| 42840/42840 [04:16<00:00, 166.99it/s] 
100%|██████████| 42840/42840 [00:34<00:00, 1228.70it/s]
100%|██████████| 42840/42840 [00:35<00:00, 1222.41it/s]


['FITTS_LAW_mean_duration_sacc', 'EVIL_BASTARD_mean_duration_sacc'] : 0.8771598789776172 : 2

In [72]:

def get_max_sil_score(group):
    return group.loc[group['silhouette_scores'].idxmax()]


best_all_comb_features = (pd.DataFrame({
    'silhouette_features': silhouette_features,
    'silhouette_scores': silhouette_scores,
    'n_clusters': n_clusters
})
.groupby('n_clusters').apply(get_max_sil_score).reset_index(drop=True)
)

best_all_comb_features


  .groupby('n_clusters').apply(get_max_sil_score).reset_index(drop=True)


Unnamed: 0,silhouette_features,silhouette_scores,n_clusters
0,"[FITTS_LAW_mean_duration_sacc, EVIL_BASTARD_me...",0.841656,2
1,"[FITTS_LAW_mean_duration_sacc, EVIL_BASTARD_me...",0.816192,3
2,"[EVIL_BASTARD_mean_duration_sacc, SMOOTH_PURSU...",0.721484,4
3,"[FITTS_LAW_mean_duration_sacc, EVIL_BASTARD_me...",0.624514,5
4,[ANTI_SACCADE_total_acceleration_magnitude_rig...,0.547859,6
5,[ANTI_SACCADE_total_acceleration_magnitude_rig...,0.524567,7
6,[ANTI_SACCADE_total_acceleration_magnitude_rig...,0.503386,8
7,"[EVIL_BASTARD_distance_to_fixpoint_y_max, SMOO...",0.501753,9
8,"[EVIL_BASTARD_distance_to_fixpoint_max, EVIL_B...",0.502728,10


In [73]:
chosen_features = []
chosen_n_clusters = []
chosen_silhouette_scores = []

for i, (sil_chosen_features, best_sil_score, n_cluster) in best_all_comb_features.iterrows():
    logging.info(f"Testing for {n_cluster} clusters")
    logging.info(f"Score to beat: {best_sil_score}")
    features_columns_left = features_stdz.drop(sil_chosen_features, axis=1).columns.to_list()
    while len(features_columns_left) > 0:
        new_sil_scores = []
        new_features = []
        
        for feature_column in features_columns_left:
            X = features_stdz[sil_chosen_features + [feature_column]]
            model = AgglomerativeClustering(n_clusters=n_cluster)
            model.fit(X)
            
            new_sil_score = silhouette_score(X, model.labels_)
            
            new_sil_scores.append(new_sil_score)
            new_features.append(feature_column)
        

        new_best_sil_score = np.max(new_sil_scores)
        new_best_sil_score_idx = np.argmax(new_sil_scores)
        new_feature = new_features[new_best_sil_score_idx]
        logging.info(f"New best silhouette score is {new_best_sil_score}")

        if best_sil_score < new_best_sil_score:
            logging.info(f"Adding new feature {feature_column} since it improved silhouette score (Old: {best_sil_score} vs New: {new_best_sil_score})")
            best_sil_score = new_best_sil_score
            sil_chosen_features += [new_feature]
        else:
            logging.info("Not better, breaking")
            break
        
        features_columns_left = features.drop(sil_chosen_features, axis=1).columns.to_list()
        
    chosen_features.append(sil_chosen_features)
    chosen_n_clusters.append(n_cluster)
    chosen_silhouette_scores.append(best_sil_score)
    
    logging.info(f"Testing for {n_cluster} clusters done\n\n\n")


        
chosen_features_df = (pd.DataFrame({
    'features': chosen_features,
    'silhouette_scores': chosen_silhouette_scores,
    'n_clusters': chosen_n_clusters
}))


2025-04-28 12:42:31,301 - INFO - 880880423.<module>:6 - Testing for 2 clusters
2025-04-28 12:42:31,302 - INFO - 880880423.<module>:7 - Score to beat: 0.841655887480316
2025-04-28 12:42:31,337 - INFO - 880880423.<module>:27 - New best silhouette score is 0.8083543080806903
2025-04-28 12:42:31,337 - INFO - 880880423.<module>:34 - Not better, breaking
2025-04-28 12:42:31,337 - INFO - 880880423.<module>:43 - Testing for 2 clusters done



2025-04-28 12:42:31,338 - INFO - 880880423.<module>:6 - Testing for 3 clusters
2025-04-28 12:42:31,338 - INFO - 880880423.<module>:7 - Score to beat: 0.816192066674989
2025-04-28 12:42:31,371 - INFO - 880880423.<module>:27 - New best silhouette score is 0.7594224033219499
2025-04-28 12:42:31,371 - INFO - 880880423.<module>:34 - Not better, breaking
2025-04-28 12:42:31,372 - INFO - 880880423.<module>:43 - Testing for 3 clusters done



2025-04-28 12:42:31,372 - INFO - 880880423.<module>:6 - Testing for 4 clusters
2025-04-28 12:42:31,372 - INFO - 880880423.

In [75]:
chosen_features_df

Unnamed: 0,features,silhouette_scores,n_clusters
0,"[FITTS_LAW_mean_duration_sacc, EVIL_BASTARD_me...",0.841656,2
1,"[FITTS_LAW_mean_duration_sacc, EVIL_BASTARD_me...",0.816192,3
2,"[EVIL_BASTARD_mean_duration_sacc, SMOOTH_PURSU...",0.721484,4
3,"[FITTS_LAW_mean_duration_sacc, EVIL_BASTARD_me...",0.624514,5
4,[ANTI_SACCADE_total_acceleration_magnitude_rig...,0.547859,6
5,[ANTI_SACCADE_total_acceleration_magnitude_rig...,0.524567,7
6,[ANTI_SACCADE_total_acceleration_magnitude_rig...,0.503386,8
7,"[EVIL_BASTARD_distance_to_fixpoint_y_max, SMOO...",0.501753,9
8,"[EVIL_BASTARD_distance_to_fixpoint_max, EVIL_B...",0.502728,10
