In [1]:
from typing import Any, List, Dict, Optional
from dataclasses import dataclass, field
from pathlib import Path
import os.path as op

from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from scipy.stats import entropy
from pysankey2 import Sankey
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from cluster import reindex_clusters

In [2]:
# Join similarity matrices - should only need to run once
def join_similarities():
    sim_paths = Path('../data/similarity/')

    df_similarity = None
    for f in sim_paths.rglob('*h5'):
        dataset, atlas = str(f).split('/')[-1].split('_')[0:2]
        tdf = pd.read_hdf(f)
        tdf['dataset'], tdf['atlas'] = dataset, atlas

        if df_similarity is None:
            df_similarity = tdf
        else:
            df_similarity = pd.concat([df_similarity, tdf])

    pd.to_pickle(df_similarity, '../data/clustering/preprocessed_similarity.pkl')

In [3]:
@dataclass
class ClusteringFrames:
    definitions: Optional[pd.DataFrame] = None
    data: Optional[pd.DataFrame] = None
    subjects: Optional[pd.DataFrame] = None
    overlap: Optional[pd.DataFrame] = None
    plot: Optional[Any] = None
    combinations: List = field(default_factory=list)
    sorting: Dict = field(default_factory=lambda: {})
    plot_order: List = field(default_factory=lambda: ['aal', 'cc2', 'hox', 'des'])


# Utility function to slice dataframes
def get_df_slice(df, ds, at):
    return df[(df['dataset'] == ds) & (df['atlas'] == at)]


# Utility function to get and stack cluster signatures
def get_signatures(df, c):
    return np.stack([np.reshape(_, -1) for _ in get_df_slice(df, *c)['signature']])


def compare_clustering(cluster_definitions: pd.DataFrame, cluster_data: pd.DataFrame,
                       subject_similarity: pd.DataFrame, dataset: str, across: str="atlas",
                       plot: bool=True):
    if across != "atlas":
        raise NotImplementedError("Only across atlas comparisons are currently supported.")

    # Create data class for clustering informations
    cf = ClusteringFrames()

    # Reduce dataframe to only contain dataset of interest
    cf.definitions = cluster_definitions[cluster_definitions['dataset'] == dataset]
    cf.data = cluster_data[cluster_data['dataset'] == dataset]
    cf.subjects = subject_similarity[subject_similarity['dataset'] == dataset]
    
    # Set up some convenience variables and lambdas
    da = ['dataset', 'atlas']  # Regularly used for subsampling dataframe
    cois = ['subject', 'session', 'dataset', 'atlas']  # Used to subsample subject dataframe
    cf.combinations = cf.definitions.value_counts(da).index  # Is sorted in order of most->fewest clusters 

    # Extract info from reference parcellation (the one with the most clusters)
    dataset_atlas = cf.combinations[0]
    sig_ref = get_signatures(cf.definitions, dataset_atlas)
    cf.sorting[dataset_atlas[1]] = np.arange(len(sig_ref)) # dataset_atlas[1] is the atlas name

    # Set up subject matrix for cluster assignment, and populate with reference
    cf.subjects = cf.subjects[cois].groupby(['subject', 'session']).max()
    labels = get_df_slice(cf.data, *dataset_atlas)['labels'].values[0]
    cf.subjects[dataset_atlas[1]] = reindex_clusters(labels, order=cf.sorting[dataset_atlas[1]])

    # Match clusters from other parcellations
    for idx, da2 in enumerate(cf.combinations[1:]):
        # da2 = dataset_atlas_2
        sig_targ = get_signatures(cf.definitions, da2)

        # Compute similarity of clusters, and match them
        cost = cdist(sig_ref, sig_targ, metric='cosine')
        reorder = linear_sum_assignment(cost)
        cf.sorting[da2[1]] = np.argsort(reorder[1])

        # Assign cluster memberships (with updated indice) back to subjects/sessions
        labels = get_df_slice(cf.data, *da2)['labels'].values[0]
        cf.subjects[da2[1]] = reindex_clusters(labels, order=cf.sorting[da2[1]])

    # Create utility dataframes and perform Sankey plotting
    if plot:
        # Set colourmap for plot
        cm = ["#9d973f", "#b253c0", "#64ac48", "#6768cc", "#c67f40", "#5d94ce",
              "#d04a3e", "#4aac8b", "#d14788", "#bb7fc1", "#bb6271"]
        # From : http://medialab.github.io/iwanthue/
        
        # Rename columns and establish the (sequential) sorting of clusters
        col_names = {atlas: 'layer'+str(_+1)
                     for _, atlas in enumerate(cf.plot_order)}
        cords = {'layer'+str(_+1): np.arange(len(cf.sorting[atlas]))
                 for _, atlas in enumerate(cf.plot_order)}
        
        # Create a slimmed down dataframe for Sankey
        sankey_df = cf.subjects[cf.plot_order]
        sankey_df = sankey_df.rename(columns=col_names)
        sankey_df = sankey_df.reset_index().drop(columns=['subject', 'session'])

        color_dict = {_: cm[_] for _ in range(len(sig_ref))}
        
        # Plot
        cf.plot = Sankey(sankey_df, colorMode='global', stripColor='gray', colorDict=color_dict, layerLabels=cords)
        fig, ax = cf.plot.plot(figSize=(15, 15), fontSize=0)
        plt.title(dataset_atlas[0])
        fig.show()

    return cf

In [4]:
df_cluster_def = pd.read_pickle('../data/clustering/cluster_definitions.pkl')
df_cluster_dat = pd.read_pickle('../data/clustering/cluster_membership.pkl')
if not op.exists('../data/clustering/preprocessed_similarity.pkl'):
    join_similarities()
df_similarity = pd.read_pickle('../data/clustering/preprocessed_similarity.pkl')

datasets = sorted(df_cluster_def['dataset'].unique())
def cc_easy(dataset, **kwargs):
    return compare_clustering(df_cluster_def, df_cluster_dat, df_similarity, dataset, **kwargs)

In [5]:
cf = {}
for d in datasets:
    cf[d] = cc_easy(d, plot=False)

In [6]:
def compute_session_variation(clusterframe):
    tdf = []
    for atlas, labels in clusterframe.sorting.items():
        ents = []
        for name, group in clusterframe.subjects.groupby('subject'):
            counts = group.value_counts(normalize=True, subset=[atlas])
            bins = [0] * len(labels)
            for k, v in counts.items():
                bins[k[0]] = v
            ents += [entropy(bins)]
        ent_hat = np.mean(ents)
        ent_bar = ent_hat / entropy([1.0/len(labels)] * len(labels))
        tdf += [{
            "dataset": group.dataset,
            "atlas": atlas,
            "n_classes": len(labels),
            "entropy": ent_hat,
            "normalized_entropy": ent_bar
        }]
    return tdf


def compute_atlas_variation(clusterframe):
    tdf = []
    for idx, da1 in enumerate(clusterframe.combinations):
        for jdx, da2 in enumerate(clusterframe.combinations[idx+1:]):
            l1 = get_df_slice(clusterframe.data, *da1).labels.values[0]
            l2 = get_df_slice(clusterframe.data, *da2).labels.values[0]
            tmp = l1 == l2

            tdf += [{
                "dataset": da1[0],
                "atlas1": da1[1],
                "atlas2": da2[1],
                "ARI": adjusted_rand_score(l1, l2),
                "AMI": adjusted_mutual_info_score(l1, l2),
                "Percent Overlap": np.sum(tmp)/len(tmp)*100
            }]
    return tdf


tdf_s = []
tdf_a = []
# Perform some post-processing for the datasets all together
for d in datasets:
    # Extract the relevant clustering (data)frames collection
    tcf = cf[d]
    
    # Compute the variation in cluster membership across sessions
    is_trt = len(tcf.subjects.index.unique(level='subject')) < len(tcf.subjects.index)
    if is_trt:
        tdf_s += compute_session_variation(tcf)

    # Compute the variation in cluster membership across atlases
    tdf_a += compute_atlas_variation(tcf)

session_overlap = pd.DataFrame.from_dict(tdf_s)
atlas_overlap = pd.DataFrame.from_dict(tdf_a)

In [7]:
session_overlap.describe()

Unnamed: 0,n_classes,entropy,normalized_entropy
count,104.0,104.0,104.0
mean,4.442308,0.453783,0.335217
std,1.852979,0.208785,0.156454
min,2.0,0.077016,0.082248
25%,3.0,0.315067,0.243005
50%,4.0,0.421614,0.315465
75%,5.0,0.523252,0.388826
max,11.0,1.259972,0.957544


In [8]:
atlas_overlap.describe()

Unnamed: 0,ARI,AMI,Percent Overlap
count,162.0,162.0,162.0
mean,0.190582,0.232728,36.70651
std,0.133218,0.110417,18.385048
min,-0.070962,0.032306,5.769231
25%,0.095844,0.152377,21.75
50%,0.176414,0.224182,33.949153
75%,0.253249,0.300838,49.215909
max,0.742516,0.628475,93.333333
