In [1]:
# Load new atlas
import os
from os.path import join as pjoin
import numpy as np
from classibundler import *

fname = 'hcp842_80_atlas.npz'
f = np.load(fname, allow_pickle=1)
atlas = f['atlas']
labels = f['labels']
label_names = f['label_names']
hierarchy = f['hierarchy']

In [3]:
threshold = 7
expand_threshold = 2
NF = 5

# Add X reflection
atlas = np.array(set_number_of_points([*atlas], NF))
atlas = np.concatenate((atlas, atlas * [-1, 1, 1])).reshape((-1, NF*3))

In [4]:
labels_refl = np.arange(len(label_names))
lns = label_names.tolist()
for i, ln in enumerate(label_names):
    try:
        if ln.endswith('_R'):
            labels_refl[i] = lns.index(ln[:-1] + 'L')
        if ln.endswith('_L'):
            labels_refl[i] = lns.index(ln[:-1] + 'R')
    except Exception as e:
        print(e)
labels = np.concatenate((labels, labels_refl[labels]))

'F_L_L' is not in list


In [5]:
# Orient bundles
reverser = np.arange(NF * 3).reshape((-1, 3))[::-1].flatten()

for l, ln in enumerate(label_names):
    bundle = atlas[labels == l]

    for i in range(100):
        stop_flag = True

        # Find main direction
        if i == 0:
            sub = bundle[:len(bundle)//2]
        else:
            sub = bundle
        
        start = sub[:, :3].sum(axis=0)
        end = sub[:, -3:].sum(axis=0)
        main_dir = end - start

        for f in bundle:
            if np.dot(main_dir, f[-3:] - f[:3]) < 0:
                f[:] = f[reverser]  # reverse fiber
                stop_flag = False  # continue sorting

        if stop_flag:
            if i:
                print(f'{l} {ln} - Finished reorienting at iteration {i}')
            break

0  - Finished reorienting at iteration 4
1 AC - Finished reorienting at iteration 1
8 CB_L - Finished reorienting at iteration 1
9 CB_R - Finished reorienting at iteration 1
10 CC - Finished reorienting at iteration 1
11 CC_ForcepsMajor - Finished reorienting at iteration 1
12 CC_ForcepsMinor - Finished reorienting at iteration 1
13 CC_Mid - Finished reorienting at iteration 1
22 CNVII_L - Finished reorienting at iteration 1
23 CNVII_R - Finished reorienting at iteration 1
28 CS_L - Finished reorienting at iteration 11
29 CS_R - Finished reorienting at iteration 15
32 CT_L - Finished reorienting at iteration 8
33 CT_R - Finished reorienting at iteration 3
35 C_R - Finished reorienting at iteration 2
62 PC - Finished reorienting at iteration 1
76 V - Finished reorienting at iteration 1


In [None]:
for i in range(5):
    print(f'Iteration {i}')
    # Grab nearest neighbours from the garbage
    result, is_reversed = classify(
        atlas[labels==0], atlas[labels!=0], labels[labels!=0],
        threshold=threshold)

    if not np.any(result):
        break
    
    # Assign new labels as detected
    ixs = np.where(labels==0)[0][is_reversed]
    if len(ixs):
        # fibers[ixs] = fibers[ixs][::-1]
        atlas[ixs] = atlas[ixs, ::-1]
    labels[labels==0] = result

Iteration 0
Iteration 1
Iteration 2
Iteration 3


In [None]:
def spill(fr, to):
    print(f'Spilling {fr} to {to}')
    result, is_reversed = classify(
        atlas[labels==to], atlas[labels==fr], labels[labels==fr],
        threshold=expand_threshold)
    ixs = np.where(labels==to)[0][is_reversed]
    if len(ixs):
        # fibers[ixs] = fibers[ixs][::-1]
        atlas[ixs] = atlas[ixs, ::-1]
    # Assign a 'fr' label
    labels[labels==to][result!=0] = fr

ln = [*label_names]
spill(ln.index('CC_ForcepsMajor'), ln.index('CC'))
spill(ln.index('CC_ForcepsMinor'), ln.index('CC'))
spill(ln.index('CC_Mid'), ln.index('CC'))

In [None]:
centroid_ixs = []

for L in range(len(ln)):
    b_ixs = np.where(labels==L)[0]
#     np.random.shuffle(b_ixs)

    # Reorder to have a ConvexHUll subset in the beginning
    if len(b_ixs) > 3:
        for i in [NF-1, NF//2, 0]:
            try:
                H = ConvexHull(atlas[b_ixs, i*3:i*3+3]).vertices
                b_ixs[:len(H)], b_ixs[H] = b_ixs[H], b_ixs[:len(H)]
            except QhullError as e:
                print('QhullError', L, i)
            except Exception as e:
                print(L, i, e)

    ixs = clusterize(atlas[b_ixs], threshold)
    centroid_ixs.extend(b_ixs[ixs])

In [None]:
# Save
fname = 'hcp842_80_centroids.npz'
np.savez_compressed(fname,
                    atlas=atlas[centroid_ixs],
                    labels=labels[centroid_ixs],
                    label_names=label_names,
                    hierarchy=hierarchy)