Visualzations of subject cluster centriods in MNI

In [None]:
from identify_subbundles import *
from dipy.io.streamline import save_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram
import AFQ.data as afd
from AFQ.utils.streamlines import bundles_to_tgram
from AFQ.viz.plotly_backend import visualize_volume, visualize_bundles

In [None]:
DATA_DIR = DATA_DIRS[0] # HCP_1200; i.e. test session

MNI_T2_IMG = afd.read_mni_template()
MNI_T2_IMG_DATA = MNI_T2_IMG.get_fdata()

BUNDLE_DICT = {
    "SLF_0" : {
        "uid" : 0
    },
    "SLF_1" : {
        "uid" : 1
    },
    "SLF_2" : {
        "uid" : 2
    }
}

CONSENSUS = '200614' # consensus subject for relabeling strategy

In [None]:
def save_centroids(centroids, centroids_name):
    """
    save individual tractogram files for each subject cluster centroid
    
    Parameters
    ----------
    centroids : dict
    
    centroids_name : string
    """
    for subject in centroids.keys():
        cluster_id = 0
        for cluster_centroid in centroids[subject]:
            save_tractogram(cluster_centroid, f'{subject}_{cluster_id}_centroid_{centroids_name}.trk')
            cluster_id += 1
            
# save_centroids(prealignment_centroids, "prealignment")
# save_centroids(mni_prealignment_centroids, "MNI_prealignment")

In [None]:
def convert_centroids(mni_centroids):
    """
    create tractogram for each mni cluster containing all subjects
    
    Parameters
    ----------
    mni_centroids : dict
    """
    
    clusters = [[],[],[]]

    for subject in mni_centroids.keys():
        cluster_id = 0
        for cluster_centroid in mni_centroids[subject]:
            clusters[cluster_id].append(cluster_centroid.streamlines[0])
            cluster_id += 1

    # any subject/tractogram will do, so just grab first one
    subject = next(iter(mni_centroids))
    tractogram = mni_centroids[subject][0]
    
    bundles = {
        "SLF_0": StatefulTractogram.from_sft(clusters[0], tractogram), 
        "SLF_1": StatefulTractogram.from_sft(clusters[1], tractogram),
        "SLF_2": StatefulTractogram.from_sft(clusters[2], tractogram)
    }
        
    sft = bundles_to_tgram(bundles, BUNDLE_DICT, MNI_T2_IMG)
    
    return sft

In [None]:
def visualize_centriods(sft):
    """
    plotly visualzation for clusters using MNI space
    
    Parameters
    ----------
    sft : StatefulTractogram
    """
    figure = visualize_volume(
        MNI_T2_IMG_DATA,
        interact = False,
        inline = False
    )
    # John's PR 641
    # visualize_bundles(sft, flip_axial=False, bundle_dict=bundle_dict, figure=figure)
    return visualize_bundles(sft, bundle_dict=BUNDLE_DICT, figure=figure)

In [None]:
def visualize_subject_clusters(subject, centroids):
    """
    take the subject and show the centroid for each cluster
    
    Parameters
    ----------
    subject : string
    
    centroids : dict
    """
    clusters = []
    for cluster_centroid in centroids[subject]:
        clusters.append(cluster_centroid.streamlines[0])
    
    sft = StatefulTractogram.from_sft(clusters, cluster_centroid)
    return visualize_centriods(sft)

In [None]:
prealign_centroids = prealignment_centroids(BASE_DIR, DATA_DIR, MODEL_NAME, SUBJECTS, BUNDLE_NAME)
mni_prealign_centroids = move_centriods_to_MNI(DATA_DIR, SUBJECTS, prealign_centroids)

#### Show Consensus subject centroids

In [None]:
visualize_subject_clusters(CONSENSUS, mni_prealign_centroids)
# visualize_subject_clusters('877168', mni_prealign_centroids)

#### Cluster Centroids Labeled by Streamline Count

In [None]:
mni_prealign_sft = convert_centroids(mni_prealign_centroids)
visualize_centriods(mni_prealign_sft)

#### Clusters centroids labeled by best weigheted dice coefficient

In [None]:
cluster_idxs, cluster_names = load_relabeled_clusters(BASE_DIR, DATA_DIR, MODEL_NAME, SUBJECTS, CONSENSUS)
dice_centriods = relabled_centriods(BASE_DIR, DATA_DIR, SUBJECTS, BUNDLE_NAME, cluster_idxs, cluster_names)
mni_dice_centriods = move_centriods_to_MNI(DATA_DIR, SUBJECTS, dice_centriods)
mni_dice_sft = convert_centroids(mni_dice_centriods)
visualize_centriods(mni_dice_sft)

#### Clusters centroids labeled by munkres (maximal trace) weighted dice coefficient

In [None]:
cluster_idxs, cluster_names = load_relabeled_clusters(BASE_DIR, DATA_DIR, MODEL_NAME, SUBJECTS, CONSENSUS, algorithm='munkres')
munkres_centriods = relabled_centriods(BASE_DIR, DATA_DIR, SUBJECTS, BUNDLE_NAME, cluster_idxs, cluster_names)
mni_munkres_centriods = move_centriods_to_MNI(DATA_DIR, SUBJECTS, munkres_centriods)
mni_munkres_sft = convert_centroids(mni_munkres_centriods)
visualize_centriods(mni_munkres_sft)

#### Cluster centroids labeled by MDF

In [None]:
cluster_idxs, cluster_names = load_relabeled_clusters(BASE_DIR, DATA_DIR, MODEL_NAME, SUBJECTS, CONSENSUS, algorithm='mdf')
mdf_centriods = relabled_centriods(BASE_DIR, DATA_DIR, SUBJECTS, BUNDLE_NAME, cluster_idxs, cluster_names)
mni_mdf_centriods = move_centriods_to_MNI(DATA_DIR, SUBJECTS, mdf_centriods)
mni_mdf_sft = convert_centroids(mni_mdf_centriods)
visualize_centriods(mni_mdf_sft)