In [35]:
import anatomist.api as ana
from soma.qt_gui.qtThread import QtThreadCall
from soma.qt_gui.qt_backend import Qt

from soma import aims
import numpy as np
import pandas as pd
import os
import glob

from sklearn.preprocessing import StandardScaler
import statsmodels.api as sm

from sklearn.neighbors import NearestNeighbors

import matplotlib.pyplot as plt



In [36]:
a = ana.Anatomist()

### Variable definitions

In [37]:
path = "/neurospin/dico/data/deep_folding/current/models/Champollion_V0_trained_on_UKB40/SC-sylv_left"
participants_file = "/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/participants.csv"
lab = "gravityCenter_x"
model_path = glob.glob(f"{path}/*")[0]
embeddings_file = f"{model_path}/hcp_random_epoch80_embeddings/full_embeddings.csv"
participants_file = "/neurospin/dico/data/human/hcp/derivatives/morphologist-2023/morphometry/spam_recognition/morpho_talairach/morpho_S.C._left.dat"

In [38]:
side = "L" # "R" or "L"
region = "S.C.-sylv." # "S.C.-sylv.", "ORBITAL", "CINGULATE", "SC-sylv", "F.I.P."
database='hcp'

# Building predictors

In [39]:
participants = pd.read_csv(participants_file, sep=' ', index_col=0)
participants = participants[[lab]].dropna()
participants

Unnamed: 0_level_0,gravityCenter_x
subject,Unnamed: 1_level_1
100206,42.9009
100307,38.5095
100408,40.8582
100610,39.6554
101006,38.4064
...,...
992774,41.7449
993675,41.8441
994273,41.9341
995174,40.5852


In [40]:
ukb_emb = pd.read_csv(f"{embeddings_file}", index_col=0)
merged = participants[[lab]].merge(ukb_emb, left_index=True, right_index=True)


In [41]:
embeddings = merged.iloc[:,1:]
label = merged.iloc[:,0:1]
label_random = label.copy()
label_random[:] = label_random.sample(frac=1).values
# print(embeddings.head())
print(label.head())
print(label_random.head())

# Construct nearest_neighbor matrix
nbrs = NearestNeighbors(n_neighbors=int(len(embeddings)/2), algorithm='brute').fit(embeddings)
distances, indices = nbrs.kneighbors(embeddings)
indices.shape

         gravityCenter_x
subject                 
100206           42.9009
100307           38.5095
100408           40.8582
100610           39.6554
101006           38.4064
         gravityCenter_x
subject                 
100206           41.7174
100307           37.8192
100408           40.2906
100610           38.8848
101006           43.4656


(1114, 557)

In [42]:
label

Unnamed: 0_level_0,gravityCenter_x
subject,Unnamed: 1_level_1
100206,42.9009
100307,38.5095
100408,40.8582
100610,39.6554
101006,38.4064
...,...
992774,41.7449
993675,41.8441
994273,41.9341
995174,40.5852


In [43]:
label.sort_values(by=lab)

Unnamed: 0_level_0,gravityCenter_x
subject,Unnamed: 1_level_1
604537,28.6862
901038,33.6232
211619,33.6675
580650,34.4887
139637,34.4968
...,...
268850,47.2445
376247,47.5965
111312,49.5779
179245,50.3080


### Function definitions

In [44]:
def to_bucket(obj):
    if obj.type() == obj.BUCKET:
        return obj
    avol = a.toAimsObject(obj)
    c = aims.Converter(intype=avol, outtype=aims.BucketMap_VOID)
    abck = c(avol)
    bck = a.toAObject(abck)
    bck.releaseAppRef()
    return bck

def build_gradient(pal):
    gw = ana.cpp.GradientWidget(None, 'gradientwidget', pal.header()['palette_gradients'])
    gw.setHasAlpha(True)
    nc = pal.shape[0]
    rgbp = gw.fillGradient(nc, True)
    rgb = rgbp.data()
    npal = pal.np['v']
    pb = np.frombuffer(rgb, dtype=np.uint8).reshape((nc, 4))
    npal[:, 0, 0, 0, :] = pb
    npal[:, 0, 0, 0, :3] = npal[:, 0, 0, 0, :3][:, ::-1]  # BGRA -> RGBA
    pal.update()

def buckets_average(subject_id_list, dataset_name_list):
    dic_vol = {}
    dim = 0
    rep = 0
    if len(subject_id_list) == 0:
        return False
    while dim == 0 and rep < len(subject_id_list):
        if dataset_name_list[rep].lower() in ['ukb', 'ukbiobank', 'projected_ukb']:
            dataset = 'UkBioBank'
        elif dataset_name_list[rep].lower() in ['hcp']:
            dataset = 'hcp'
        mm_skeleton_path = f"/neurospin/dico/data/deep_folding/current/datasets/{dataset}/crops/2mm/{region}/mask/{side}crops"
        if os.path.isfile(f'{mm_skeleton_path}/{subject_id_list[rep]}_cropped_skeleton.nii.gz'):
            sum_vol = aims.read(f'{mm_skeleton_path}/{subject_id_list[rep]}_cropped_skeleton.nii.gz').astype(float)
            dim = sum_vol.shape
            sum_vol.fill(0)
        else: 
            print(f'FileNotFound {mm_skeleton_path}/{subject_id_list[rep]}_cropped_skeleton.nii.gz')
            #raise FileNotFoundError(f'{mm_skeleton_path}/{subject_id_list[0]}_cropped_skeleton.nii.gz')
        rep += 1

    for subject_id, dataset in zip(subject_id_list,dataset_name_list):
        if dataset.lower() in ['ukb', 'ukbiobank',  'projected_ukb']:
            dataset = 'UkBioBank40'
        elif dataset.lower() == 'hcp':
            dataset = 'hcp'
            
        mm_skeleton_path = f"/neurospin/dico/data/deep_folding/current/datasets/{dataset}/crops/2mm/{region}/mask/{side}crops"

        if os.path.isfile(f'{mm_skeleton_path}/{subject_id}_cropped_skeleton.nii.gz'):
            vol = aims.read(f'{mm_skeleton_path}/{subject_id}_cropped_skeleton.nii.gz')
            # compare the dim with the first file dim

            if vol.np.shape != dim:
                raise ValueError(f"{subject_id_list[0]} and {subject_id} must have the same dim")

                
            # to have a binary 3D structure
            dic_vol[subject_id] = (vol.np > 0).astype(int)
            sum_vol.np[:] += (vol.np > 0).astype(int) 
        else: 
            print(f'FileNotFound {mm_skeleton_path}/{subject_id}_cropped_skeleton.nii.gz')
            #raise FileNotFoundError(f'{mm_skeleton_path}/{subject_id}_cropped_skeleton.nii.gz')

    sum_vol = sum_vol / len(subject_id_list)
    print(sum_vol, sum_vol.shape)
    return sum_vol

In [45]:
def visualize_averages_along_parameter(df, column_name, database, nb_subjects_per_average=200, nb_columns=3, nb_lines=1):
    # anatomist objects
    global _block
    global _average_dic
    global dic_packages
    nb_windows = nb_columns * nb_lines
    _average_dic = {}
    step = nb_subjects_per_average # number of subjects on which average is done
    # Creates the block if it has not been created
    # try:
    #     _block
    # except NameError:
    _block = a.createWindowsBlock(nb_columns)

    # Order according to column_name
    # df = df.sort_values(column_name)
    # predict_proba = df[[column_name]]
    predict_proba = df

    # Creates dictionary of subjects to average
    dic_packages = {}
    for i in range(0,len(predict_proba),step):
        list_idx = (predict_proba.index[i:i+step].to_numpy())
        dic_packages[i//step] = list_idx
    
    # Ensures that last list contains the last step subjects
    list_idx = (predict_proba.index[-step:].to_numpy())
    dic_packages[i//step] = list_idx
    
    list_database = [database for i in range(step)]
    n_pack = len(dic_packages)

    # Loop of display averages
    list_pack = [int(np.ceil(i*n_pack/float(nb_windows))) for i in range(0,nb_windows)]
    for i in list_pack:
        sum_vol = buckets_average(dic_packages[i], list_database)
        print(dic_packages[i], sum_vol.np.sum())
        _average_dic[f'a_sum_vol{i}'] = a.toAObject(sum_vol)
        _average_dic[f'a_sum_vol{i}'].setPalette(minVal=0, absoluteMode=True)
        #wsum = a.createWindow('Sagittal', block=block)
        #wsum.addObjects(a_sum_vol)
        _average_dic[f'rvol{i}'] = a.fusionObjects(
            objects=[_average_dic[f'a_sum_vol{i}']],
            method='VolumeRenderingFusionMethod')
        _average_dic[f'rvol{i}'].releaseAppRef()
        # custom palette
        n = len(dic_packages[i])
        pal = a.createPalette('VR-palette')
        pal.header()['palette_gradients'] = '0;0.459574;0.497872;0.910638;1;1#0;0;0.52766;0.417021;1;1#0;0.7;1;0#0;0;0.0297872;0.00851064;0.72766;0.178723;0.957447;0.808511;1;1'
        #f'0;0.244444;0.5;1;1;1#0;0;0.535897;0.222222;1;1#0;0.7;1;0#0;0;{0.5/n};0;1;1'
        build_gradient(pal)
        _average_dic[f'rvol{i}'].setPalette('VR-palette', minVal=0.05, absoluteMode=True)
        pal2 = a.createPalette('slice-palette')
        pal2.header()['palette_gradients'] = '0;0.459574;0.497872;0.910638;1;1#0;0;0.52766;0.417021;1;1#0;0.7;1;0#0;0;0.0297872;0.00851064;0.72766;0.178723;0.957447;0.808511;1;1'
        #f'0;0.244444;0.5;1;1;1#0;0;0.535897;0.222222;1;1#0;0.7;1;0#0;0;{0.3/n};0;{0.7/n};1;1;1'
        build_gradient(pal2)
        _average_dic[f'a_sum_vol{i}'].setPalette('slice-palette')
        # rvol.palette().fill()
        _average_dic[f'wvr{i}'] = a.createWindow('3D', block=_block)
        _average_dic[f'wvr{i}'].addObjects(_average_dic[f'rvol{i}'])

### Visualization

In [46]:
df = label.sort_values(by=lab)
_block = a.createWindowsBlock(3)
visualize_averages_along_parameter(df, "predicted", database, nb_subjects_per_average=2, nb_columns=3, nb_lines=1)

<soma.aims.Volume_DOUBLE object at 0x7491695d3370> (38, 36, 49, 1)
[604537 901038] 1478.5
<soma.aims.Volume_DOUBLE object at 0x7491695d3e20> (38, 36, 49, 1)
[987983 194645] 1341.0
<soma.aims.Volume_DOUBLE object at 0x749160b2c310> (38, 36, 49, 1)
[329844 203721] 1226.0


# Displaying subjects

In [47]:
# list_subjects = list_smallest
# block = a.createWindowsBlock(5) # 5 columns
# dic_windows = {}
# for subject in list_subjects:
#     path_to_t1mri = f'/neurospin/dico/data/bv_databases/human/not_labeled/hcp/hcp/{subject}/t1mri/BL'
#     dic_windows[f'w{subject}'] = a.createWindow("3D", block=block)
#     dic_windows[f'white_{subject}'] = a.loadObject(f'{path_to_t1mri}/default_analysis/segmentation/mesh/{subject}_{side}white.gii')
#     dic_windows[f'white_{subject}'].loadReferentialFromHeader()
#     dic_windows[f'sulci_{subject}'] = a.loadObject(f'{path_to_t1mri}/default_analysis/folds/3.1/{side}{subject}.arg')
#     dic_windows[f'sulci_{subject}'].loadReferentialFromHeader()
#     dic_windows[f'w{subject}'].addObjects([dic_windows[f'white_{subject}'], dic_windows[f'sulci_{subject}']])

In [48]:
# df.sort_values("predicted")[::10]

In [49]:
# df_concat = pd.concat([df.sort_values("predicted")[::10][:5], df.sort_values("predicted")[::10][-5:]])

In [50]:
# df_concat

In [51]:
# visualize_averages_along_parameter(df_concat, "predicted", database, nb_subjects_per_average=1, nb_columns=5, nb_lines=2)

: 