In [1]:
# Download original atlas data first:

# from dipy.data.fetcher import get_two_hcp842_bundles
# from dipy.data.fetcher import (fetch_target_tractogram_hcp,
#                                fetch_bundle_atlas_hcp842,
#                                get_bundle_atlas_hcp842,
#                                get_target_tractogram_hcp)

# target_file, target_folder = fetch_target_tractogram_hcp()
# atlas_file, atlas_folder = fetch_bundle_atlas_hcp842()
# atlas_file, all_bundles_files = get_bundle_atlas_hcp842()
# atlas_file, all_bundles_files

In [2]:
# Convert Atlas
import os
from os.path import join as pjoin
import numpy as np
from dipy.io.streamline import load_trk
from pykdtree.kdtree import KDTree

if 'DIPY_HOME' in os.environ:
    dipy_home = os.environ['DIPY_HOME']
else:
    dipy_home = pjoin(os.path.expanduser('~'), '.dipy')

atlas_dir = pjoin(dipy_home,
          'bundle_atlas_hcp842',
          'Atlas_80_Bundles')
p = pjoin(atlas_dir, 'bundles')

bundle_names = sorted(os.listdir(p))

bundles = {b[:-4]: load_trk(pjoin(p, b), "same", bbox_valid_check=False).streamlines
          for b in bundle_names
          if b not in {'IF0F_R.trk'}}

label_names = sorted(bundles)
fibers = [t for l in label_names
            for t in bundles[l]
            if len(t)>1]
labels = [ix+1 for ix, l in enumerate(label_names)
               for f in bundles[l]
               if len(f)>1]
label_names = [''] + label_names  # 0 = unlabeled fibers

In [3]:
# Some fibers are labeled several times, e.g. CC contains also CC_ForcepsMajor.
# We want each fiber to have a single label. Assumnig that bundles may have
# hierarchial structure and no intersections, we relabel fibers.

tree = KDTree(np.concatenate(fibers))
fiber_ixs = np.array([i for i, f in enumerate(fibers) for _ in range(len(f))])

def find(fiber, threshold=0.1, k=4):
    norm = np.linalg.norm
    d, ixs = tree.query(fiber, k=k)
    results = set()
    
    for d, ixs in zip(d.T, ixs.T):
        if norm(d) < threshold:
            for i in {*fiber_ixs[ixs]}:
                if len(fiber) == len(fibers[i]) and (
                    norm(fiber-fibers[i]) < threshold or
                    norm(fiber[::-1]-fibers[i]) < threshold):
                    results.add(i)

    return results

In [4]:
bundle_sizes = [0] + [len(bundles[l]) for l in label_names[1:]]
hierarchy = {l: [l] for l in label_names}
clean_ixs = []

for i, f in enumerate(fibers):
    ixs = [*find(f)]
    
    if len(ixs) > 1:
        # Find the least populated (the most nested) bundle
        b_sizes = [bundle_sizes[labels[i]] for i in ixs]
        min_ix = ixs[np.argmin(b_sizes)]
        
        if i == min_ix:
            clean_ixs.append(i)
        
            # Append to hierarchy of all outer bundles
            for j in ixs:
                h = hierarchy[label_names[labels[j]]]
                if label_names[labels[i]] not in h:
                    h.append(label_names[labels[i]])
    else:
        clean_ixs.append(i)

In [5]:
# Add unlabeled fibers
atlas_file = pjoin(atlas_dir, 'whole_brain', 'whole_brain_MNI.trk')
sft_atlas = load_trk(atlas_file, "same", bbox_valid_check=False)
whole_brain = sft_atlas.streamlines

In [6]:
unlabeled = [f for f in whole_brain if len(f)>1 and not find(f)]

In [7]:
final_fibers = np.array([fibers[i] for i in clean_ixs] + unlabeled, dtype=object)
final_labels = np.array([labels[i] for i in clean_ixs] + [0]*len(unlabeled), dtype=np.int32)

In [8]:
# Orient bundles
for l, ln in enumerate(label_names):
    bundle = final_fibers[final_labels == l]

    for i in range(100):
        stop_flag = True

        # Find main direction
        start = np.sum([f[0] for f in bundle], axis=0)
        end = np.sum([f[-1] for f in bundle], axis=0)

        for f in bundle:
            if np.dot(end-start, f[-1]-f[0]) < 0:
                f[:] = f[::-1]  # reverse fiber
                stop_flag = False  # continue sorting

        if stop_flag:
            if i:
                print(f'{l} {ln} - Finished reorienting at iteration {i}')
            break

0  - Finished reorienting at iteration 7
1 AC - Finished reorienting at iteration 0
2 AF_L - Finished reorienting at iteration 1
3 AF_R - Finished reorienting at iteration 1
4 AR_L - Finished reorienting at iteration 0
5 AR_R - Finished reorienting at iteration 1
6 AST_L - Finished reorienting at iteration 1
7 AST_R - Finished reorienting at iteration 0
8 CB_L - Finished reorienting at iteration 1
9 CB_R - Finished reorienting at iteration 0
10 CC - Finished reorienting at iteration 1
11 CC_ForcepsMajor - Finished reorienting at iteration 1
12 CC_ForcepsMinor - Finished reorienting at iteration 1
13 CC_Mid - Finished reorienting at iteration 1
14 CNIII_L - Finished reorienting at iteration 0
15 CNIII_R - Finished reorienting at iteration 0
16 CNII_L - Finished reorienting at iteration 0
17 CNII_R - Finished reorienting at iteration 0
18 CNIV_L - Finished reorienting at iteration 0
19 CNIV_R - Finished reorienting at iteration 0
20 CNVIII_L - Finished reorienting at iteration 0
21 CNVII

In [9]:
# Save
fname = 'hcp842_80_atlas.npz'
np.savez_compressed(fname,
                    atlas=final_fibers,
                    labels=final_labels,
                    label_names=label_names,
                    hierarchy=hierarchy)

In [11]:
print(f'{len(final_fibers)} Original whole brain')
print(f'{len(fibers)} Original bundles')
print(f'{len(labels)-len(clean_ixs)} Duplicates removed')
print(f'{len(clean_ixs)} Fibers w/o duplicates')
print(f'{len(unlabeled)} Negative samples')
print(f'{len(final_fibers)} Total fibers written')

144641 Original whole brain
126764 Original bundles
1455 Duplicates removed
125309 Fibers w/o duplicates
19332 Negative samples
144641 Total fibers written


In [None]:
from fury import actor, window

scene = window.Scene()

#scene.add(actor.line(atlas[::10], colors=[1]*3, opacity = 0.05))

# lines = final_fibers[final_labels==0]
# if len(lines)>0:
#     scene.add(actor.line(lines, colors=[1,1,1], opacity = 0.5))

# final_labels==label_names.index('MLF_R')
ixs = centroid_ixs[final_labels[centroid_ixs] == 0]
lines = final_fibers[ixs]
if len(lines)>0:
    scene.add(actor.line(lines, colors=[0,1,0], opacity = 0.5))

# lines = final_fibers[centroid_ixs]
# if len(lines)>0:
#     scene.add(actor.line(lines, colors=[0,1,0], opacity = 0.5))

# for i in range(1, len(label_names)):
#     if '_R' in label_names[i]:
#         lines = final_fibers[final_labels==i]
#         color = np.random.rand(3)
#         #color /= np.linalg.norm(color)
#         if len(lines)>0:
#             scene.add(actor.line(lines, colors=color, opacity = 0.5))

window.show(scene)