# Moving Averages CINGULATE region

### 1) Imports

In [24]:
import os
import glob
import sys
import json

p = os.path.abspath('../')
if p not in sys.path:
    sys.path.append(p)

In [25]:
import moving_averages as ma
import colorado as cld
import dico_toolbox as dtx
from tqdm import tqdm

import pandas as pd
import numpy as np
from scipy.spatial import distance, distance_matrix

import plotly.graph_objects as go
import pickle
import matplotlib.pyplot as plt

from soma import aims

import torch

from sklearn.cluster import KMeans, SpectralClustering, AffinityPropagation
from sklearn import metrics

import matplotlib.cm as cm

from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf, DictConfig

In [26]:
import anatomist.notebook as ana
a = ana.Anatomist()
print(a.headless_info.__dict__)

{'xvfb': <subprocess.Popen object at 0x7fcf9843d198>, 'original_display': ':1', 'display': 2, 'glx': 2, 'virtualgl': None, 'headless': True, 'mesa': False, 'qtapp': None}


In [27]:
def closest_distance(centroid, df):
    """Returns closest point to centroid of a given cluster
    /!\ central_point is part of the defined set of points
    Args:
        IN: centroid: tuple of coordinates of centroid
            df: pandas.DataFrame with index corresponding to subjects and as many columns as coordinates
        OUT: central_point: subject ID corresponding to closest point to the centroid
    """
    # df must have as many columns as dimensions to consider
    distances = {}
    # Get distances to centroid for each point
    for sub in list(df.index):
        pos = [df.loc[sub][k] for k in range(1, len(df.columns)+1)]
        distances[sub] = distance.euclidean(pos, centroid)
    # Get closest point to centroid
    central_point = min(distances, key=distances.get)
    return central_point

In [28]:
src_dir = '/neurospin/dico/data/deep_folding/papers/midl2022/crops/CINGULATE/mask/sulcus_based/2mm/centered_combined/hcp'

In [29]:
path = f"{src_dir}/Rcrops"

In [30]:
bucket_dir = f"{src_dir}/Rbuckets"

In [31]:
# If true, meshes are saved as files
save_mesh = False

# if true, buckets are saved as files
save_bucket = False

### 2) Loading of subjects's distribution in the latent space
DataFrame is obtained during analysis of beta-VAE latent space (2 dimensions) for CINGULATE crops, according to the following steps:
- loading of trained model
- encoding of test controls and asymmetry benchmark subjects

In [32]:
if torch.cuda.is_available():
    device = "cuda:0"
print(device)

cuda:0


In [33]:
# We load the mebeddings
result_dir = '/host/volatile/jc225751/Runs/33_MIDL_2022_reviews/Output/t-0.1-analysis_output/n-004_o-4'
val_embeddings_file = f"{result_dir}/train_val_embeddings.pt"
embeddings = torch.load(val_embeddings_file)
print(embeddings.shape)

torch.Size([550, 4])


In [34]:
# We load the labels
with open(f"{result_dir}/train_val_filenames.json", 'r') as f:
    labels = json.load(f)
print(labels[0])
print(labels[:10])
print(labels[-10:-1])

129634
['129634', '136833', '206222', '138837', '987983', '562345', '182032', '275645', '151021', '177342']
['395958', '154734', '385046', '200008', '792867', '618952', '433839', '188448', '166640']


In [35]:
kmeans = KMeans(n_clusters=3, random_state=0).fit(embeddings)
cluster_labels = kmeans.labels_
clusters_centroids = kmeans.cluster_centers_
print(f"clusters centroids = {clusters_centroids.shape}")
print(f"Average silhouette score: {metrics.silhouette_score(embeddings, cluster_labels)}")


clusters centroids = (3, 4)
Average silhouette score: 0.3474253714084625


### Determines closest and furthest subjects

In [36]:
type(embeddings)

torch.Tensor

In [37]:
embeddings

tensor([[-0.7047,  0.4643,  1.2135,  0.2982],
        [-0.3508,  0.4110,  0.9183,  1.0133],
        [-0.7645, -1.2402,  1.7594,  0.1883],
        ...,
        [-0.1585,  0.3070,  0.7922,  1.4341],
        [-0.2880, -1.2718,  1.3776,  1.2204],
        [ 0.6378,  1.3223, -1.0439,  0.8875]])

In [38]:
embeddings[0]

tensor([-0.7047,  0.4643,  1.2135,  0.2982])

In [39]:
x = embeddings.numpy()

In [40]:
type(x)

numpy.ndarray

In [41]:
x.shape

(550, 4)

In [42]:
dist = distance_matrix(x,x)

In [43]:
(dist+1000*np.eye(x.shape[0])).min()

0.02212502621114254

In [44]:
dist.argmax(axis=None)

109694

In [45]:
np.unravel_index(dist.argmax(axis=None), dist.shape)

(199, 244)

In [122]:
dist[199,244]

IndexError: invalid index to scalar variable.

In [47]:
def allmax(a):
    if len(a) == 0:
        return []
    all_ = [0]
    max_ = a[0]
    for i in range(1, len(a)):
        if a[i] > max_:
            all_ = [i]
            max_ = a[i]
        elif a[i] == max_:
            all_.append(i)
    return all_

In [48]:
def allmin(a):
    if len(a) == 0:
        return []
    all_ = [1000.]
    min_ = 1000.
    for i in range(0, len(a)):
        if a[i] < min_ and a[i] > 0.:
            all_ = [i]
            min_ = a[i]
        elif a[i] == min_:
            all_.append(i)
    return all_

In [49]:
dist.flatten()[:10]

array([0.        , 0.85240144, 1.79422212, 2.80109072, 3.4863627 ,
       1.05017078, 2.59089971, 1.26592112, 3.39883661, 2.33007884])

In [50]:
def get_indices_of_k_smallest(arr, k):
    idx = np.argpartition(arr.ravel(), k)
    return np.array(np.unravel_index(idx, arr.shape))[:, range(k)].transpose().tolist()

In [51]:
def get_indices_of_k_biggest(arr, k):
    idx = np.argpartition(arr.ravel(), -k)
    return np.array(np.unravel_index(idx, arr.shape))[:, range(-k,0)].transpose().tolist()

In [52]:
def get_subject_id_by_index(list_index, labels):
    subjects = [labels[idx] for idx in list_index]
    return subjects

In [53]:
dist.shape

(550, 550)

In [54]:
smallest_idx = get_indices_of_k_smallest(dist + 1000.*np.eye(dist.shape[0]), 10)
print(smallest_idx)
for idx in range(len(smallest_idx)):
    print(dist[smallest_idx[idx][0], smallest_idx[idx][1]])

[[262, 128], [88, 29], [128, 262], [29, 88], [82, 350], [338, 137], [89, 336], [336, 89], [137, 338], [350, 82]]
0.03441549092531204
0.02212502621114254
0.03441549092531204
0.02212502621114254
0.03787141293287277
0.03972189873456955
0.03799424692988396
0.03799424692988396
0.03972189873456955
0.03787141293287277


In [55]:
[get_subject_id_by_index(list_idx, labels) for list_idx in smallest_idx]

[['568963', '107422'],
 ['318637', '174437'],
 ['107422', '568963'],
 ['174437', '318637'],
 ['792766', '248339'],
 ['111211', '117021'],
 ['644044', '198653'],
 ['198653', '644044'],
 ['117021', '111211'],
 ['248339', '792766']]

In [56]:
biggest_idx = get_indices_of_k_biggest(dist, 300)
print(biggest_idx)
for idx in range(len(biggest_idx)):
    print(dist[biggest_idx[idx][0], biggest_idx[idx][1]])

[[387, 199], [199, 387], [268, 169], [199, 451], [73, 365], [531, 363], [286, 232], [232, 286], [363, 531], [365, 73], [169, 268], [451, 199], [365, 317], [469, 120], [346, 73], [73, 346], [407, 448], [448, 407], [286, 8], [448, 206], [41, 29], [73, 363], [8, 286], [206, 448], [363, 73], [317, 365], [129, 4], [4, 129], [199, 206], [206, 199], [324, 473], [531, 169], [29, 41], [120, 469], [169, 531], [473, 324], [136, 378], [244, 166], [448, 361], [136, 67], [378, 136], [47, 244], [41, 88], [473, 286], [67, 136], [286, 4], [244, 139], [169, 4], [324, 443], [378, 18], [67, 18], [448, 339], [443, 324], [361, 448], [313, 199], [166, 244], [88, 41], [194, 286], [268, 341], [18, 378], [407, 199], [268, 324], [527, 120], [324, 531], [339, 448], [199, 313], [324, 268], [41, 363], [244, 47], [120, 527], [317, 67], [341, 268], [363, 41], [244, 483], [199, 407], [67, 317], [4, 286], [286, 194], [531, 324], [286, 473], [18, 67], [139, 244], [4, 169], [483, 244], [313, 346], [244, 542], [136, 286],

In [57]:
[get_subject_id_by_index(list_idx, labels) for list_idx in biggest_idx]

[['833148', '406836'],
 ['406836', '833148'],
 ['150928', '707749'],
 ['406836', '152427'],
 ['111009', '159845'],
 ['723141', '212015'],
 ['141119', '766563'],
 ['766563', '141119'],
 ['212015', '723141'],
 ['159845', '111009'],
 ['707749', '150928'],
 ['152427', '406836'],
 ['159845', '645551'],
 ['175136', '204218'],
 ['120414', '111009'],
 ['111009', '120414'],
 ['604537', '510326'],
 ['510326', '604537'],
 ['141119', '151021'],
 ['510326', '878877'],
 ['115825', '174437'],
 ['111009', '212015'],
 ['151021', '141119'],
 ['878877', '510326'],
 ['212015', '111009'],
 ['645551', '159845'],
 ['152831', '987983'],
 ['987983', '152831'],
 ['406836', '878877'],
 ['878877', '406836'],
 ['784565', '188145'],
 ['723141', '707749'],
 ['174437', '115825'],
 ['204218', '175136'],
 ['707749', '723141'],
 ['188145', '784565'],
 ['206828', '173536'],
 ['174841', '139637'],
 ['510326', '255740'],
 ['206828', '877168'],
 ['173536', '206828'],
 ['204622', '174841'],
 ['115825', '318637'],
 ['188145',

In [58]:
argmax_ = allmax(dist.flatten())

In [59]:
argmin_ = allmin(dist.flatten())
argmin_

[16038, 48429]

In [60]:
list_argmax = [np.unravel_index(arg, dist.shape) for arg in argmax_]
list_argmax

[(199, 244), (244, 199)]

In [61]:
list_argmin = [np.unravel_index(arg, dist.shape) for arg in argmin_]
list_argmin

[(29, 88), (88, 29)]

In [62]:
for l in list_argmax:
    print(labels[l[0]], labels[l[1]])

406836 174841
174841 406836


In [63]:
for l in list_argmin:
    print(labels[l[0]], labels[l[1]])

174437 318637
318637 174437


### DIsplays buckets of closest subjects

In [64]:
def display_bucket(src_dir: str, subject_id: str):
    """Displays the chosen bucket with bucket files in src_dir"""
    global a
    bucket_file = [s for s in glob.glob(f"{src_dir}/*.bck") if subject_id in s][0]
    bucket = aims.read(bucket_file)
    a_bucket = a.toAObject(bucket)
    w3d = a.createWindow("3D")
    w3d.addObjects(a_bucket)
    return w3d

In [65]:
def display_several_buckets(src_dir: str, subject_id_list: list):
    global a
    length = len(subject_id_list)
    bucket_file_l = []
    for subject_id in subject_id_list:
        bucket_file = [s for s in glob.glob(f"{src_dir}/*.bck") if subject_id in s][0]
        bucket_file_l.append(bucket_file)
    a_bucket_l = []
    for bucket_file in bucket_file_l:
        bucket = aims.read(bucket_file)
        a_bucket = a.toAObject(bucket)
        a_bucket_l.append(a_bucket)  
    block = a.AWindowsBlock(a, 2)
    globals()['w1'] = a.createWindow("3D", block=block)
    globals()['w1'].addObjects(a_bucket_l[0])
    globals()['w2'] = a.createWindow("3D", block=block)
    globals()['w2'].addObjects(a_bucket_l[1])
    return block

In [66]:
block = a.AWindowsBlock(a, 2)

In [67]:
fig0 = display_several_buckets(bucket_dir, ['644044', '198653'])

AnatomistInteractiveWidget(height=479, layout=Layout(height='auto', width='auto'), width=424)

AnatomistInteractiveWidget(height=479, layout=Layout(height='auto', width='auto'), width=424)

In [68]:
fig1 = display_bucket(bucket_dir, '644044')

AnatomistInteractiveWidget(height=479, layout=Layout(height='auto', width='auto'), width=424)

In [69]:
type(fig1)

anatomist.notebook.api.NotebookAnatomist.AWindow

In [70]:
fig2 = display_bucket(bucket_dir, '198653')

AnatomistInteractiveWidget(height=479, layout=Layout(height='auto', width='auto'), width=424)

### Displays in t-SNE

In [71]:
dstrb_sub = pd.DataFrame(embeddings.numpy(), index=labels, columns=[1, 2, 3, 4])
dstrb_sub['cluster_lab'] = cluster_labels

In [72]:
dstrb_sub.head()

Unnamed: 0,1,2,3,4,cluster_lab
129634,-0.704682,0.464345,1.213497,0.298184,1
136833,-0.350781,0.410975,0.918348,1.013292,0
206222,-0.764542,-1.240232,1.75938,0.188272,1
138837,0.730793,-0.473701,-0.269615,1.943166,0
987983,1.014665,1.228762,-1.676012,0.812938,2


In [73]:
clusters_centroids = kmeans.cluster_centers_
print(f"cluster's centroids coordinates: \n {clusters_centroids}")

cluster's centroids coordinates: 
 [[ 1.31399507e-01 -8.80867887e-04  3.96890843e-01  1.43585158e+00]
 [-8.36732560e-01 -1.23226578e-01  1.49042979e+00  1.23057654e-02]
 [ 1.52115546e-01  1.41179076e+00 -5.20310118e-01  2.20132112e-01]]


In [74]:
central_1= closest_distance(clusters_centroids[0], dstrb_sub.drop(['cluster_lab'], axis=1))
print(f"Closest subject to centroid of cluster 1 is {central_1}")
central_2 = closest_distance(clusters_centroids[1], dstrb_sub.drop(['cluster_lab'], axis=1))
print(f"Closest subject to centroid of cluster 2 is {central_2}")

Closest subject to centroid of cluster 1 is 519950
Closest subject to centroid of cluster 2 is 134324


In [75]:
arr = np.array([np.array([dstrb_sub[k][i] for k in dstrb_sub.columns[0:2]]) for i in range(len(dstrb_sub))])

color_dict = {0: 'red', 1:'blue', 2:'green'}
fig, ax = plt.subplots()

for g in np.unique([dstrb_sub.cluster_lab]):
    ix = np.where(dstrb_sub.cluster_lab == g)
    x = [arr[ix][k][0] for k in range(len(ix[0]))]
    y = [arr[ix][k][1] for k in range(len(ix[0]))]
    if g =='benchmark':
        g_lab = 'benchmark asymmetry'
    else:
        g_lab=g
    ax.scatter(x, y, c = color_dict[g], label = g_lab)
    
ax.scatter(clusters_centroids[0][0], clusters_centroids[0][1], color='crimson', marker='X')
ax.scatter(clusters_centroids[1][0], clusters_centroids[1][1], color='navy', marker='X')

plt.xlabel(f'dimension 1', fontsize=14)
plt.ylabel(f'dimension 2', fontsize=14)
plt.legend()
plt.show()

In [76]:
dstrb_sub.head()

Unnamed: 0,1,2,3,4,cluster_lab
129634,-0.704682,0.464345,1.213497,0.298184,1
136833,-0.350781,0.410975,0.918348,1.013292,0
206222,-0.764542,-1.240232,1.75938,0.188272,1
138837,0.730793,-0.473701,-0.269615,1.943166,0
987983,1.014665,1.228762,-1.676012,0.812938,2


In [77]:
dstrb_sub.index[0]

'129634'

In [78]:
arr = np.array([np.array([dstrb_sub[k][i] for k in dstrb_sub.columns[2:4]]) for i in range(len(dstrb_sub))])

color_dict = {0: 'red', 1:'blue', 2:'green'}
fig, ax = plt.subplots()

for g in np.unique([dstrb_sub.cluster_lab]):
    ix = np.where(dstrb_sub.cluster_lab == g)
    x = [arr[ix][k][0] for k in range(len(ix[0]))]
    y = [arr[ix][k][1] for k in range(len(ix[0]))]
    ax.scatter(x, y, c = color_dict[g], label = g)
    
ax.scatter(clusters_centroids[0][2], clusters_centroids[0][3], color='crimson', marker='X')
ax.scatter(clusters_centroids[1][2], clusters_centroids[1][3], color='navy', marker='X')
# ax.scatter(dstrb_sub[3]['113316'], dstrb_sub[4]['113316'], color='orange')
# ax.scatter(dstrb_sub[3]['144125'], dstrb_sub[4]['144125'], color='forestgreen')
# ax.scatter(dstrb_sub[3]['217126'], dstrb_sub[4]['217126'], color='black')

plt.xlabel(f'dimension 3', fontsize=14)
plt.ylabel(f'dimension 4', fontsize=14)
plt.legend()
plt.show()

In [79]:
fig, ax = plt.subplots()
for g in np.unique([dstrb_sub.cluster_lab]):
    ix = np.where(dstrb_sub.cluster_lab == g)
    x = [arr[ix][k][0] for k in range(len(ix[0]))]
    y = [arr[ix][k][1] for k in range(len(ix[0]))]
    ax.scatter(x, y, c = color_dict[g], label = g)
    for k, i in zip(range(len(ix[0])), ix[0]):
        ax.annotate(dstrb_sub.index[i], (x[k], y[k]),fontsize=7)
    
ax.scatter(clusters_centroids[0][2], clusters_centroids[0][3], color='crimson', marker='X')
ax.scatter(clusters_centroids[1][2], clusters_centroids[1][3], color='navy', marker='X')
# ax.scatter(dstrb_sub[3]['113316'], dstrb_sub[4]['113316'], color='orange')
# ax.scatter(dstrb_sub[3]['144125'], dstrb_sub[4]['144125'], color='forestgreen')
# ax.scatter(dstrb_sub[3]['217126'], dstrb_sub[4]['217126'], color='black')

plt.xlabel(f'dimension 3', fontsize=14)
plt.ylabel(f'dimension 4', fontsize=14)
plt.legend()
plt.show()

In [80]:
cluster1 = dstrb_sub[dstrb_sub.cluster_lab==0]
cluster2 = dstrb_sub[dstrb_sub.cluster_lab==1]
assert(len(np.unique(list(cluster1.cluster_lab)))==1)
assert(len(np.unique(list(cluster2.cluster_lab)))==1)

### Creation of buckets dictionnary

In [81]:
bucket_path = '/neurospin/dico/data/deep_folding/current/crops/CINGULATE/mask/sulcus_based/2mm/centered_combined/hcp/Rbuckets/'
suffix_path = '_normalized.bck'
buckets = {}
for sub in tqdm(list(labels)):
    bucket = aims.read(os.path.join(bucket_path, str(sub) + suffix_path))
    bucket = dtx.convert.bucket_aims_to_ndarray(bucket[0])
    buckets[sub] = bucket

100%|██████████| 550/550 [00:44<00:00, 12.23it/s]


In [82]:
type(buckets)

dict

In [83]:
subjects_c1 = cluster1.index
subjects_c2 = cluster2.index

buckets_c1 = {k: v for k,v in buckets.items() if k in list(subjects_c1)}
buckets_c2 = {k: v for k,v in buckets.items() if k in list(subjects_c2)}

In [84]:
cld.draw(list(buckets_c1.values())[0])

#### Alignement of the subjects to respective central subject

In [85]:
aligned_buckets_C1, aligned_rot_C1, aligned_transl_C1 = ma.align_buckets_by_ICP_batch(buckets_c1, central_1)
aligned_buckets_C2, aligned_rot_C2, aligned_transl_C2 = ma.align_buckets_by_ICP_batch(buckets_c2, central_2)

>>> INFO moving_averages.transform - using 45 cores out of 48
Aligning buckets to 519950: 100%|██████████| 185/185 [00:01<00:00, 95.72it/s] 
>>> INFO moving_averages.transform - using 45 cores out of 48
Aligning buckets to 134324: 100%|██████████| 166/166 [00:01<00:00, 115.79it/s]


In [86]:
plt.hist(cluster1[3], alpha=0.5)
plt.hist(cluster2[3], alpha=0.5)
plt.show()

In [87]:
def subj_count_extreme_coords(isodf, axis, min_coord, max_coord, num_coord=10):
    coord_values = np.linspace(min_coord, max_coord, num_coord)
    step = (coord_values[1] - coord_values[0])
    midpoint_min = coord_values[0] + step/2
    midpoint_max = coord_values[num_coord-1] - step/2
    coord_under = [elt for elt in isodf[axis] if elt<midpoint_min]
    coord_over = [elt for elt in isodf[axis] if elt>midpoint_max]
    num_subj_under = len(coord_under)
    num_subj_over = len(coord_over)
    return num_subj_under, num_subj_over

def get_MA_coords(isodf, axis, num_subj_threshold, num_coord=10):
    min_subj_coord = min(isodf[axis])
    max_subj_coord = max(isodf[axis])
    step = (max_subj_coord - min_subj_coord)/num_coord
    num_subj_under, num_subj_over = subj_count_extreme_coords(isodf, axis, min_subj_coord, max_subj_coord)
    while ((num_subj_under<num_subj_threshold)|(num_subj_over<num_subj_threshold)):
        step = (max_subj_coord - min_subj_coord)/num_coord
        if num_subj_under<num_subj_threshold:
            min_subj_coord = min_subj_coord + step
        if num_subj_over<num_subj_threshold:
            max_subj_coord = max_subj_coord - step
        num_subj_under, num_subj_over = subj_count_extreme_coords(isodf, axis, min_subj_coord, max_subj_coord)
    return np.linspace(min_subj_coord, max_subj_coord, num_coord)

In [88]:
MA_coords = get_MA_coords(cluster1, 3, num_subj_threshold=14, num_coord=2)
step = step = MA_coords[1]-MA_coords[0]

In [89]:
MA_coords

array([0.05151618, 0.05151618])

In [90]:
type(MA_coords)

numpy.ndarray

In [91]:
step

0.0

In [92]:
cluster1.index 

Index(['136833', '138837', '562345', '177342', '869472', '529953', '667056',
       '966975', '127731', '123824',
       ...
       '571548', '933253', '122418', '103515', '163331', '663755', '395958',
       '154734', '188448', '166640'],
      dtype='object', length=185)

In [93]:
SPAM_vols_c1, shift1 = ma.calc_MA_volumes_batch(\
    centers=MA_coords,
    buckets_dict=aligned_buckets_C1,
    distance_df=cluster1, axis_n=3, FWHM=0.5)

Calculating moving averages: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s]


In [94]:
# SPAM_vols_c1, shift1 = ma.moving_averages_tools.calc_one_MA_volume(aligned_buckets_C1, cluster1, axis_n=3, FWHM=0.5)

In [95]:
SPAM_vols_c1.items()

dict_items([(0.051516175270080566, array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.

In [96]:
SPAM_meshes = {}

for k, volume in tqdm(list(SPAM_vols_c1.items())[:]):
    SPAM_meshes[k]=dtx.convert.volume_to_mesh(
                    vol=volume,
                    gblur_sigma=0.0,
                    threshold=0.45,
                    deciReductionRate=0,
                    smoothRate=0.25)
    #SPAM_meshes[k] = dtx.mesh.shift_aims_mesh(SPAM_meshes[k], 2*shift1[k])
    
shifted_SPAM_meshes = {}
for dist, mesh in SPAM_meshes.items():
    shifted_SPAM_meshes[str(dist)] = dtx.mesh.shift_aims_mesh_along_axis(mesh, dist, axis=0, scale=30)

100%|██████████| 1/1 [00:00<00:00, 16.19it/s]


In [97]:
fig = cld.draw(list(SPAM_vols_c1.values())[0], th_min=0.2)
cld.draw(list(SPAM_meshes.values())[0], fig=fig)

In [98]:
if save_mesh:
    for x, mesh in tqdm(shifted_SPAM_meshes.items()):
        aims.write(mesh, f"MA_{x}_2cluster33.mesh")#### Saving of average crops

### Whole cluster analysis

In [99]:
clusters_centroids

array([[ 1.31399507e-01, -8.80867887e-04,  3.96890843e-01,
         1.43585158e+00],
       [-8.36732560e-01, -1.23226578e-01,  1.49042979e+00,
         1.23057654e-02],
       [ 1.52115546e-01,  1.41179076e+00, -5.20310118e-01,
         2.20132112e-01]])

In [100]:
np.mean(clusters_centroids,axis=0)

array([-0.18440584,  0.42922777,  0.45567017,  0.55609649])

In [101]:
clusters_centroids = kmeans.cluster_centers_
print(f"cluster's centroids coordinates: \n {clusters_centroids}")
centroid = np.mean(clusters_centroids, axis=0)
print(centroid)

central = closest_distance(centroid, dstrb_sub.drop(['cluster_lab'], axis=1))
print(f"Closest subject to centroid of cluster is {central}")
print(dstrb_sub[3][str(central)], dstrb_sub[4][str(central)])

cluster's centroids coordinates: 
 [[ 1.31399507e-01 -8.80867887e-04  3.96890843e-01  1.43585158e+00]
 [-8.36732560e-01 -1.23226578e-01  1.49042979e+00  1.23057654e-02]
 [ 1.52115546e-01  1.41179076e+00 -5.20310118e-01  2.20132112e-01]]
[-0.18440584  0.42922777  0.45567017  0.55609649]
Closest subject to centroid of cluster is 146735
0.59765047 0.8947191


In [102]:
aligned_buckets, aligned_rot, aligned_transl = ma.align_buckets_by_ICP_batch(buckets, central)

>>> INFO moving_averages.transform - using 45 cores out of 48
Aligning buckets to 146735: 100%|██████████| 550/550 [00:04<00:00, 120.67it/s]


In [103]:
MA_coords = get_MA_coords(dstrb_sub, 3, num_subj_threshold=14, num_coord=5)
step = step = MA_coords[1]-MA_coords[0]

In [104]:
MA_coords

array([-1.01722126, -0.41663303,  0.18395519,  0.78454342,  1.38513165])

In [105]:
step

0.6005882263183594

In [106]:
len(dstrb_sub)

550

In [107]:
#SPAM_centers_c1 = [-2.05716]
#SPAM_vols_c1, shift1 = ma.calc_MA_volumes_batch(SPAM_centers_c1, aligned_buckets_C1, cluster1, axis_n=1, FWHM=1)
SPAM_vols, shift = ma.calc_MA_volumes_batch(MA_coords, aligned_buckets, dstrb_sub, axis_n=1, FWHM=1)

Calculating moving averages: 100%|██████████| 5/5 [00:03<00:00,  1.29it/s]


In [108]:
# Draws first volume in the brochette
cld.draw(list(SPAM_vols.values())[0], th_min=0.3)

In [109]:
# Draws last volume in the brochette
cld.draw(list(SPAM_vols.values())[-1], th_min=0.3)

In [110]:
SPAM_meshes = {}

for k, volume in tqdm(list(SPAM_vols.items())[:]):
    SPAM_meshes[k]=dtx.convert.volume_to_mesh(
                    vol=volume,
                    gblur_sigma=0,
                    threshold=0.45,
                    deciReductionRate=0,
                    smoothRate=0.4)
    #SPAM_meshes[k] = dtx.mesh.shift_aims_mesh(SPAM_meshes[k], 2*shift1[k])
    
shifted_SPAM_meshes = {}
for dist, mesh in SPAM_meshes.items():
    shifted_SPAM_meshes[str(dist)] = dtx.mesh.shift_aims_mesh_along_axis(mesh, dist, axis=0, scale=20)

100%|██████████| 5/5 [00:00<00:00, 26.76it/s]


In [111]:
if save_mesh:
    for x, mesh in tqdm(shifted_SPAM_meshes.items()):
        aims.write(mesh, f"MA_{x}_2cluster44.mesh")#### Saving of average crops

In [112]:
cld.draw(list(SPAM_meshes.values())[0])

In [113]:
cld.draw(list(SPAM_meshes.values())[4])

In [114]:
fig = cld.draw(list(SPAM_vols.values())[0], th_min=0.2)
cld.draw(list(SPAM_meshes.values())[0], fig=fig)

In [115]:
SPAM_centers_c2 = [0]
SPAM_vols_c2, shift2 = ma.calc_MA_volumes_batch(SPAM_centers_c2, aligned_buckets_C2, cluster2, axis_n=3, FWHM=1)

Calculating moving averages: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s]


#### Visualization of average crops of both clusters

In [116]:
fig = cld.draw(shifted_SPAM_meshes)
ma.plot.brochette_layout(fig, "subjects meshes")

In [117]:
if save_mesh:
    for x, mesh in tqdm(shifted_SPAM_meshes.items()):
        aims.write(mesh, f"MA_{x}_2cluster2.mesh")

# Saves extreme bucket files

In [118]:
arr = np.array([np.array([dstrb_sub[k][i] for k in dstrb_sub.columns[2:4]]) for i in range(len(dstrb_sub))])

# Selects subjects whose 3rd dimension is higher than 14
ix_sup14 = np.where(arr[:,0] > 14.)
print(ix_sup14)

(array([], dtype=int64),)


In [119]:
if save_bucket:
    for ix in np.nditer(ix_sup14):
        sub = dstrb_sub.index[ix]
        aims.write(dtx.convert.bucket_numpy_to_bucketMap_aims(buckets[sub]), f">14_{sub}.bck")

In [120]:
# Selects subjects whose 3rd dimension is less than 1
ix_inf1 = np.where(arr[:,0] < 1.)
print(ix_inf1)

(array([  1,   3,   4,   5,   6,   8,   9,  11,  12,  13,  14,  17,  18,
        19,  21,  22,  23,  24,  25,  26,  27,  31,  32,  33,  36,  39,
        40,  41,  42,  43,  44,  45,  46,  49,  51,  53,  55,  56,  57,
        59,  60,  62,  63,  64,  65,  66,  68,  69,  72,  73,  75,  76,
        78,  79,  82,  83,  84,  85,  86,  87,  89,  90,  91,  92,  94,
        95,  96,  97,  98, 100, 101, 102, 103, 104, 107, 108, 109, 110,
       112, 116, 117, 118, 120, 121, 122, 123, 124, 125, 128, 130, 131,
       134, 135, 136, 138, 141, 142, 143, 144, 147, 148, 149, 151, 154,
       156, 157, 158, 160, 162, 163, 164, 165, 167, 168, 170, 171, 174,
       175, 176, 177, 179, 180, 181, 182, 183, 184, 186, 189, 190, 191,
       192, 194, 195, 197, 200, 201, 202, 203, 205, 206, 207, 209, 210,
       214, 215, 216, 217, 219, 220, 223, 224, 226, 227, 228, 232, 233,
       235, 237, 238, 241, 242, 244, 245, 247, 248, 249, 250, 254, 255,
       257, 258, 259, 260, 261, 262, 265, 267, 268, 270, 271, 2

In [121]:
if save_bucket:
    for ix in np.nditer(ix_inf1):
        sub = dstrb_sub.index[ix]
        aims.write(dtx.convert.bucket_numpy_to_bucketMap_aims(buckets[sub]), f"<1_{sub}.bck")