## Subbundle Model Analysis

reload imports for each cell

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

**NOTE**: To see more detailed information set `logging.DEBUG`

In [None]:
from subbundle_model_analysis_utils import fetch_model_data, make_bundle_dict, ClusterType
from identify_subbundles import *
from visualizations import *

import logging
logger = logging.getLogger('subbundle')
logger.setLevel(logging.INFO)

**NOTE**: Assumes clustering models exist for each `{subject}` and `{session_name}` for each `{bundle_name}`. 

For each `{expirement_name}` -- consisting of feature selection and embedding, choice of clustring algorithm and corresponding model hyperparameters -- results are saved to:

> s3://hcp-subbundle/{expirement_name}/{session_name}/{bundle_name}/{subject}/{n_clusters}/

### Constants

Constants from pyAFQ and HCP dataset

In [None]:
# list of pyAFQ bundle identifers
BUNDLE_NAMES = [
    'ATR_L', 'ATR_R',
    'CGC_L', 'CGC_R',
    'CST_L', 'CST_R',
    'IFO_L', 'IFO_R',
    'ILF_L', 'ILF_R',
    'SLF_L', 'SLF_R',
    'ARC_L', 'ARC_R',
    'UNC_L', 'UNC_R',
    'FA', 'FP'
]

# list of HCP test-retest subject identifiers
SUBJECTS = [
    '103818', '105923', '111312', '114823', '115320',
    '122317', '125525', '130518', '135528', '137128',
    '139839', '143325', '144226', '146129', '149337',
    '149741', '151526', '158035', '169343', '172332',
    '175439', '177746', '185442', '187547', '192439',
    '194140', '195041', '200109', '200614', '204521',
    '250427', '287248', '341834', '433839', '562345',
    '599671', '601127', '627549', '660951', # '662551', 
    '783462', '859671', '861456', '877168', '917255'
]

# list of HCP test and retest session names
SESSION_NAMES = ['HCP_1200', 'HCP_Retest']

### Experiment and Model Metadata

dictionary of information passed to helper functions

**NOTE**:
- One bundle at a time
- Experiment was run for only `ARC_L`, `ARC_R`, `SLF_L`, and `SLF_R`
- Experiment was run for 2-4 clusters

In [None]:
BUNDLE_NAME = 'SLF_L'
print('bundle', BUNDLE_NAME)

**NOTE**: Filtering removes clusters for some subjects. Concerns over impact.

Less than ideal that the two bundles have completely different subjects that exhibit this effect.

**TODO**: Automaticaly detect and remove subjects.

```
for d in */ ; do echo "$d" ; rsync --dry-run --verbose --recursive --existing --ignore-existing --delete-after $d/HCP_1200/ $d/HCP_Retest/ | grep clean | awk '{print $NF}'; done
```

```
for d in */ ; do echo "$d" ; rsync --dry-run --verbose --recursive --existing --ignore-existing --delete-after $d/HCP_Retest/ $d/HCP_1200/ | grep clean | awk '{print $NF}'; done
```

**Current strategy: excluding subjects.** 

Effect: Reduces $N$ subjects reported.

**Alternative strategies:**
- Adjust filter threshold to be less agressive
- Skip cluster.
  This will cause list of clusters to be shorter and mislabel clusters, as result profiles to be calculated incorrectly
- No op cluster.
  Add empty placeholder list or list of all zeros, profiles will likely be calculated incorrectly
- Move or replace model to $N-1$ clusters.
  Time and heavy code impact. Reduces $N$ subjects reported.
- Skip model.
- Skip subject.

To assess effect of alternative strategies could look at smaller subset of subjects that present issue.

Filtering removes one of the clusters for these subjects

For now do not include them in analysis

In [None]:
excluded_subjects = []

if BUNDLE_NAME == 'SLF_L':
    excluded_subjects = ['125525', '195041', '200109', '599671']
elif BUNDLE_NAME == 'SLF_R':
    excluded_subjects = ['122317', '137128', '149741', '187547', '660951']
elif BUNDLE_NAME == 'ARC_L':
    excluded_subjects = ['287248']
elif BUNDLE_NAME == 'ARC_R':
    excluded_subjects = ['135528', '144226', '917255']

print('excluded subjects:', excluded_subjects)

In [None]:
remove_excluded_subjects = True
inspect_excluded_subjects = False

if remove_excluded_subjects:
    print('removing excluded subjects')
    for subject in excluded_subjects:
        SUBJECTS.remove(subject)
elif inspect_excluded_subjects:
    print('inspecting excluded subjects')
    SUBJECTS = excluded_subjects

In [None]:
from os.path import join

import random

metadata = {}

# experiment
metadata['experiment_name'] = 'MASE_FA_Sklearn_KMeans'

metadata['experiment_output_dir'] = join('subbundles', metadata['experiment_name'])

metadata['experiment_bundles'] = [BUNDLE_NAME]

metadata['experiment_subjects'] = SUBJECTS 
print('subjects', metadata['experiment_subjects'])

metadata['experiment_sessions'] = SESSION_NAMES
metadata['experiment_test_session'] = metadata['experiment_sessions'][0]
metadata['experiment_retest_session'] = metadata['experiment_sessions'][1]

metadata['experiment_range_n_clusters'] = [2, 3, 4] 
metadata['experiment_bundle_dict'] = make_bundle_dict(metadata)

# model
metadata['model_name'] = 'mase_kmeans_fa_r2_is_mdf'
metadata['model_scalars'] = [Scalars.DTI_FA]

# analysis
metadata['n_points'] = 100
metadata['algorithm'] = Algorithm.MUNKRES
metadata['bundle_name'] = BUNDLE_NAME
metadata['n_clusters'] = 2

Remove local analysis artifacts

### Pipeline

#### Set up local directory

In [None]:
model_data = fetch_model_data(metadata)

#### Identify a `consensus_subject` and appropriately relabel clusters

   Generates the following local artifacts:
    
   `{expirement_name}/{bundle_name}/{subject}/{session}/{n_clusters}/`
   
  - `{target}_{algorithm}_labels.npy` 

     cluster labels for `subject` using `target` as consensus subject `algorithm`

  - `{subject}_{bundle_name}_{cluster_id}_MNI.trk` 

     cleaned cluster tractogram in MNI space

     **NOTE:** `cluster_id` is the original cluster label from the model.

  - `{subject}_{bundle_name}_{cluster_id}_MNI_density_map.nii.gz` 

     density map for the cleaned cluster tractogram in MNI space used to calculate weighted dice coefficient.
     
     _optionally_ only generated when using `Algorithm.MAXDICE` or `Algorithm.MUNKRES`

build `cluster_info` dict:
- `cluster_info[n_clusters]`
  - `cluster_info[n_clusters]['consensus_subject']`
  - `cluster_info[n_clusters][session_name]`
     - `cluster_info[n_clusters][session_name]['centroids']`
     - `cluster_info[n_clusters][session_name]['tractograms_filenames']`
     - `cluster_info[n_clusters][session_name]['tractograms']`


**NOTE**: some clusters are removed as result of filtering and cleaning

Saving some computational time by using previous consensus subjects

In [None]:
consensus_subjects = {}
consensus_subjects[2] = {}
consensus_subjects[3] = {}
consensus_subjects[4] = {}

if BUNDLE_NAME == 'SLF_L':
    consensus_subjects[2]['consensus_subject'] = '187547'
    consensus_subjects[3]['consensus_subject'] = '660951'
    consensus_subjects[4]['consensus_subject'] = '139839'
elif BUNDLE_NAME == 'SLF_R':
    consensus_subjects[2]['consensus_subject'] = '250427'
    consensus_subjects[3]['consensus_subject'] = '783462'
    consensus_subjects[4]['consensus_subject'] = '172332'
elif BUNDLE_NAME == 'ARC_L':
    consensus_subjects[2]['consensus_subject'] = '859671'
elif BUNDLE_NAME == 'ARC_R':
    consensus_subjects[2]['consensus_subject'] = '115320'
    
print('consensus subjects:', consensus_subjects)

In [None]:
cluster_info = get_cluster_info(metadata, consensus_subjects)

##### Relabel retest clusters based on `consensus_subject`

- labels are aligned across test-retest for consensus_subject before relabeling retest subjects

   `{expirement_name}/{bundle_name}/{consensus_subject}/HCP_Retest/{cluster_number}/consensus_mdf_labels.npy`

##### Reliability

In [None]:
bundle_dice_coeffs = get_bundle_dice_coefficients(metadata)

In [None]:
cluster_dice_coeffs = {}

for n_clusters in metadata['experiment_range_n_clusters']:
    cluster_dice_coeffs[n_clusters] = get_cluster_dice_coefficients(
        metadata,
        cluster_info,
        n_clusters
    )

In [None]:
bundle_afq_profiles = get_bundle_afq_profiles(metadata)

In [None]:
cluster_afq_profiles = {}

for n_clusters in metadata['experiment_range_n_clusters']:    
     cluster_afq_profiles[n_clusters] = get_cluster_afq_profiles(
        metadata, 
        n_clusters, 
        cluster_info[n_clusters]['consensus_subject']
    )

#### Population Visualizations

**WARNING** interactive plotly has been crashing, using ploty to generate pngs

##### bundle

plot bundle FA profiles

In [None]:
display_population_bundle_profiles(
    metadata,
    bundle_afq_profiles[metadata['model_scalars'][0]]
)

bundle streamline count statistics

In [None]:
display_bundle_streamline_stats(metadata, model_data)

bundle weighted dice coefficient statistics

In [None]:
display_bundle_dice_coeff_stats(metadata, bundle_dice_coeffs)

bundle FA profile test-retest reliability statistics

In [None]:
display_bundle_profile_reliability_stats(
    metadata,
    bundle_afq_profiles[metadata['model_scalars'][0]]
)

##### clusters

In [None]:
display_population_cluster_profile(
    metadata, 
    cluster_afq_profiles[metadata['n_clusters']][metadata['model_scalars'][0]],
    'DTI FA',
    metadata['n_clusters']
)

cluster streamline count statistics

**NOTE** can check cluster `model` or `filtered`. here just looking at `clean`

In [None]:
csc = get_cluster_streamline_counts(metadata, model_data)
    
display_cluster_streamline_count_stats(metadata, csc, metadata['n_clusters'])

cluster weighted dice cofficient test-retest reliability

**NOTE**: 
- Dice was higher with DTI and without the two stage cleaning. Could check model bundles, filtered bundles, and not just clean. 
- Also there is a large varation between max and min dice which is effecting the average

In [None]:
bdcs = get_bundle_dice_coeff_stats(
    bundle_dice_coeffs
)

display_cluster_dice_coef(
    metadata,
    cluster_dice_coeffs,
    metadata['n_clusters'],
    bdcs.loc['mean'][0]
)

cluster fa profile test-retest reliability

In [None]:
cpr = get_cluster_profile_reliability(
    metadata,
    cluster_afq_profiles
)
    
bprs = get_bundle_profile_reliability_stats(
    metadata,
    bundle_afq_profiles[metadata['model_scalars'][0]]
)

display_cluster_profile_reliability_stats(
    metadata,
    cpr,
    metadata['n_clusters'],
    bprs.loc['mean'][0]
)

#### Consensus Visualizations

##### bundle

anatomical plot

In [None]:
bundle_anatomy_figures = get_consensus_bundle_anatomy_figures(
    metadata, 
    model_data, 
    cluster_info, 
    metadata['n_clusters']
)

bundle streamline counts

In [None]:
csc = get_consensus_streamline_counts(metadata, model_data, cluster_info)

display(
    csc.loc[csc['n_clusters'] == metadata['n_clusters']].style.set_caption(f"{metadata['bundle_name']} consensus subject streamline counts")
)

bundle weighted dice coefficient test-retest reliability

In [None]:
cbdc = get_bundle_dice_coeff(bundle_dice_coeffs, cluster_info[n_clusters]['consensus_subject'])

display(
    cbdc.style.set_caption(f"{metadata['bundle_name']} consensus subject weighted dice coefficient")
)

In [None]:
cbpr = get_consensus_bundle_profile_reliability(
    metadata,
    bundle_afq_profiles[metadata['model_scalars'][0]],
    cluster_info,
    metadata['n_clusters']
)

display(
    cbpr.style.set_caption(f"{metadata['bundle_name']} consensus subject DTI FA pearson r")
)

##### adjacencies

In [None]:
display_consensus_adjacencies(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

##### model artifacts: silhouette scores and pair plot

plot model artifacts

In [None]:
display_consensus_model_artifacts(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

plot filtered artifacts

In [None]:
display_consensus_filtered_artifacts(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

##### FA profiles

recall using deterministic DTI prior; this is deterministic CSD

jointly plot streamline and bundle FA profiles for `subject` and `session`

In [None]:
display_consensus_streamline_bundle_profile(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

plot cluster FA profiles for `subject`, `session`, and `n_clusters`

In [None]:
display_consensus_cluster_profiles(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

##### cluster streamlines

**TODO** investigate why getting extra line in `bundle_to_tgram` plots

anatomical plot of subjects cluster tractograms in single visualization

In [None]:
consensus_cluster_figs = get_clean_consensus_cluster_tractograms(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

cluster streamline counts

In [None]:
display_consensus_cluster_streamline_counts(
    metadata,
    model_data,
    cluster_info,
    metadata['n_clusters']
)

consensus cluster weighed bundle dice coefficients

In [None]:
display_consensus_cluster_dice_coef(
    metadata,
    cluster_info,
    cluster_dice_coeffs,
    metadata['n_clusters']
)

consensus cluster FA corr

In [None]:
display_consensus_cluster_profile_reliability(
    metadata,
    cluster_info,
    cluster_afq_profiles,
    metadata['n_clusters']
)

#### Show (MNI space) results for individuals and group:

##### centriods

- Quality control check, much easier to view each cluster as centroid

_optionally_ view consensus subject centroids

*optionally* choose a subject to investigate

_optionally_ view centroids for original model clusters, labeled by streamline count. 

- Compare to labeling algorithm

view labeling algoritm

**WARNING** plotly may crash running this
- display_centroids generates multiple plotly 
    - best to run `visualize_tractogram` one at a time

In [None]:
def get_centroid_figures(metadata, cluster_info, n_clusters, save_sfts=False):
    centroid_figs = {}

    for session_name in metadata['experiment_sessions']:
        mni_centroids = get_relabeled_centroids(
            metadata,
            n_clusters,
            session_name,
            cluster_info[n_clusters]['consensus_subject']
        )

        mni_sft = convert_centroids(
            n_clusters,
            session_name,
            mni_centroids,
            metadata['experiment_bundle_dict'],
            save_sfts
        )

        centroid_figs[session_name] = visualize_tractogram(
            mni_sft,
            metadata['experiment_bundle_dict']
        )
        
    return centroid_figs

In [None]:
centroid_figs = get_centroid_figures(
    metadata,
    cluster_info,
    metadata['n_clusters'],
    True # temporary saving for ariel
)

#### Choose $K$ 

From the `metadata['experiment_range_n_clusters']` choose the model that is most reliabile across sessions.

Based on the scalar profiles for subjects' clusters

average RMSE - root mean squared difference per subject

In [None]:
K, data = find_K(metadata, bundle_afq_profiles, cluster_afq_profiles)
print(metadata['bundle_name'], metadata['algorithm'], 'Choosing n_cluster', K)

In [None]:
show_choose_k_data(data)

In [None]:
# temporary saving for ariel
import numpy as np
profile_tensor = get_bundle_profile_tensor(
    metadata,
    bundle_afq_profiles[metadata['model_scalars'][0]]
)
np.save(f'output/{BUNDLE_NAME}_bundle_tensor.npy', profile_tensor)
print(profile_tensor.shape)


for n_clusters in metadata['experiment_range_n_clusters']:  
    profile_tensor = get_cluster_profile_tensor(
        metadata,
        cluster_afq_profiles[n_clusters][metadata['model_scalars'][0]],
        n_clusters
    )
    
    np.save(f'output/{BUNDLE_NAME}_n_clusters_{n_clusters}_tensor.npy', profile_tensor)
    print(profile_tensor.shape)

#### Calinski-Harabasz criterion clustering evaluation

In [None]:
pseudo_f = get_pseudo_f(metadata, cluster_afq_profiles)

In [None]:
# temporary saving for ariel
np.save(f'output/{BUNDLE_NAME}_n_clusters_2_pseudo_f.npy', pseudo_f[2][metadata['model_scalars'][0]])