In [2]:
import sys
import importlib

sys.path.append("../")

from src import utils
from src.utils import *


from src import plot_utils
from src import graph_utils
from src import inpaint_utils
from src import fiberatlas_utils

## Goal

The goal of this notebook is to do create paths (voxel level) to match pairs of regions with highest connectivity paths

## Description
Generate paths according to the functional connectivity of pairs of regions. 
- E.g using highest sum of energy (energy being positive correlation)

Generate usual bundle paths prior to optimization and prepare functional connectivity information.

In [None]:
scale = 1
connFilename = f'../../atlas_data/fiber_atlas/probconnatlas/wm.connatlas.scale{scale}.h5'
hf = h5py.File(connFilename, 'r')

centers = np.array(hf.get('header').get('gmcoords'))
nsubject = hf.get('header').get('nsubjects')[()]
dim = hf.get('header').get('dim')[()]
fiber_affine = hf.get('header').get('affine')[()]

gmregions_names = hf.get('header').get('gmregions')[()]
nb_regions = gmregions_names.shape[0]

gm_mask_subj = nib.load('../../atlas_data/moviedata_fMRI_eg/gm_mask_subj7.nii').get_fdata() 
wm_mask_subj = (gm_mask_subj + 1) % 2


consistency_view = fiberatlas_utils.get_aggprop(hf, 'consistency')
length_view = fiberatlas_utils.get_aggprop(hf, 'length')
nbStlines_view = fiberatlas_utils.get_aggprop(hf, 'numbStlines')
nb_regions = consistency_view.shape[0]

# NOTE: consider bundles that appear at least in 30 % of the subjects
thresh_subjapp = int(np.ceil(nsubject * 0.1)) 

### No fibers Remove

In [4]:
X = []
bundles_labels = []
for i in tqdm(range(1,nb_regions + 1)):
    for j in range(i,nb_regions + 1):
        tmp = fiberatlas_utils.get_bundles_betweenreg(hf, i, j, verbose=False)
        if tmp is None: continue
        if np.sum(tmp[:,3] >= (thresh_subjapp)) == 0: continue
        bundles_labels.append((i,j))
        vec = np.zeros(nb_regions)
        vec[i-1] = 1.0
        vec[j-1] = 1.0
        X.append(vec)

X = np.array(X)

root = '../../atlas_data/fiber_atlas/yasser_datacomp/volspams_compress/'

atlas_of_interest = f'compresslausanne2018.scale{scale}.sym.corrected.ctx+subc.volspams.nii.gz'

prob_regions, prob_affine = (nib.load(root + atlas_of_interest).get_fdata()[:,:,:,1:], 
                             nib.load(root + atlas_of_interest).affine)

Xp = []
bundles_labels = []
for i in tqdm(range(1,nb_regions + 1)):
    for j in range(i,nb_regions + 1):
        tmp = fiberatlas_utils.get_bundles_betweenreg(hf, i, j, verbose=False)
        if tmp is None: continue
        if np.sum(tmp[:,3] >= (thresh_subjapp)) == 0: continue
        bundle_coords = tmp[:,[0,1,2]]

        prob_vox = np.zeros_like(prob_regions[:,:,:,0])
        prob_vox[bundle_coords[:,0], bundle_coords[:,1], bundle_coords[:,2]] = 1.0

        region_i = prob_regions[:,:,:,i-1]
        region_j = prob_regions[:,:,:,j-1]

        bundle_proba_i = (region_i * prob_vox)
        bproba_i = bundle_proba_i[bundle_proba_i!=0].mean()
        bundle_proba_j = (region_j * prob_vox)
        bproba_j = bundle_proba_j[bundle_proba_j!=0].mean()

        bundles_labels.append((i,j))
        vec = np.zeros(nb_regions)
        vec[i-1] = bproba_i
        vec[j-1] = bproba_j
        Xp.append(vec)

Xp = np.array(Xp)
Xp = np.nan_to_num(Xp)

region_ftimecourse = load(f"../../atlas_data/moviedata_fMRI_eg/yasseratlased_fmri/ftimecourse_95_scale{scale}.pkl")
regions_in_voxels = load(f'../../atlas_data/fiber_atlas/regions95_voxels_scale{scale}.pkl')[:,:,:,1:]

# spatial graph defining
bundle_graph = np.zeros((X.shape[0], X.shape[0]))
for k in range(X.shape[0]):
    avect1 = X[k]
    for s in range(X.shape[0]):
        if s == k: continue
        avect2 = X[s]
        if np.abs(avect1 - avect2).sum() <= 2:
            bundle_graph[k,s] = 1.0
            bundle_graph[s,k] = 1.0

# temporal graph defining
cycle = graph_utils.make_cycle(region_ftimecourse.shape[-1])

Ls = graph_utils.compute_directed_laplacian(bundle_graph)
Lt = graph_utils.compute_directed_laplacian(cycle)

Xmult = np.array([Xp.T for _ in range(region_ftimecourse.shape[-1])])

bundle_opt, logs = inpaint_utils.optimize_lreg(Xmult, region_ftimecourse, Ls=Ls, Lt=Lt, 
                                               verbose=True, num_epochs=200, logging=True, p1=0, p2=0, lr=1)

save(f"../resources/weights_regressors_activity/weighted_bundle_activity_timevertex{thresh_subjapp}_scale{scale}_stability-noremove.pkl", bundle_opt)

100%|██████████| 95/95 [00:30<00:00,  3.10it/s]
  bproba_i = bundle_proba_i[bundle_proba_i!=0].mean()
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 95/95 [06:38<00:00,  4.19s/it]
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 200/200 [00:57<00:00,  3.50it/s]

Losses are decomposed into:
generic loss=tensor([1.8552])
spatialloss=tensor([2965809.2500])
temporalloss=tensor([14861.0010])
sumloss=tensor([1.8552])





### Random fibers Remove

In [18]:
perc = 10
np.random.seed(99)

edges = np.array(np.where(consistency_view > 0)).T
toremove = edges[np.random.choice(np.arange(len(edges)), len(edges)//10, replace=False)]

In [30]:
np.min(np.linalg.norm(np.array([13,-1]) - toremove, axis=1)) < 

2.8284271247461903

In [None]:
# Example with 10% removed
X = []
bundles_labels = []
for i in tqdm(range(1,nb_regions + 1)):
    for j in range(i,nb_regions + 1):
        tmp = fiberatlas_utils.get_bundles_betweenreg(hf, i, j, verbose=False)
        if (i-1,j-1)
        if tmp is None: continue
        if np.sum(tmp[:,3] >= (thresh_subjapp)) == 0: continue
        bundles_labels.append((i,j))
        vec = np.zeros(nb_regions)
        vec[i-1] = 1.0
        vec[j-1] = 1.0
        X.append(vec)

X = np.array(X)

root = '../../atlas_data/fiber_atlas/yasser_datacomp/volspams_compress/'

atlas_of_interest = f'compresslausanne2018.scale{scale}.sym.corrected.ctx+subc.volspams.nii.gz'

prob_regions, prob_affine = (nib.load(root + atlas_of_interest).get_fdata()[:,:,:,1:], 
                             nib.load(root + atlas_of_interest).affine)

Xp = []
bundles_labels = []
for i in tqdm(range(1,nb_regions + 1)):
    for j in range(i,nb_regions + 1):
        tmp = fiberatlas_utils.get_bundles_betweenreg(hf, i, j, verbose=False)
        if tmp is None: continue
        if np.sum(tmp[:,3] >= (thresh_subjapp)) == 0: continue
        bundle_coords = tmp[:,[0,1,2]]

        prob_vox = np.zeros_like(prob_regions[:,:,:,0])
        prob_vox[bundle_coords[:,0], bundle_coords[:,1], bundle_coords[:,2]] = 1.0

        region_i = prob_regions[:,:,:,i-1]
        region_j = prob_regions[:,:,:,j-1]

        bundle_proba_i = (region_i * prob_vox)
        bproba_i = bundle_proba_i[bundle_proba_i!=0].mean()
        bundle_proba_j = (region_j * prob_vox)
        bproba_j = bundle_proba_j[bundle_proba_j!=0].mean()

        bundles_labels.append((i,j))
        vec = np.zeros(nb_regions)
        vec[i-1] = bproba_i
        vec[j-1] = bproba_j
        Xp.append(vec)

Xp = np.array(Xp)
Xp = np.nan_to_num(Xp)

region_ftimecourse = load(f"../../atlas_data/moviedata_fMRI_eg/yasseratlased_fmri/ftimecourse_95_scale{scale}.pkl")
regions_in_voxels = load(f'../../atlas_data/fiber_atlas/regions95_voxels_scale{scale}.pkl')[:,:,:,1:]

# spatial graph defining
bundle_graph = np.zeros((X.shape[0], X.shape[0]))
for k in range(X.shape[0]):
    avect1 = X[k]
    for s in range(X.shape[0]):
        if s == k: continue
        avect2 = X[s]
        if np.abs(avect1 - avect2).sum() <= 2:
            bundle_graph[k,s] = 1.0
            bundle_graph[s,k] = 1.0

# temporal graph defining
cycle = graph_utils.make_cycle(region_ftimecourse.shape[-1])

Ls = graph_utils.compute_directed_laplacian(bundle_graph)
Lt = graph_utils.compute_directed_laplacian(cycle)

Xmult = np.array([Xp.T for _ in range(region_ftimecourse.shape[-1])])

bundle_opt, logs = inpaint_utils.optimize_lreg(Xmult, region_ftimecourse, Ls=Ls, Lt=Lt, 
                                               verbose=True, num_epochs=200, logging=True, p1=0, p2=0, lr=1)

save(f"../resources/weights_regressors_activity/weighted_bundle_activity_timevertex{thresh_subjapp}_scale{scale}_stability-{perc}p_remove.pkl", bundle_opt)