# Subbundles Part 2: Streamlines

**Subbundle** - a subgroup of streamlines with a set of common properties

Part 2: get `streamlines` and `affine`

##### <span style="color:red">NOTE: Part 2 is more standalone quality control; nothing from this part is saved or used by any other parts</span>

In [None]:
from utils import *

import os.path as op

import nibabel as nib
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.stats.analysis import afq_profile, gaussian_weights
from dipy.tracking.streamline import set_number_of_points

import pandas as pd

from AFQ import api
import AFQ.data as afd
from AFQ.viz.fury_backend import visualize_volume
from AFQ.viz.fury_backend import visualize_bundles

import tempfile
from mpl_toolkits.mplot3d import Axes3D
from dipy.viz import window, actor
from IPython.display import Image, display

## AFQ (from Part 1)

Instantiate AFQ object: `myafq` for desired dataset

In [None]:
compare_test_retest = True

test_retest_dir = 'HCP_test_retest'
test_retest_sessions = ['test', 'retest']
test_retest_names = ['HCP', 'HCP_retest']

# dataset_name = 'HCP'
dataset_name = 'HCP_retest'

# subjects = get_subjects(dataset_name)
subjects = get_subjects_small(dataset_name)
# subjects = get_subjects_medium(dataset_name)

if compare_test_retest:
    sub_dir = test_retest_dir
    
    print('HCP')
    myafq_test = get_afq('HCP')
    display(myafq_test.data_frame)
    
    print('HCP_retest')
    myafq_retest = get_afq('HCP_retest')
    display(myafq_retest.data_frame)
else:
    sub_dir = dataset_name
    
    print(dataset_name)
    myafq = get_afq(dataset_name)
    display(myafq.data_frame)

## Bundles

1. SLF
2. Corpus Callosum
3. Novel bundle
4. Whole brain

#### 1. SLF - *superior longitudinal fasciculus* (reproduce)

- Grotheer, M., Zhen, Z., Lerma-Usabiaga, G., & Grill-Spector, K. (2019). Separate lanes for adding and reading in the white matter highways of the human brain. Nature communications, 10(1), 1-14.

  https://www.nature.com/articles/s41467-019-11424-1


- Schurr, R., Zelman, A., & Mezer, A. A. (2020). Subdividing the superior longitudinal fasciculus using local quantitative MRI. NeuroImage, 208, 116439.

  https://www.sciencedirect.com/science/article/pii/S1053811919310304
  

- De Schotten, M. T., Dell’Acqua, F., Forkel, S., Simmons, A., Vergani, F., Murphy, D. G., & Catani, M. (2011). A lateralized brain network for visuo-spatial attention. Nature Precedings, 1-1.

  https://www.nature.com/articles/npre.2011.5549.1
  https://www.researchgate.net/publication/281573090_A_lateralized_brain_network_for_spatial_attention

<span style="color:blue">**TODO: SLF subbundle tractometry (in part 1)**</span>

Represent anatomical correlate

- define ROIs [Issue #600](https://github.com/yeatmanlab/pyAFQ/issues/600)

##### <span style="color:red">NOTE: By default use SLF bundles, otherwise:</span>

- **Change `bundle_names`**

  - pyAFQ segmentation was run with default bundles. To determine valid names can either:
  
     - refer to [documentation](https://yeatmanlab.github.io/pyAFQ/), or 
     
     - inspect the `myafq.bundle_dict` object

In [None]:
# if compare_test_retest:
#     bundle_names = [*myafq_retest.bundle_dict]
# else:
#     bundle_names = [*myafq.bundle_dict]

# bundle_names = ['SLF_L', 'SLF_R']
# bundle_names = ['ARC_L', 'ARC_R', 'CST_L', 'CST_R', 'FP'] 
bundle_names = ['SLF_L', 'SLF_R', 'ARC_L', 'ARC_R', 'CST_L', 'CST_R', 'FP']

#### 2. Corpus callosum tract profiles (baseline)

<span style="color:blue">**TODO: Corpus callosum tractometry (in part 1)**</span>

- define ROIs

  - use midsaggital inclusion ROI and through midline
  
  - union of all callosum bundles

#### 3. Novel bundles (predictive)

Bundles where results are less established and more speculative

<span style="color:blue">**TODO: select existing bundles**</span>

- <span style="color:red">**Question: what bundles choose?**</span>

  - Are there other bundles that would be ideal candidates? If so, why?
  
  - What does literature say?
  
  - Could be greedy run on all bundles defined by RECO or Waypoint ROI

- <span style="color:red">**Question: are there any bundles that do not result in subbundles?**</span>

  - Or does this approach always subdivide?
  
    - For example if recursively use outputs as inputs, will there always be more subbundles?

#### 4. Whole Brain Tractometry

<span style="color:blue">**TODO: Run on whole brain tractometry**</span>

- See whether reproduce same top level bundles

### Get Tractogram Files

Name of tractogram file: `tg_fname` used for importing initial streamlines. Streamlines may represent whole brain or some subset (bundle).

In [None]:
tg_fnames = {}

for subject in subjects:
    tg_fnames[subject] = {}

    if compare_test_retest:
        loc_test  = get_iloc(myafq_test, subject)
        loc_retest = get_iloc(myafq_retest, subject)
    else:
        loc = get_iloc(myafq, subject)
        
    for bundle_name in bundle_names:
        tg_fnames[subject][bundle_name] = {}
        
        if compare_test_retest:
            iterables = zip(test_retest_names, test_retest_sessions, [myafq_test, myafq_retest], [loc_test, loc_retest])
        else:
            iterables = zip([sub_dir], [dataset_name], [myafq], [loc])

        for name, ses, myafq, loc in iterables:
            tg_fnames[subject][bundle_name][ses] = get_tractogram_filename(myafq, bundle_name, loc)

#### Check Tractogram Header

In [None]:
check_header = False

if check_header:
    for subject in subjects:
        for bundle_name in bundle_names:
            if compare_test_retest:
                tg_fname = tg_fnames[subject][bundle_name]['retest']
            else:    
                tg_fname = tg_fnames[subject][bundle_name][dataset_name]
            
            print(sub_dir, subject, bundle_name, nib.streamlines.load(tg_fname).header)

### Get Streamlines

In [None]:
tractograms = {}

for subject in subjects:
    tractograms[subject] = {}
    
    for bundle_name in bundle_names:
        tractograms[subject][bundle_name] = {}

        if compare_test_retest:
            iterables = zip(test_retest_names, test_retest_sessions)
        else:
            iterables = zip([sub_dir''], [dataset_name])

        for name, ses in iterables:
            tg_fname = tg_fnames[subject][bundle_name][ses]
            tractograms[subject][bundle_name][ses] = load_tractogram(tg_fname, 'same')

            
tg_df = pd.DataFrame.from_dict(
    {(i,j,k): [len(tractograms[i][j][k].streamlines), tractograms[i][j][k].affine, tg_fnames[i][j][k]] for i in tractograms.keys() for j in tractograms[i].keys() for k in tractograms[i][j].keys()}, 
    orient='index', 
    columns=['number of streamlines', 'affine', 'tratogram files']
)

with pd.option_context('display.max_colwidth', -1):
    display(tg_df)

os.makedirs(op.join('subbundles', sub_dir), exist_ok=True)
f_name = op.join('subbundles', sub_dir, f'tractogram_info.csv')
print(f_name)
tg_df.to_csv(f_name)

## QC: Bundle Visualization (OPTIONAL)

In [None]:
interact = False

##### Bundle Streamlines

<span style="color:blue">**TODO: Left and Right hemispheres appear flipped between HARDI and HCP**</span>

In [None]:
show_bundle_streamlines = False

if show_bundle_streamlines:    
    for subject in subjects:
        if compare_test_retest:
            loc_test  = get_iloc(myafq_test, subject)
            loc_retest = get_iloc(myafq_retest, subject)
        else:
            loc = get_iloc(myafq, subject)
            
        for bundle_name in bundle_names:
            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions, [myafq_test, myafq_retest], [loc_test, loc_retest])
            else:
                iterables = zip([sub_dir], [dataset_name], [myafq], [loc])

            for name, ses, myafq, loc in iterables:
                fname = tempfile.NamedTemporaryFile().name + '.png'
                
                volume, color_by_volume = myafq._viz_prepare_vols(
                    myafq.data_frame.iloc[loc],
                    volume=None,
                    xform_volume=False,
                    color_by_volume=None,
                    xform_color_by_volume=False
                )
                
                scene = visualize_volume(
                    volume,
                    interact=False,
                    inline=False
                )

                tractogram = tractograms[subject][bundle_name][ses]
                sft = StatefulTractogram.from_sft(tractogram.streamlines, tractogram)
                sft.to_vox()
                visualize_bundles(sft, figure=scene, interact=interact)

                print(name, subject, bundle_name, ses, 'streamlines')
                window.record(scene, out_path=fname, size=(300, 300))
                display(Image(filename=fname))

##### Bundle Tract Profile

In [None]:
show_bundle_tract_profiles = True

if show_bundle_tract_profiles:
    for subject in subjects:
        if compare_test_retest:
            loc_test  = get_iloc(myafq_test, subject)
            loc_retest = get_iloc(myafq_retest, subject)
        else:
            loc = get_iloc(myafq, subject)

        for bundle_name in bundle_names:
            if compare_test_retest:
                make_dirs(myafq_retest, sub_dir, bundle_name, subjects)
                target_dir = get_dir_name(myafq_retest, sub_dir, bundle_name, loc_retest)
                scalars = myafq_retest.scalars
            else:
                make_dirs(myafq, sub_dir, bundle_name, subjects)
                target_dir = get_dir_name(myafq, sub_dir, bundle_name, loc)
                scalars = myafq.scalars
            
            for scalar_name in scalars:
                truncated_name = scalar_name.split('_')[-1]

                profiles = []
                
                if compare_test_retest:
                    iterables = zip(test_retest_names, test_retest_sessions, [myafq_test, myafq_retest], [loc_test, loc_retest])
                else:
                    iterables = zip([sub_dir], [dataset_name], [myafq], [loc])

                for name, ses, myafq, loc in iterables:
                    scalar_data = nib.load(get_scalar_filename(myafq, scalar_name, loc)).get_fdata()
                    
                    if len(tractograms[subject][bundle_name][ses].streamlines) == 0:
                        profiles.append(np.zeros(100))
                    else:
                        profiles.append(afq_profile(
                            scalar_data,
                            tractograms[subject][bundle_name][ses].streamlines,
                            tractograms[subject][bundle_name][ses].affine,
                            weights=gaussian_weights(tractograms[subject][bundle_name][ses].streamlines)
                        ))

                if compare_test_retest:
                    print('test-retest bundle profile correlation:')
                    # Calculate Pearson correlations between profiles (test-retest reliability)
                    test_retest_corr_matrix = pd.DataFrame(zip(*profiles), columns=test_retest_sessions).corr()
                
                    # select only the upper triangle off diagonals of the correlation matrix
                    test_retest_corr = pd.Series(test_retest_corr_matrix.where(np.triu(np.ones(test_retest_corr_matrix.shape), 1).astype(np.bool)).stack(), name='corr')
                    
                    display(test_retest_corr)
                    
                    # save the correlation
                    f_name = op.join(target_dir, f'{truncated_name}_trt_tract_profile_corr.csv')
                    print(f_name)
                    test_retest_corr.to_csv(f_name)    
                
                plt.figure()
                
                if compare_test_retest and not test_retest_corr.empty:
                    plt.title(f'{dataset_name} {subject} {bundle_name} {scalar_name} tract profiles\ncorrelation {test_retest_corr.iloc[0]:.5f}')
                else:
                    plt.title(f'{dataset_name} {subject} {bundle_name} {scalar_name} tract profiles')
                
                if compare_test_retest:
                    iterables = zip(test_retest_sessions, profiles)
                else:
                    iterables = zip([dataset_name], profiles)
                    
                for ses, profile in iterables:
                    if ses == 'retest':
                        plt.plot(profile, c='k', linestyle='dashed', label=ses)
                    else:
                        plt.plot(profile, c='k', label=ses)
                                  
                plt.xlabel('node index')
                plt.ylabel(f'{scalar_name} values')
                plt.legend()
                if compare_test_retest:
                    f_name = op.join(target_dir, f'{truncated_name}_trt_tract_profile.png')
                else:
                    f_name = op.join(target_dir, f'{truncated_name}_tract_profile.png')
                print(f_name)
                plt.savefig(f_name)
                plt.show()

#### QC: Streamline Metrics/Statisitical Measurements

<span style="color:blue">**TODO: outliers, mean, and variation**</span>

- See [Streamline analysis and connectivity](https://dipy.org/documentation/1.3.0./examples_index/#streamline-analysis-and-connectivity)

##### Frequency Distribution of Streamline Length (voxel units)

In [None]:
show_min_max_ref_streamline = False

In [None]:
show_bundle_voxel_freq = False or show_min_max_ref_streamline

In [None]:
if show_bundle_voxel_freq:
    voxel_freqs = {}

    for subject in subjects:
        voxel_freqs[subject] = {}

        for bundle_name in bundle_names:
            voxel_freqs[subject][bundle_name] = {}

            if compare_test_retest:
                iterables = test_retest_sessions
            else:
                iterables = [dataset_name]

            for ses in iterables:
                voxel_freqs[subject][bundle_name][ses] = [len(streamline) for streamline in tractograms[subject][bundle_name][ses].streamlines]

In [None]:
if show_bundle_voxel_freq:
    for subject in subjects:
        for bundle_name in bundle_names:
            plt.figure()
            plt.title(f'{dataset_name} {subject} {bundle_name} Streamline Voxel Frequency')

            voxel_freq = voxel_freqs[subject][bundle_name]

            if compare_test_retest:
                plt.hist((voxel_freq['test'], voxel_freq['retest']), bins=range(min(voxel_freq['retest'])-min(voxel_freq['retest'])%50, max(voxel_freq['retest'])+50-max(voxel_freq['retest'])%50, 50), label=('test', 'retest'))
            else:
                plt.hist(voxel_freq[dataset_name], bins=range(min(voxel_freq[dataset_name])-min(voxel_freq[dataset_name])%50, max(voxel_freq[dataset_name])+50-max(voxel_freq[dataset_name])%50, 50))
            
            plt.xlabel('length')
            plt.ylabel('num streamlines')
            plt.legend()
            plt.show()

##### Various Visualizations for Min, Max, Mean Streamlines

<span style="color:red">**Question: What is appropriate reference?**</span>
    
- First pass, calculate a streamline on mean position of all streamlines in bundle

##### Create mean position (centriod) streamline

Resample streamlines so can compute the mean

In [None]:
if show_min_max_ref_streamline:
    reference_streamlines = {}

    for subject in subjects:
        reference_streamlines[subject] = {}
        for bundle_name in bundle_names:
            reference_streamlines[subject][bundle_name] = {}
            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions)
            else:
                iterables = zip([sub_dir], [dataset_name])

            for name, ses in iterables:
                reference_streamlines[subject][bundle_name][ses] = np.mean(set_number_of_points(tractograms[subject][bundle_name][ses].streamlines, 100), axis=0)

    display(pd.DataFrame.from_dict(
        {(i,j,k): len(reference_streamlines[i][j][k]) for i in reference_streamlines.keys() for j in reference_streamlines[i].keys() for k in reference_streamlines[i][j].keys()}, 
        orient='index', 
        columns=['reference streamline n_points']
    ))

##### Identify min and max streamlines by length

In [None]:
if show_min_max_ref_streamline:
    max_streamlines = {}
    min_streamlines = {}

    for subject in subjects:
        max_streamlines[subject] = {}
        min_streamlines[subject] = {}
        for bundle_name in bundle_names:
            max_streamlines[subject][bundle_name] = {}
            min_streamlines[subject][bundle_name] = {}

            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions)
            else:
                iterables = zip([sub_dir], [dataset_name])

            for name, ses in iterables:
                max_streamlines[subject][bundle_name][ses] = tractograms[subject][bundle_name][ses].streamlines[np.argmax(voxel_freqs[subject][bundle_name][ses])]
                min_streamlines[subject][bundle_name][ses] = tractograms[subject][bundle_name][ses].streamlines[np.argmin(voxel_freqs[subject][bundle_name][ses])]

    display(pd.DataFrame.from_dict(
        {(i,j,k): [len(min_streamlines[i][j][k]), len(max_streamlines[i][j][k])] for i in reference_streamlines.keys() for j in reference_streamlines[i].keys() for k in reference_streamlines[i][j].keys()}, 
        orient='index', 
        columns=['min n_points', 'max n_points']
    ))

##### Resample the min and max streamlines

So that they both have the same number of points per streamline

[`set_number_of_points`](https://dipy.org/documentation/1.3.0./reference/dipy.segment/#dipy.segment.benchmarks.bench_quickbundles.set_number_of_points)

Change the number of points of streamlines in order to obtain `nb_points-1` segments of equal length

In [None]:
if show_min_max_ref_streamline:
    sampled_max_streamlines = {}
    sampled_min_streamlines = {}


    for subject in subjects:
        sampled_max_streamlines[subject] = {}
        sampled_min_streamlines[subject] = {}
        for bundle_name in bundle_names:
            sampled_max_streamlines[subject][bundle_name] = {}
            sampled_min_streamlines[subject][bundle_name] = {}

            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions)
            else:
                iterables = zip([sub_dir], [dataset_name])

            for name, ses in iterables:
                sampled_max_streamlines[subject][bundle_name][ses] = set_number_of_points(max_streamlines[subject][bundle_name][ses], 100)
                sampled_min_streamlines[subject][bundle_name][ses] = set_number_of_points(min_streamlines[subject][bundle_name][ses], 100)

    display(pd.DataFrame.from_dict(
        {(i,j,k): [len(sampled_min_streamlines[i][j][k]), len(sampled_max_streamlines[i][j][k])] for i in reference_streamlines.keys() for j in reference_streamlines[i].keys() for k in reference_streamlines[i][j].keys()}, 
        orient='index', 
        columns=['sampled min n_points', 'sampled max n_points']
    ))

#### QC: Streamline Visualization

Plot the min, max, and mean streamlines in RAS+

In [None]:
if show_min_max_ref_streamline:
    for subject in subjects:
        for bundle_name in bundle_names:
            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions)
            else:
                iterables = zip([sub_dir], [dataset_name])

            for name, ses in iterables:
                maxlen = max_streamlines[subject][bundle_name][ses]
                minlen = min_streamlines[subject][bundle_name][ses]
                mean = reference_streamlines[subject][bundle_name][ses]
                fig = plt.figure()
                ax = Axes3D(fig)
                plt.title(f'{name} {subject} {bundle_name} {ses} Streamlines (RAS mm)')
                ax.scatter3D(maxlen[:,0], maxlen[:,1], maxlen[:,2], c='tab:green', label='max')
                ax.scatter3D(minlen[:,0], minlen[:,1], minlen[:,2], c='g', label='min')
                ax.scatter3D(mean[:,0], mean[:,1],mean[:,2], c='tab:red', label='mean')
                plt.legend()
                ax.set_xlabel('transverse/axial') # left (-) / right (+)
                ax.set_ylabel('sagittal') # back (-) / forward (+)
                ax.set_zlabel('coronal') # bottom (-) / top (+)
                plt.show()

<span style="color:blue">**TODO: determine camera `positon` and `focal_point`**</span>

<span style="color:red">**Question: What cooridnate system is volume in? What is demension of the volume? and How does `visualize_bundles` get horizontal slice perspective?**</span>

- Given the name of of `trk` file *RASMM* should be in subject space with RAS MM coordinates. Another indication is that the coordinate axes are negative, which wouldn't happen for voxel space.

```
scene.elevation(90)
scene.set_camera(view_up=(0.0, 0.0, 1.0))
```

```
print(scene.get_camera())
scene.set_camera(position=(0, 0, 1), focal_point=(1, 0, 0))
print(scene.get_camera())
```

In [None]:
if show_min_max_ref_streamline:
    for subject in subjects:
        for bundle_name in bundle_names:
            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions)
            else:
                iterables = zip([sub_dir], [dataset_name])

            for name, ses in iterables:
                maxlen = max_streamlines[subject][bundle_name][ses]
                minlen = min_streamlines[subject][bundle_name][ses]
                mean = reference_streamlines[subject][bundle_name][ses]

                fname = tempfile.NamedTemporaryFile().name + '.png'
                scene = window.Scene()
                # scene.add(actor.streamtube([maxlen, minlen], linewidth=0.5))
                scene.add(actor.point(maxlen, window.colors.green, point_radius=1))
                scene.add(actor.point(minlen, window.colors.green, point_radius=1))
                scene.add(actor.point(mean, window.colors.red, point_radius=1))

                if interact:
                    window.show(scene, title=f'{name} {subject} {bundle_name} {ses} min/max/ref streamlines', size=(300, 300))

                print(name, subject, bundle_name, ses, 'min/max/ref streamlines')
                window.record(scene, out_path=fname, size=(300, 300))
                display(Image(filename=fname))

<span style="color:red">**Question: Why downsample streamlines? Why choose `n=100`?**</span>

- Data reduction/runtime performance? Simplify calculations?

In [None]:
if show_min_max_ref_streamline:
    for subject in subjects:
        for bundle_name in bundle_names:
            if compare_test_retest:
                iterables = zip(test_retest_names, test_retest_sessions)
            else:
                iterables = zip([sub_dir], [dataset_name])

            for name, ses in iterables:
                maxlen = sampled_max_streamlines[subject][bundle_name][ses]
                minlen = sampled_min_streamlines[subject][bundle_name][ses]
                mean = reference_streamlines[subject][bundle_name][ses]

                fname = tempfile.NamedTemporaryFile().name + '.png'
                scene = window.Scene()
                # scene.add(actor.streamtube([sampledmaxlen, sampledminlen], linewidth=0.5))
                scene.add(actor.point(maxlen, window.colors.green, point_radius=1))
                scene.add(actor.point(minlen, window.colors.green, point_radius=1))
                scene.add(actor.point(mean, window.colors.red, point_radius=1))

                if interact:
                    window.show(scene, title=f'{name} {subject} {bundle_name} {ses} min/max/ref sampled streamlines', size=(300, 300))

                print(name, subject, bundle_name, ses, 'min/max/ref sampled streamlines')
                window.record(scene, out_path=fname, size=(300, 300))
                display(Image(filename=fname))

### Streamline Registration (Move to part 7)

See [Streamline-based Registration](https://dipy.org/tutorials/#id10)

Register two bundles from two different subjects directly in the space of streamlines

This will be useful for multisubject comparison

<span style="color:red">**Question: Which subject is `static` and which is `moving`? Is there some similar notion to MNI space?**</span>

#### [Streamline-based Linear Registration (SLR)](https://dipy.org/documentation/1.3.0./examples_built/bundle_registration/)

<span style="color:blue">**TODO: Provide some context and description of problem**</span>

- https://www.sciencedirect.com/science/article/pii/S1053811915003961