# Choosing Number of Clusters

In [None]:
import os.path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
study_name = 'MASE_ChooseK_Study_FA_and_MD'
session_name = 'HCP_1200'
bundle_name = 'SLF_L'
aws_path = f's3://hcp-subbundle/{study_name}/{session_name}/{bundle_name}'
model_name = 'mase_kmeans_fa_r2_md_r2_is_mdf'
max_n_clusters = 9 # Silhouette Scores
best_n_cluster = 2 # Pair plots and Info

In [None]:
def aws_download(aws_path, endswith_pattern):
    aws_files = !aws s3 ls --recursive {aws_path}
        
    remote_filenames = []
    local_filenames = []

    for file in aws_files:
        if file.split()[3].endswith(endswith_pattern):
            remote_filenames.append(file.split()[3])
            local_filenames.append(file.split()[3].replace('/', '_'))

    for remote_filename, local_filename in zip(remote_filenames, local_filenames):
        !aws s3 cp s3://hcp-subbundle/{remote_filename} {local_filename}
            
    return local_filenames

def remove_aws_downloads(local_filenames):
    for file in local_filenames:
        !rm {file}

## Silhouette Scores

#### Download Silhouette Scores

In [None]:
local_silhouette_score_filenames = aws_download(aws_path, f'{max_n_clusters}/{model_name}_silhouette_scores.npy')

#### Aggregate Silhouette Scores

In [None]:
df = pd.DataFrame()

subjects = [
    '103818', '105923', '111312', '114823', '115320',
    '122317', '125525', '130518', '135528', '137128',
    '139839', '143325', '144226', '146129', '149337',
    '149741', '151526', '158035', '169343', '172332',
    '175439', '177746', '185442', '187547', '192439',
    '194140', '195041', '200109', '200614', '204521',
    '250427', '287248', '341834', '433839', '562345',
    '599671', '601127', '627549', '660951', # '662551', 
    '783462', '859671', '861456', '877168', '917255'
]

for fname in local_silhouette_score_filenames:
    # TODO figure out better way to do this
    for subject in subjects:
        if subject in fname:
            break
            
    #fname = f'MASE_ChooseK_Study_{session_name}_{bundle_name}_{subject}_{max_n_clusters}_{model_name}_silhouette_scores.npy'
    if os.path.exists(fname):
        df = df.append(pd.Series(np.load(fname), name=subject))

display(df) 

df1 = pd.melt(frame = df, var_name = 'cluster_number', value_name = 'silhouette_score')

# offset column index to correspond to clusters, clusters begin with two
df1['cluster_number'] = df1['cluster_number'] + 2

fig, ax = plt.subplots()
sns.lineplot(ax = ax, data = df1, x='cluster_number', y='silhouette_score', sort=False).set(
    title=f'MASE_ChooseK_Study_{session_name}_{bundle_name}\nn_subjects: {len(df)}'
)
plt.show()

#### Clean up

In [None]:
remove_aws_downloads(local_silhouette_score_filenames)

## Pair Plots

`best_n_cluster` is chosen from maximal value in the Aggregate Silhouette Scores

There's code to generate an animated gif, but much better/easier to inspect the pairplots individually.

Looking for patterns and differences in number of embedded components and number of clusters.

#### Download Pair Plots

In [None]:
local_pairplot_filenames = aws_download(aws_path, f'{best_n_cluster}/{model_name}_pairplot.png')

#### Create Animated gif of pair plots

#### Clean up

In [None]:
remove_aws_downloads(local_pairplot_filenames)

## Info File

In [None]:
local_info_filenames = aws_download(aws_path, f'{best_n_cluster}/{model_name}_info.pkl')

In [None]:
info_dfs = []
for local_info_filename in local_info_filenames:
    info_dfs.append(pd.read_pickle(local_info_filename))

In [None]:
info_df = pd.concat(info_dfs)
# fix offset issue
info_df['n_clusters selected'] = info_df['n_clusters selected'] + 1
display(info_df)

In [None]:
diff = {}
for subject, dims, n_cluster in zip(info_df['subject'], info_df['embed dimensions'].tolist(), info_df['n_clusters selected']):
    diff[subject] = dims[1]-n_cluster
    
plt.figure(figsize=(20,5))
plt.bar(*zip(*diff.items()))
plt.xticks(rotation=45)
plt.title(f'MASE_ChooseK_Study_{session_name}_{bundle_name}\n difference number of components to number of clusters')
plt.show()

In [None]:
remove_aws_downloads(local_info_filenames)