In [44]:
import nibabel as nib
import numpy as np
import pandas as pd
from pathlib import Path
import os
import pickle
import argparse
import logging
import nibabel.processing
from tqdm import tqdm
import warnings

In [45]:
# -*- coding: utf-8 -*-


def reorient_to_ras(img):
    orig_ornt = nib.io_orientation(img.affine)
    targ_ornt = nib.orientations.axcodes2ornt("RAS")
    transform = nib.orientations.ornt_transform(orig_ornt, targ_ornt)

    img_orient = img.as_reoriented(transform)
    return img_orient

def get_affines(imgs):
    affines = []
    for img in tqdm(imgs):
        affines.append(img.affine)
    return np.array(affines)

def get_resolution(affines):
    # asserting
    for n,affine in enumerate(affines):
        if abs(affine[0,0]) != abs(affine[1,1]) != abs(affine[2,2]):
            warnings.warn(f"Affine matrix #{n} is not isotropic")

        i, j = np.nonzero(affine)
        if not np.all(i == j):
            warnings.warn(f"Affine matrix #{n} is not cartesian")

    counts = np.bincount([[affine[0,0],affine[1,1],affine[2,2]] for affine in affines])
    return np.argmax(counts)

class BBox :
    def __init__(self, datas):
        self.minx, self.maxx, self.miny, self.maxy, self.minz, self.maxz = [], [], [], [], [], []
        for data in tqdm(datas):
            bbox = self.compute_bbox(data)
            self.minx.append(bbox[0])
            self.maxx.append(bbox[1])
            self.miny.append(bbox[2])
            self.maxy.append(bbox[3])
            self.minz.append(bbox[4])
            self.maxz.append(bbox[5])

    def compute_bbox(self, data):
        shape = data.shape

        if np.sum(data) == 0:
            #print("VTA empty")
            #return (shape[0], 0, shape[1], 0, shape[2], 0)
            return (np.nan, np.nan, np.nan, np.nan, np.nan, np.nan)
        if np.sum(data) < 0.0:
            raise ValueError("Negative sum?")

        # x-axis
        voxel_sum_x = [np.sum(data[i,:,:]) for i in range(shape[0])]
        bbox_min_x = next((index for index, item in enumerate(voxel_sum_x) if item != 0), None)
        try:
            bbox_max_x = [i for i, slice in enumerate(voxel_sum_x) if slice > 0.0][-1]
        except Exception as e:
            print(f'Exception : {e}')
            test = np.array([i for i, slice in enumerate(voxel_sum_x) if slice > 0.0])
            print(test)

        # y-axis
        voxel_sum_y = [np.sum(data[:,i,:]) for i in range(shape[1])]
        bbox_min_y = next((index for index, item in enumerate(voxel_sum_y) if item != 0), None)
        bbox_max_y = [i for i, slice in enumerate(voxel_sum_y) if slice > 0.0][-1]

        # z-axis
        voxel_sum_z = [np.sum(data[:,:,i]) for i in range(shape[2])]
        bbox_min_z = next((index for index, item in enumerate(voxel_sum_z) if item != 0), None)
        bbox_max_z = [i for i, slice in enumerate(voxel_sum_z) if slice > 0.0][-1]

        return (bbox_min_x, bbox_max_x, bbox_min_y, bbox_max_y, bbox_min_z, bbox_max_z)      

    def get_max(self, index):
        return (self.maxx[index], self.maxy[index], self.maxz[index])

    def get_maxs(self):
        maxs = np.swapaxes(np.array((self.maxx, self.maxy, self.maxz)), 0, 1)
        return maxs
    
    def get_min(self, index):
        return (self.minx[index], self.miny[index], self.minz[index])

    def get_mins(self):
        mins = np.swapaxes(np.array((self.minx, self.miny, self.minz)), 0, 1)
        return mins


## Finding Null and net negative VTAs

## Finding Null and net negative VTAs

In [46]:
null_vtas, negative_vtas = [], []
table = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table.csv')
filenames_mixed = table['massive_filename'].to_list()

for filename in filenames_mixed:
    #print(filename)
    img = nib.load(filename)
    
    resolution = 0.5

    # flip all the VTAS to right side & force isotropy on space
    ras_img = reorient_to_ras(img)

    # binarizing and vectorizing numpy arrays of VTAs
    data_ = np.round(ras_img.get_fdata())

    if np.sum(data_) < 0.0:
        negative_vtas.append(filename)

    if np.sum(data_) == 0:
        null_vtas.append(filename)

print(f'null vtas : {null_vtas}, \nnegative vtas : {negative_vtas}')

null vtas : [], 
negative vtas : ['/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p201_c18_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p203_c0_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p203_c16_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p203_c17_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p204_c16_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p205_c16_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p205_c17_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p206_c16_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p206_c17_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p207_c15_a0.5.nii', '/media/brainstimmaps/DATA/2009_DeepMaps01/03_Data/04_MassiveMerged/p207_c18_a0.5.nii',

In [47]:
len(negative_vtas)

38

In [48]:
table = table[~table['massive_filename'].isin(negative_vtas)]
len(table)

8556

In [49]:
table.to_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table.csv', index=False)

In [None]:
zeroed_filenames = {}
new_filenames = []

for filename in negative_vtas:
    nifti_img = nib.load(filename)
    voxel_data = nifti_img.get_fdata()
    negative_mask = voxel_data < 0
    voxel_data[negative_mask] = 0
    modified_nifti_img = nib.Nifti1Image(voxel_data, affine=nifti_img.affine, header=nifti_img.header)
    new_filename = filename[:-4] + '_zeroed.nii'
    nib.save(modified_nifti_img, new_filename)
    zeroed_filenames[filename] = new_filename
    new_filenames.append(new_filename)

for index, row in table.iterrows():
    filename = row['massive_filename']
    if filename in zeroed_filenames:
        new_path = zeroed_filenames[filename]
        table.at[index, 'massive_filename'] = new_path
        table.at[index, 'zeroed'] = 1
    else:
        table.at[index, 'zeroed'] = 0



table.to_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table_zeroed.csv', index=False)
print(len(table))
print(negative_vtas)
table = table[~table['massive_filename'].isin(new_filenames)]
print(len(table))
table.to_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table_subtracted.csv', index=False)


In [None]:
table_sub = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table_subtracted.csv')
table_zer = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table_zeroed.csv')

print(len(table_sub))
print(len(table_zer))


## Common space, STN/N-image center of mass & inclusion percentage comptuation

In [50]:

# table = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/psm_tables/bern+cologne/mapping/multicentricTableAllImprovedOnlyRev04B.csv')
table = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table.csv')

c_map = table['mapping'] == 1
c_bern = table['center'] == 'Bern'
c_koln = table['center'] == 'Cologne'

table_bern_massive = table[c_bern]
table_koln_massive = table[c_koln]

table_bern_mapping = table[c_bern & c_map]
table_koln_mapping = table[c_koln & c_map]

table_mapping = table[c_map]
filenames = {}
filenames['bern_massive'] = table_bern_massive['massive_filename'].to_list()
filenames['koln_massive'] = table_koln_massive['massive_filename'].to_list()
filenames['merged_massive'] = table['massive_filename'].to_list()

filenames['bern_mapping'] = table_bern_mapping['massive_filename'].to_list()
filenames['koln_mapping'] = table_koln_mapping['massive_filename'].to_list()
filenames['merged_mapping'] = table_mapping['massive_filename'].to_list()

#filenames = list(map(filename_process, filenames))

In [51]:
table_koln_mapping

Unnamed: 0,center,cohort,lead_model,patient,lead,hemisphere,contact,vercise,directional,ring,segmented,amplitude,mapping_score,mapping,massive_filename,part,lin_interp_score,step_interp_score
3547,Cologne,C01,M. 3389,201,201.5,left,11,15,False,R3,,4.0,0.500000,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.500000,0.500000
3553,Cologne,C01,M. 3389,201,201.5,left,8,8,False,R0,,2.0,0.500000,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.500000,0.500000
3555,Cologne,C01,M. 3389,201,201.5,left,8,8,False,R0,,3.0,0.750000,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.750000,0.750000
3562,Cologne,C01,M. 3389,201,201.5,left,9,18,False,R1,,2.0,0.500000,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.500000,0.500000
3564,Cologne,C01,M. 3389,201,201.5,left,9,18,False,R1,,3.0,0.750000,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.750000,0.750000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5419,Cologne,C02,B.S. Vercise,240,240.0,right,2,2,True,,R1_2,2.0,0.333333,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.333333,0.333333
5421,Cologne,C02,B.S. Vercise,240,240.0,right,2,2,True,,R1_2,3.0,0.333333,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.333333,0.333333
5423,Cologne,C02,B.S. Vercise,240,240.0,right,2,2,True,,R1_2,4.0,0.666667,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.666667,0.666667
5431,Cologne,C02,B.S. Vercise,240,240.0,right,3,3,True,,R1_3,3.0,0.500000,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.500000,0.500000


In [52]:
table_koln_massive

Unnamed: 0,center,cohort,lead_model,patient,lead,hemisphere,contact,vercise,directional,ring,segmented,amplitude,mapping_score,mapping,massive_filename,part,lin_interp_score,step_interp_score
3540,Cologne,C01,M. 3389,201,201.5,left,11,15,False,R3,,0.5,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.0625,0.0
3541,Cologne,C01,M. 3389,201,201.5,left,11,15,False,R3,,1.0,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.1250,0.0
3542,Cologne,C01,M. 3389,201,201.5,left,11,15,False,R3,,1.5,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.1875,0.0
3543,Cologne,C01,M. 3389,201,201.5,left,11,15,False,R3,,2.0,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.2500,0.0
3544,Cologne,C01,M. 3389,201,201.5,left,11,15,False,R3,,2.5,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.3125,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5431,Cologne,C02,B.S. Vercise,240,240.0,right,3,3,True,,R1_3,3.0,0.5,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.5000,0.5
5432,Cologne,C02,B.S. Vercise,240,240.0,right,3,3,True,,R1_3,3.5,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.5000,0.5
5433,Cologne,C02,B.S. Vercise,240,240.0,right,3,3,True,,R1_3,4.0,0.5,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,1.0,0.5000,0.5
5434,Cologne,C02,B.S. Vercise,240,240.0,right,3,3,True,,R1_3,4.5,,0.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,0.5000,0.5


In [53]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import center_of_mass, affine_transform
from nilearn import image
from collections import defaultdict
import gc

def create_dict_4_list():
    return defaultdict(create_dict_3_list)

def create_dict_3_list():
    return defaultdict(create_dict_2_list)

def create_dict_2_list():
    return defaultdict(create_dict_list)

def create_dict_list():
    return defaultdict(list)

def get_mask_cuboid(coords, size, cs_data_dummy):
    mask = np.zeros_like(cs_data_dummy)
    x, y, z = map(round, coords)  # round the coordinates to the nearest integer
    #print(x,y,z)
    mask[max(x-size,0):min(x+size,mask.shape[0]), max(y-size,0):min(y+size,mask.shape[1]), max(z-size,0):min(z+size,mask.shape[2])] = 1
    #print(np.sum(mask)/mask.size)
    return mask, np.sum(mask)/mask.size

def get_mask_sphere(coords, radius, cs_data_dummy):
    shape_ = cs_data_dummy.shape
    #mask = np.zeros(shape_, dtype=bool)
    
    x0, y0, z0 = map(round, coords)  # round the coordinates to the nearest integer

    # Generate the 3D grid and shift it by the offset
    x, y, z = np.ogrid[0:shape_[0], 0:shape_[1], 0:shape_[2]]

    # Create the mask by comparing the distance from the origin to the radius
    mask = (x-x0)**2 + (y-y0)**2 + (z-z0)**2 <= radius**2

    #print(shape_, mask.shape)
    return mask, np.sum(mask)/mask.size


def matrix_point_mul(matrix, point):
    return (matrix @ np.append(point,1))[:-1]

def display_space_bounds(shape, affine):
    mins = [affine[i,3] for i in range(3)]
    maxs = [affine[i,i]*shape[i] for i in range(3)]
    for i, dim in enumerate(zip(('x', 'y', 'z'), (' M <-> L', 'P <-> A', 'I <-> S'))):
        print(f'{dim[0]} : [{mins[i]:.2f}, {maxs[i]:.2f}] mm  {dim[1]}')

def display_volume_size(shape, affine):
    n_voxels = np.prod(shape)
    volume_size = n_voxels * np.prod([affine[i,i] for i in range(3)]) / 1000 # cm^3
    print(f"{n_voxels:,} voxels\n{volume_size:.2f} cm^3")
    
window_sizes = {
    0.25 : range(1, 50),
    0.5  : range(1, 30),
    1.0  : range(1, 20)
}

resolutions = (0.25, 0.5, 1.0)

In [54]:
class STN_ALGO_DATA:
    def __init__(self, filename=None):
        self.data = {}
        self.common_space_affine = {}
        self.common_space_shape = {}
        self.stn_com_coords_in_common_space = {}
        self.n_com_coords_in_common_space = {}
        self.vta_inclusion_3sigma = {}
        if filename:
            self.load_from_file(filename)

    def add_value(self, resolution, center, inclusion_type, shape, focus, value):
        # Create a unique key for the combination of categorical variables
        key = f"{resolution}-{center}-{inclusion_type}-{shape}-{focus}"

        # If the key is not in the data dictionary, create a new empty array
        if key not in self.data:
            self.data[key] = []

        # Append the value to the corresponding array
        self.data[key].append(value)

    def get_array(self, resolution, center, inclusion_type, shape, focus):
        # Create a unique key for the combination of categorical variables
        key = f"{resolution}-{center}-{inclusion_type}-{shape}-{focus}"

        # If the key is in the data dictionary, return the corresponding array as a numpy array
        if key in self.data:
            return np.array(self.data[key])

        # If the key is not in the data dictionary, return None
        return None
    
 
    def save_to_file(self, filename):
        self.opti_edge_radii = {}
        self.resolutions = []
        self.centers = []

        for key in self.data.keys():
            resolution = key.split('-')[0]
            if resolution not in self.resolutions:
                self.resolutions.append(resolution)
            center = key.split('-')[1]
            if center not in self.centers:
                self.centers.append(center)

        for resolution in self.resolutions:
            for center in self.centers:
                arr = self.get_array(resolution,center,'vta', 'cuboid', 'stn')
                if arr is not None:
                    for i, sigma in enumerate((0.6827, 0.945, 0.9973)):
                        new_key = f"{resolution}-{center}-{i+1}sigma"
                        self.opti_edge_radii[new_key] = np.argmax(arr > sigma)

        try:
            with open(filename, 'wb') as file:
                data_to_save = {
                    'data': self.data,
                    'common_space_affine': self.common_space_affine,
                    'common_space_shape': self.common_space_shape,
                    'stn_com_coords_in_common_space': self.stn_com_coords_in_common_space,
                    'n_com_coords_in_common_space': self.n_com_coords_in_common_space,
                    'vta_inclusion_3sigma' : self.vta_inclusion_3sigma,
                    'opti_edge_radii' : self.opti_edge_radii
                }
                pickle.dump(data_to_save, file)
            print(f"Data saved to {filename} successfully.")
        except Exception as e:
            print(f"Error while saving data to {filename}: {e}")

    def load_from_file(self, filename):
        try:
            with open(filename, 'rb') as file:
                data_to_load = pickle.load(file)
                self.data = data_to_load['data']
                self.common_space_affine = data_to_load['common_space_affine']
                self.common_space_shape = data_to_load['common_space_shape']
                self.stn_com_coords_in_common_space = data_to_load['stn_com_coords_in_common_space']
                self.n_com_coords_in_common_space = data_to_load['n_com_coords_in_common_space']
                self.vta_inclusion_3sigma = data_to_load['vta_inclusion_3sigma']
                self.opti_edge_radii = data_to_load['opti_edge_radii']
            print(f"Data loaded from {filename} successfully.")
        except Exception as e:
            print(f"Error while loading data from {filename}: {e}")

    # Setters and Getters for the new variables
    def set_common_space_affine(self, resolution, center, value):
        key = (resolution, center)
        self.common_space_affine[key] = value

    def get_common_space_affine(self, resolution, center):
        key = (resolution, center)
        return self.common_space_affine.get(key, None)

    def set_common_space_shape(self, resolution, center, value):
        key = (resolution, center)
        self.common_space_shape[key] = value

    def get_common_space_shape(self, resolution, center):
        key = (resolution, center)
        return self.common_space_shape.get(key, None)

    def set_stn_com_coords_in_common_space(self, resolution, center, value):
        key = (resolution, center)
        self.stn_com_coords_in_common_space[key] = value

    def get_stn_com_coords_in_common_space(self, resolution, center):
        key = (resolution, center)
        return self.stn_com_coords_in_common_space.get(key, None)

    def set_n_com_coords_in_common_space(self, resolution, center, value):
        key = (resolution, center)
        self.n_com_coords_in_common_space[key] = value

    def get_n_com_coords_in_common_space(self, resolution, center):
        key = (resolution, center)
        return self.n_com_coords_in_common_space.get(key, None)


In [55]:
from importlib import reload  # Not needed in Python 2
from datetime import datetime
import logging
reload(logging)


# datetime object containing current date and time
now = datetime.now()
 
# dd/mm/YY H:M:S
dt_string = now.strftime("%Y_%m_%d_%H_%M_%S")
#print("date and time =", dt_string)

# Configure the logging format
log_format = "%(asctime)s - %(levelname)s - %(message)s"
logging.basicConfig(level=logging.DEBUG, format=log_format)

# Create a FileHandler to log messages to a file
log_file = f"inclusion_{dt_string}.log"
file_handler = logging.FileHandler(filename=log_file, mode='a')
file_handler.setLevel(logging.DEBUG)

# Create a Formatter to format the log messages
formatter = logging.Formatter(log_format)
file_handler.setFormatter(formatter)

# Add the handlers to the logger
logger = logging.getLogger()
logger.addHandler(file_handler)

In [56]:
centers = filenames.keys()

min_range_mm, max_range_mm = {}, {}

data = STN_ALGO_DATA()

for resolution in resolutions:

    #percentages = create_dict_4_list() # com,stn -> vta,mask -> cuboid,sphere
    #common_space_affine, common_space_shape, stn_com_coords_in_common_space, n_com_coords_in_common_space = {}, {}, {}, {}
    logger.info(f'------ {int(resolution*1000)}um ------')
    for center in centers:
        logger.info(center.capitalize())
        filenames_ = filenames[center]
        logger.info('loading niftis')
        imgs = [nib.load(filename) for filename in tqdm(filenames_)]

        # ----- PROCESSING --------------------------------------------------------
        affines = get_affines(imgs)

        # if resolution is not explicitly given, we get the most frequent resolution
        if resolution is None:
            resolution = get_resolution(affines)

        # flip all the VTAS to right side & force isotropy on space
        logger.info('converting to ras')
        imgs = [reorient_to_ras(img) for img in imgs]

        logger.info('rounding and converting to array')
        # binarizing and vectorizing numpy arrays of VTAs
        datas = [np.round(img.get_fdata()) for img in imgs]

        # computing bounding boxes for each binarized VTA
        logger.info('computing bounding box')
        bbox = BBox(datas)

        del datas

        # computing furthest point in each axis for each VTA
        bbox_maxs = bbox.get_maxs()
        bbox_mins = bbox.get_mins()

        # cloest corner -> most [Medial Posterior Inferior]
        shortest_points = [matrix_point_mul(affine, bbox_min) for affine, bbox_min in zip(affines, bbox_mins)]
        min_range_mm[center] = np.nanmin(shortest_points, axis=0)
        logger.info(f'min corner [mm] : {min_range_mm[center]}')

        # furthest corner -> most [Lateral Anterior Superior]
        furthest_points = [matrix_point_mul(affine, bbox_max) for affine, bbox_max in zip(affines, bbox_maxs)]
        max_range_mm[center] = np.nanmax(furthest_points, axis=0)
        logger.info(f'max corner [mm] : {max_range_mm[center]}')

        # container space size
        common_space_shape = np.ceil((max_range_mm[center] - min_range_mm[center])/resolution).astype(int)
        data.set_common_space_shape(resolution, center, common_space_shape)
        logger.info(f'container shape : {data.get_common_space_shape(resolution, center)}')

        # new affine for resampling VTAs
        common_space_affine = np.array([
            [resolution, 0, 0, min_range_mm[center][0]],
            [0, resolution, 0, min_range_mm[center][1]],
            [0, 0, resolution, min_range_mm[center][2]],
            [0, 0, 0, 1]]
        )
        
        data.set_common_space_affine(resolution, center, common_space_affine)
        # resampling vtas to new affine and container space size 
        # then rounding and converting back to numpy array
        logger.info('converting to common space (heavy processing)') # HEAVY PROCESSING
        common_space_vtas = [nib.processing.resample_from_to(from_img=img, to_vox_map=(common_space_shape, common_space_affine), order=1) for img in tqdm(imgs)]

        del imgs
        gc.collect()

        logger.info('rounding and converting to array')
        # previously : cs_datas = np.array([np.round(vta.get_fdata()) for vta in common_space_vtas])
        cs_datas = []
        for vta in common_space_vtas:
            cs_datas.append(np.round(vta.get_fdata()).astype(np.int8))
        cs_datas = np.array(cs_datas)

        del common_space_vtas
        gc.collect()
        
        # Load the STN volume and align it with the dataset
            
        logger.info('computing stn center of mass')

        stn_volume = nib.load('STNRight.nii.gz')
        stn_affine = stn_volume.affine
        logger.info(f'STN space affine : \n{stn_affine}')
        logger.info(f'CS->RL : \n{common_space_affine}')

        stn_com_coords_in_stn_space = np.array(center_of_mass(stn_volume.get_fdata()))
        stn_to_cs_mat = np.linalg.inv(common_space_affine) @ stn_affine

        logger.info(f'STN->CS affine : \n{stn_to_cs_mat}')

        stn_com_coords_in_common_space = matrix_point_mul(stn_to_cs_mat,stn_com_coords_in_stn_space)
        data.set_stn_com_coords_in_common_space(resolution, center, stn_com_coords_in_common_space)

        logger.info(f'stn coords in common space : {stn_com_coords_in_common_space}')
        logger.info(f'stn coords in real world : {matrix_point_mul(common_space_affine, stn_com_coords_in_common_space)}')
        logger.info('center of mass computation')

        n_image_cs = np.sum(cs_datas, axis=0).astype(float)
        n_image_cs_img = nib.Nifti1Image(
            dataobj=n_image_cs, 
            affine=common_space_affine)
        nib.save(n_image_cs_img, f'n_images/{center}_n_image_{int(1000*resolution)}um.nii')

        n_com_coords_in_common_space = center_of_mass(n_image_cs)
        data.set_n_com_coords_in_common_space(resolution, center, n_com_coords_in_common_space)
        logger.info(n_com_coords_in_common_space)

        # Iterate over a range of window sizes

        logger.info('inclusion percentage computation')
        inclusion_info = True
        for size in tqdm(window_sizes[resolution]):
            # Compute the percentage of positive voxels inside the window for each volume
            #logger.info(f'size : {size}')
            
            total_mask_cuboid_com, perc_mask_cuboid_com = get_mask_cuboid(n_com_coords_in_common_space, size, cs_datas[0])
            total_mask_sphere_com, perc_mask_sphere_com = get_mask_sphere(n_com_coords_in_common_space, size, cs_datas[0])
            total_mask_cuboid_stn, perc_mask_cuboid_stn = get_mask_cuboid(stn_com_coords_in_common_space, size, cs_datas[0])
            total_mask_sphere_stn, perc_mask_sphere_stn = get_mask_sphere(stn_com_coords_in_common_space, size, cs_datas[0])
            data.add_value(resolution, center, 'mask', 'cuboid', 'stn', perc_mask_cuboid_stn)
            data.add_value(resolution, center, 'mask', 'cuboid', 'com', perc_mask_cuboid_com)
            data.add_value(resolution, center, 'mask', 'sphere', 'stn', perc_mask_sphere_stn)
            data.add_value(resolution, center, 'mask', 'sphere', 'com', perc_mask_sphere_com)
            #logger.info(total_mask.shape)
            total_positive_voxels_cuboid_com, total_positive_voxels_sphere_com = 0, 0
            total_positive_voxels_cuboid_stn, total_positive_voxels_sphere_stn = 0, 0

            total_voxels = np.sum(cs_datas)
            for vta in cs_datas:
                total_positive_voxels_cuboid_com += np.sum(total_mask_cuboid_com * vta)
                total_positive_voxels_sphere_com += np.sum(total_mask_sphere_com * vta)
                total_positive_voxels_cuboid_stn += np.sum(total_mask_cuboid_stn * vta)
                total_positive_voxels_sphere_stn += np.sum(total_mask_sphere_stn * vta)

            data.add_value(resolution, center, 'vta', 'cuboid', 'stn', total_positive_voxels_cuboid_stn / total_voxels)
            data.add_value(resolution, center, 'vta', 'cuboid', 'com', total_positive_voxels_cuboid_com / total_voxels)
            data.add_value(resolution, center, 'vta', 'sphere', 'stn', total_positive_voxels_sphere_stn / total_voxels)
            data.add_value(resolution, center, 'vta', 'sphere', 'com', total_positive_voxels_sphere_com / total_voxels)

            if data.get_array(resolution, center, 'vta', 'cuboid', 'stn')[-1] >= 0.997 and inclusion_info:
                logger.info(f'3-sigma inclusion at edge radius {size}')
                inclusion_info = False


            if (data.get_array(resolution, center, 'vta', 'cuboid', 'stn')[-1] >= 0.9999) and \
               (data.get_array(resolution, center, 'vta', 'cuboid', 'com')[-1] >= 0.9999):
                logger.info(f'Breaking out of edge radius loop at edge radius {size}')
                break


        del cs_datas
        gc.collect()
        logger.info(' ')
data.save_to_file('data_28_09_23.pkl')
# 18m50s

2023-09-28 12:45:27,721 - INFO - ------ 250um ------
2023-09-28 12:45:27,722 - INFO - Bern_massive
2023-09-28 12:45:27,722 - INFO - loading niftis
100%|██████████| 7395/7395 [00:02<00:00, 3676.12it/s]
100%|██████████| 7395/7395 [00:00<00:00, 1713166.42it/s]
2023-09-28 12:45:29,768 - INFO - converting to ras
2023-09-28 12:45:30,941 - INFO - rounding and converting to array
2023-09-28 12:45:32,182 - INFO - computing bounding box
100%|██████████| 7395/7395 [00:03<00:00, 2328.94it/s]
2023-09-28 12:45:35,396 - INFO - min corner [mm] : [  5.5 -20.  -15. ]
2023-09-28 12:45:35,427 - INFO - max corner [mm] : [19.5 -4.   5. ]
2023-09-28 12:45:35,427 - INFO - container shape : [56 64 80]
2023-09-28 12:45:35,427 - INFO - converting to common space (heavy processing)
100%|██████████| 7395/7395 [01:45<00:00, 70.42it/s]
2023-09-28 12:47:20,553 - INFO - rounding and converting to array
2023-09-28 12:47:23,756 - INFO - computing stn center of mass
2023-09-28 12:47:23,759 - INFO - STN space affine : 
[[

Data saved to data_28_09_23.pkl successfully.


In [57]:
data = STN_ALGO_DATA('data_28_09_23.pkl')

Data loaded from data_28_09_23.pkl successfully.


In [None]:
edge_radii = {
    0.25 : 32,
    0.5  : 16,
    1.0  : 9
}

In [None]:
center = 'bern_mapping'

filename = filenames[center][0]
for resolution in resolutions:
    common_space_shape = data.get_common_space_shape(resolution, center)
    common_space_affine = data.get_common_space_affine(resolution, center)
    img = nib.load(filename)
    vta = nib.processing.resample_from_to(from_img=img, to_vox_map=(common_space_shape, common_space_affine), order=1)
    vta = np.round(vta.get_fdata()).astype(np.int8)
    vta = vta.fill(1)

    for size in (1,5,10,15,20,25,30):
        mask = get_mask_cuboid(n_com_coords_in_common_space, size, vta)[0]

        nifti_obj = nib.Nifti1Image(
            dataobj=mask, 
            affine=common_space_affine)
        nib.save(nifti_obj, f'masks/{size}_edge_radius_stn_{int(1000*resolution)}um.nii')

In [None]:
centers = filenames.keys()
for resolution in resolutions:
    for center in centers:
        common_space_shape = data.get_common_space_shape(resolution, center)
        common_space_affine = data.get_common_space_affine(resolution, center)
        filename = filenames[center][0]

        img = nib.load(filename)
        vta = nib.processing.resample_from_to(from_img=img, to_vox_map=(common_space_shape, common_space_affine), order=1)
        vta = np.round(vta.get_fdata()).astype(np.int8)
        vta.fill(1)

        nifti_obj = nib.Nifti1Image(
            dataobj=vta, 
            affine=common_space_affine)
        nib.save(nifti_obj, f'common_spaces/{center}_common_space_{int(1000*resolution)}um.nii')

## Ploting STN one center (cuboid)

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)

resolution = 0.25
center = 'merged'
center_massive = ''.join([center,'_massive'])
center_mapping = ''.join([center,'_mapping'])

massive_color = 'seagreen'
mapping_color = 'red'

perc_vta_massive = data.get_array(resolution, center_massive, 'vta', 'cuboid', 'stn')
perc_mask_massive = data.get_array(resolution, center_massive, 'mask', 'cuboid', 'stn')
perc_vta_mapping = data.get_array(resolution, center_mapping, 'vta', 'cuboid', 'stn')
perc_mask_mapping = data.get_array(resolution, center_mapping, 'mask', 'cuboid', 'stn')
# Plot the results

window_80_massive = np.argmax(perc_vta_massive>0.80)
window_90_massive = np.argmax(perc_vta_massive>0.9)
window_3sigma_massive = np.argmax(perc_vta_massive>0.997)
window_100_mapping = np.argmax(perc_vta_mapping>0.999)



In [None]:
perc_vta_massive

In [None]:

fig = plt.figure()
ax1 = fig.add_subplot(111)

ax1.plot(range(len(perc_vta_massive)), perc_vta_massive, label='Massive voxels', c=massive_color, ls='-')
ax1.plot(range(len(perc_vta_mapping)), perc_vta_mapping, label='Mapping voxels', c=mapping_color, ls='-')
ax1.plot(range(len(perc_mask_massive)), perc_mask_massive, label='Massive container space', c=massive_color, ls=':')
ax1.plot(range(len(perc_mask_mapping)), perc_mask_mapping, label='Mapping container space', c=mapping_color, ls=':')
ax1.xaxis.tick_top()
ax1.xaxis.set_label_position('top')
ax1.xaxis.set_ticks_position('top')
ax1.legend(bbox_to_anchor=(1.05, 1))



axins = zoomed_inset_axes(ax1, 3, loc='lower right', borderpad=2) # zoom = 6
#axins.patch.set_facecolor('lightgray')


axins.plot(range(len(perc_vta_massive)), perc_vta_massive, c=massive_color, ls='-', zorder=10)
axins.plot(range(len(perc_vta_mapping)), perc_vta_mapping, c=mapping_color, ls='-', zorder=10)


axins.set_aspect(aspect=80, adjustable='box', anchor='E')


# sub region of the original image
x1, x2, y1, y2 = window_90_massive, window_100_mapping, 0.9, 1.0
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)


axins.xaxis.set_minor_locator(MultipleLocator(1))
axins.yaxis.set_minor_locator(MultipleLocator(0.01))


#plt.xticks(visible=False)
#plt.yticks(visible=False)

# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax1, axins, loc1=3, loc2=1, fc="none", ec="0.2", zorder=11)

axins.set_zorder(ax1.get_zorder() + 1)

axins.axvline(x=window_3sigma_massive, ymin=0.00, ymax=1-(10*(1.002-perc_vta_massive[window_3sigma_massive])), c=massive_color, alpha=0.7, ls='--')
axins.axhline(y=0.99, xmin=0.00, xmax=(window_3sigma_massive-window_90_massive)/(window_100_mapping-window_90_massive), c=massive_color, alpha=0.7, ls='--')
#axins.axvline(x=window_99_mapping, ymin=0.00, ymax=1-(5*(1.008-perc_vta_mapping[window_99_mapping])), c=mapping_color, alpha=0.7, ls='--')

#axins.text(x=window_99_mapping-1.3,y=0.88, s='99%', c=mapping_color, fontsize=9, backgroundcolor='white')
axins.text(x=window_3sigma_massive-0.5,y=0.92, s='99%', c=massive_color, fontsize=9, backgroundcolor='white')

ax1.set_xlabel('Window size [voxels]')
ax1.set_ylabel('% inclusion')
ax1.set_title(f"Inclusion of {center.replace('_', ' ').capitalize()} VTAs into STN-centered\n space with increasing edge radius, {int(resolution*1000)}um")
ax1.grid(True)

#plt.draw()
#plt.show()

In [None]:
print('Merged')
display_space_bounds(common_space_shape['merged_massive'], common_space_affine['merged_massive'])
print('Bern')
display_space_bounds(common_space_shape['bern_massive'], common_space_affine['bern_massive'])
print('Cologne')
display_space_bounds(common_space_shape['koln_massive'], common_space_affine['koln_massive'])

In [None]:
(window_99_massive-window_90_massive)/(window_100_mapping-window_90_massive)

In [None]:
print('Merged')
display_volume_size(common_space_shape['merged_massive'], common_space_affine['merged_massive'])
print('Bern')
display_volume_size(common_space_shape['bern_massive'], common_space_affine['bern_massive'])
print('Cologne')
display_volume_size(common_space_shape['koln_massive'], common_space_affine['koln_massive'])

## STN-centered mixed cuboid

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)

def res_to_vox(x):
    return x/resolution

def vox_to_res(x):
    return x*resolution

perc_vta_bern = percentages['stn']['vta']['cuboid']['bern']
perc_mask_bern = percentages['stn']['mask']['cuboid']['bern']
perc_vta_koln = percentages['stn']['vta']['cuboid']['koln']
perc_mask_koln = percentages['stn']['mask']['cuboid']['koln']
perc_vta_mixed = percentages['stn']['vta']['cuboid']['mixed']
perc_mask_mixed = percentages['stn']['mask']['cuboid']['mixed']

window_95_bern = np.argmax(perc_vta_bern>0.95)
window_99_bern = np.argmax(perc_vta_bern>0.99)
window_95_koln = np.argmax(perc_vta_koln>0.95)
window_99_koln = np.argmax(perc_vta_koln>0.99)
window_95_mixed = np.argmax(perc_vta_mixed>0.95)
window_99_mixed = np.argmax(perc_vta_mixed>0.99)

# Plot the results
fig = plt.figure(figsize=(8,5))
ax1 = fig.add_subplot(111)

ax1.plot(window_sizes, perc_vta_bern, label='Bern VTA voxels', c='indianred', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_mask_bern, label='Bern container space', c='indianred', ls=':', linewidth=3)

ax1.plot(window_sizes, perc_vta_koln, label='Köln VTA voxels', c='seagreen', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_mask_koln, label='Köln container space', c='seagreen', ls=':', linewidth=3)

ax1.plot(window_sizes, perc_vta_mixed, label='All VTA voxels', c='slateblue', ls='-', linewidth=3)
#ax1.plot(window_sizes, perc_mask_mixed, label='All container space', c='slateblue', ls=':')
#ax1.plot([1], [1], label='All VTA voxels', c='slateblue')
ax1.legend(loc='lower right', fontsize='small')

ax1.xaxis.set_minor_locator(MultipleLocator(1))

axins = zoomed_inset_axes(ax1, 3, loc='right', borderpad=2) # zoom = 6
axins.plot(window_sizes, perc_vta_bern, c='indianred', ls='-')
axins.plot(window_sizes, perc_vta_koln, c='seagreen', ls='-')
axins.plot(window_sizes, perc_vta_mixed, c='slateblue', ls='-',zorder=10, linewidth=3)

axins.set_aspect(aspect=20, adjustable='box', anchor='E')

# sub region of the original image
x1, x2, y1, y2 = 8, 15, 0.8, 1.0
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.xaxis.set_minor_locator(MultipleLocator(1))
axins.yaxis.set_minor_locator(MultipleLocator(0.01))
#axins.xaxis.tick_top()
axins.yaxis.tick_right()
secaxins = axins.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secaxins.xaxis.set_minor_locator(MultipleLocator(1))

#plt.xticks(visible=False)
#plt.yticks(visible=False)

# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.8")

inters95 = 0.62
inters99 = 0.92
axins.axvline(x=window_95_mixed, ymin=0.00, ymax=inters95, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_99_mixed, ymin=0.00, ymax=inters99, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_95_mixed, ymin=inters95, ymax=1.00, c='black', alpha=0.5, ls=':')
axins.axvline(x=window_99_mixed, ymin=inters99, ymax=1.00, c='black', alpha=0.5, ls=':')

axins.text(x=11.3,y=0.91, s='95%', c='black')
axins.text(x=14.3,y=0.91, s='99%', c='black')

secax = ax1.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secax.set_xlabel('Window size [mm]')
secax.xaxis.set_minor_locator(MultipleLocator(1))

ax1.set_xlabel('Window size [voxels]')
ax1.set_ylabel('% inclusion')
ax1.set_title('Inclusion of mapping VTAs into STN-centered\n MNI space with increasing window size', pad=20,fontdict={'fontsize':18})

plt.draw()
plt.show()

In [None]:
perc_vta_bern

In [None]:
perc_vta_mixed[13]

## Printing points

In [None]:
n_com_coords_in_common_space

In [None]:
print('Center of mass :')
for center,coords in n_com_coords_in_common_space.items():
    print(center)
    print(f'common space coords : {coords}')

    transformed_point = matrix_point_mul(common_space_affine[center], coords)
    print(f'real world coords  : {transformed_point}')

In [None]:
print('STN-center :')
for center,coords in stn_com_coords_in_common_space.items():
    print(center)
    print(f'common space coords : {coords}')

    transformed_point = matrix_point_mul(common_space_affine[center], coords)
    print(f'real world coords  : {transformed_point}')

In [None]:
print('Common space bounding box in real world coordinates [mm]')
for center in ('bern', 'koln', 'mixed'):
    print(center)
    shp = common_space_shape[center]
    min_pnt = np.array([0,0,0])
    max_pnt = np.array(shp)

    min_pnt = matrix_point_mul(common_space_affine[center], min_pnt)
    max_pnt = matrix_point_mul(common_space_affine[center], max_pnt)
    
    print(f'real world coords  -> min : {min_pnt}, max : {max_pnt}, diff : {max_pnt-min_pnt}')

In [None]:
print('STN-center - COM diff :')
print('[R A S] [neg -> COM more right, neg -> COM more anterior, neg -> COM more superior\n')
for (center,stn_coord), (center_, com_coord) in zip(stn_com_coords_in_common_space.items(), n_com_coords_in_common_space.items()):
    print(center)
    point = np.array(stn_coord) - np.array(com_coord)
    print(f'common space coords : {point} [voxels]')
    
    transformed_point =  point * resolution

    print(f'real world coords   : {transformed_point} [mm]\n')

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)

def res_to_vox(x):
    return x/0.25

def vox_to_res(x):
    return x*0.25

def size_to_vox(x):
    return ((x**(1./3.))-1)/70

def vox_to_size(x):
    return (70*x+1)**3

perc_vta_bern = percentages['com']['vta']['cuboid']['bern']
perc_mask_bern = percentages['com']['mask']['cuboid']['bern']
perc_vta_koln = percentages['com']['vta']['cuboid']['koln']
perc_mask_koln = percentages['com']['mask']['cuboid']['koln']
perc_vta_mixed = percentages['com']['vta']['cuboid']['mixed']
perc_mask_mixed = percentages['com']['mask']['cuboid']['mixed']

window_95_bern = np.argmax(perc_vta_bern>0.95)
window_99_bern = np.argmax(perc_vta_bern>0.99)
window_95_koln = np.argmax(perc_vta_koln>0.95)
window_99_koln = np.argmax(perc_vta_koln>0.99)
window_95_mixed = np.argmax(perc_vta_mixed>0.95)
window_99_mixed = np.argmax(perc_vta_mixed>0.99)

# Plot the results
fig = plt.figure(figsize=(8,5))
ax1 = fig.add_subplot(111)

ax1.plot(window_sizes, perc_vta_bern, label='Bern VTA voxels', c='indianred', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_mask_bern, label='Bern container space', c='indianred', ls=':', linewidth=3)

ax1.plot(window_sizes, perc_vta_koln, label='Köln VTA voxels', c='seagreen', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_mask_koln, label='Köln container space', c='seagreen', ls=':', linewidth=3)

ax1.plot(window_sizes, perc_vta_mixed, label='All VTA voxels', c='slateblue', ls='-', linewidth=3)
#ax1.plot(window_sizes, perc_mask_mixed, label='All container space', c='slateblue', ls=':')
#ax1.plot([1], [1], label='All VTA voxels', c='slateblue')
ax1.legend(loc='lower right', fontsize='small')

ax1.xaxis.set_minor_locator(MultipleLocator(1))

ax1.set_xlim(0, 35)

axins = zoomed_inset_axes(ax1, 3, loc='upper left', borderpad=2) # zoom = 6
axins.plot(window_sizes, perc_vta_bern, c='indianred', ls='-')
axins.plot(window_sizes, perc_vta_koln, c='seagreen', ls='-')
axins.plot(window_sizes, perc_vta_mixed, c='slateblue', ls='-',zorder=10, linewidth=3)

axins.set_aspect(aspect=90, adjustable='box', anchor='NW')

# sub region of the original image
x1, x2, y1, y2 = 17, 30, 0.9, 1.0
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.xaxis.set_minor_locator(MultipleLocator(1))
axins.yaxis.set_minor_locator(MultipleLocator(0.01))
#axins.xaxis.tick_top()
axins.yaxis.tick_right()
secaxins = axins.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secaxins.xaxis.set_minor_locator(MultipleLocator(1))

#plt.xticks(visible=False)
#plt.yticks(visible=False)

# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax1, axins, loc1=4, loc2=2, fc="none", ec="0.8")

axins.axvline(x=window_95_mixed, ymin=0.00, ymax=0.46, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_99_mixed, ymin=0.00, ymax=0.89, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_95_mixed, ymin=0.46, ymax=1.00, c='black', alpha=0.5, ls=':')
axins.axvline(x=window_99_mixed, ymin=0.89, ymax=1.00, c='black', alpha=0.5, ls=':')

axins.text(x=22.3,y=0.91, s='95%', c='black')
axins.text(x=26.3,y=0.91, s='99%', c='black')

secax = ax1.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secax.set_xlabel('Window size [mm]')
secax.xaxis.set_minor_locator(MultipleLocator(1))

ax2 = ax1.twiny()
fig.subplots_adjust(bottom=0.2)
# Move twinned axis ticks and label from top to bottom
ax2.xaxis.set_ticks_position("bottom")
ax2.xaxis.set_label_position("bottom")
# Offset the twin axis below the host
ax2.spines["bottom"].set_position(("axes", -0.15))
# Turn on the frame for the twin axis, but then hide all 
# but the bottom spine
ax2.set_frame_on(True)
ax2.patch.set_visible(False)
for sp in ax2.spines.values():
    sp.set_visible(False)
ax2.spines["bottom"].set_visible(True)

bot_xtickslabels = np.array((10_000, 50_000, 100_000, 150_000, 200_000, 250_000))
new_tick_locations = np.array([size_to_vox(x) for x in bot_xtickslabels])

ax2.set_xticks(new_tick_locations)
ax2.set_xticklabels([f"{x/1000:,.0f}k" for x in bot_xtickslabels])
ax2.set_xlabel(r"Volume size [voxels]")

#thirdax = ax1.secondary_xaxis('bottom', functions=(vox_to_size, size_to_vox))
#thirdax.set_xlabel('Space size')
#thirdax.xaxis.set_major_locator(MultipleLocator(1000))

ax1.set_xlabel('Window size [voxels]')
ax1.set_ylabel('% inclusion')
ax1.set_title('Inclusion of mapping VTAs into center of mass centered\n MNI space with increasing window size', pad=20,fontdict={'fontsize':18})

plt.draw()
plt.show()

## COM all centers sphere

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
import math

def res_to_vox(x):
    return x/0.25

def vox_to_res(x):
    return x*0.25

def size_to_vox(x):
    return ((x/40)**(1./3))*((3/4)/math.pi)

def vox_to_size(x):
    return (4/3)*math.pi*(40*x)**3

perc_vta_bern = percentages['com']['vta']['sphere']['bern']
perc_mask_bern = percentages['com']['mask']['sphere']['bern']
perc_vta_koln = percentages['com']['vta']['sphere']['koln']
perc_mask_koln = percentages['com']['mask']['sphere']['koln']
perc_vta_mixed = percentages['com']['vta']['sphere']['mixed']
perc_mask_mixed = percentages['com']['mask']['sphere']['mixed']

window_95_bern = np.argmax(perc_vta_bern>0.95)
window_99_bern = np.argmax(perc_vta_bern>0.99)
window_95_koln = np.argmax(perc_vta_koln>0.95)
window_99_koln = np.argmax(perc_vta_koln>0.99)
window_95_mixed = np.argmax(perc_vta_mixed>0.95)
window_99_mixed = np.argmax(perc_vta_mixed>0.99)

# Plot the results
fig = plt.figure(figsize=(8,5))
ax1 = fig.add_subplot(111)

ax1.plot(window_sizes, perc_vta_bern, label='Bern VTA voxels', c='indianred', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_mask_bern, label='Bern container space', c='indianred', ls=':', linewidth=3)

ax1.plot(window_sizes, perc_vta_koln, label='Köln VTA voxels', c='seagreen', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_mask_koln, label='Köln container space', c='seagreen', ls=':', linewidth=3)

ax1.plot(window_sizes, perc_vta_mixed, label='All VTA voxels', c='slateblue', ls='-', linewidth=3)
#ax1.plot(window_sizes, perc_mask_mixed, label='All container space', c='slateblue', ls=':')
#ax1.plot([1], [1], label='All VTA voxels', c='slateblue')
ax1.legend(loc='lower right', fontsize='small')

ax1.xaxis.set_minor_locator(MultipleLocator(1))

ax1.set_xlim(0, 40)

axins = zoomed_inset_axes(ax1, 3, loc='upper left', borderpad=2) # zoom = 6
axins.plot(window_sizes, perc_vta_bern, c='indianred', ls='-')
axins.plot(window_sizes, perc_vta_koln, c='seagreen', ls='-')
axins.plot(window_sizes, perc_vta_mixed, c='slateblue', ls='-',zorder=10, linewidth=3)

axins.set_aspect(aspect=90, adjustable='box', anchor='NW')

# sub region of the original image
x1, x2, y1, y2 = 21, 35, 0.9, 1.0
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.xaxis.set_minor_locator(MultipleLocator(1))
axins.yaxis.set_minor_locator(MultipleLocator(0.01))
#axins.xaxis.tick_top()
axins.yaxis.tick_right()
secaxins = axins.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secaxins.xaxis.set_minor_locator(MultipleLocator(1))

#plt.xticks(visible=False)
#plt.yticks(visible=False)

# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax1, axins, loc1=4, loc2=2, fc="none", ec="0.8")

axins.axvline(x=window_95_mixed, ymin=0.00, ymax=0.46, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_99_mixed, ymin=0.00, ymax=0.89, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_95_mixed, ymin=0.46, ymax=1.00, c='black', alpha=0.5, ls=':')
axins.axvline(x=window_99_mixed, ymin=0.89, ymax=1.00, c='black', alpha=0.5, ls=':')

axins.text(x=26.3,y=0.91, s='95%', c='black')
axins.text(x=31.3,y=0.91, s='99%', c='black')

secax = ax1.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secax.set_xlabel('Window size [mm]')
secax.xaxis.set_minor_locator(MultipleLocator(1))

ax2 = ax1.twiny()
fig.subplots_adjust(bottom=0.2)
# Move twinned axis ticks and label from top to bottom
ax2.xaxis.set_ticks_position("bottom")
ax2.xaxis.set_label_position("bottom")
# Offset the twin axis below the host
ax2.spines["bottom"].set_position(("axes", -0.15))
# Turn on the frame for the twin axis, but then hide all 
# but the bottom spine
ax2.set_frame_on(True)
ax2.patch.set_visible(False)
for sp in ax2.spines.values():
    sp.set_visible(False)
ax2.spines["bottom"].set_visible(True)

bot_xtickslabels = np.array((10_000, 50_000, 100_000, 150_000, 200_000, 250_000))
new_tick_locations = np.array([size_to_vox(x) for x in bot_xtickslabels])

ax2.set_xticks(new_tick_locations)
ax2.set_xticklabels([f"{x/1000:,.0f}k" for x in bot_xtickslabels])
ax2.set_xlabel(r"Volume size [voxels]")

#thirdax = ax1.secondary_xaxis('bottom', functions=(vox_to_size, size_to_vox))
#thirdax.set_xlabel('Space size')
#thirdax.xaxis.set_major_locator(MultipleLocator(1000))

ax1.set_xlabel('Window size [voxels]')
ax1.set_ylabel('% inclusion')
ax1.set_title('Inclusion of mapping VTAs into center of mass centered\n MNI space with increasing window size', pad=20,fontdict={'fontsize':18})

plt.draw()
plt.show()

## N-image vs STN cuboid

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)

def res_to_vox(x):
    return x/0.25

def vox_to_res(x):
    return x*0.25

def size_to_vox(x):
    return ((x**(1./3.))-1)/70

def vox_to_size(x):
    return (70*x+1)**3


perc_vta_mixed_com = percentages['com']['vta']['cuboid']['mixed']
perc_vta_mixed_stn = percentages['stn']['vta']['cuboid']['mixed']

window_95_mixed_com = np.argmax(perc_vta_mixed_com>0.95)
window_99_mixed_com = np.argmax(perc_vta_mixed_com>0.99)
window_95_mixed_stn = np.argmax(perc_vta_mixed_stn>0.95)
window_99_mixed_stn = np.argmax(perc_vta_mixed_stn>0.99)

# Plot the results
fig = plt.figure(figsize=(8,5))
ax1 = fig.add_subplot(111)


ax1.plot(window_sizes, perc_vta_mixed_com, label='N-image CoM', c='forestgreen', ls='-', linewidth=3)
ax1.plot(window_sizes, perc_vta_mixed_stn, label='STN CoM', c='darkgoldenrod', ls='-', linewidth=3)
#ax1.plot(window_sizes, perc_mask_mixed, label='All container space', c='slateblue', ls=':')
#ax1.plot([1], [1], label='All VTA voxels', c='slateblue')
ax1.legend(loc='lower right', fontsize='small')

ax1.xaxis.set_minor_locator(MultipleLocator(1))

ax1.set_xlim(0, 17)

axins = zoomed_inset_axes(ax1, 6, loc='upper left', borderpad=2) # zoom = 6
axins.plot(window_sizes, perc_vta_mixed_com, label='N-image CoM', c='forestgreen', ls='-', linewidth=3)
axins.plot(window_sizes, perc_vta_mixed_stn, label='STN CoM', c='darkgoldenrod', ls='-', linewidth=3)

axins.set_aspect(aspect=120, adjustable='box', anchor='NW')

# sub region of the original image
x1, x2, y1, y2 = 7, 16, 0.94, 1.0
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.xaxis.set_minor_locator(MultipleLocator(1))
axins.yaxis.set_minor_locator(MultipleLocator(0.01))
#axins.xaxis.tick_top()
axins.yaxis.tick_right()
secaxins = axins.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secaxins.xaxis.set_minor_locator(MultipleLocator(1))

#plt.xticks(visible=False)
#plt.yticks(visible=False)

# draw a bbox of the region of the inset axes in the parent axes and
# connecting lines between the bbox and the inset axes area
mark_inset(ax1, axins, loc1=4, loc2=2, fc="none", ec="0.8")

axins.axvline(x=window_99_mixed_stn, ymin=0.00, ymax=0.76, c='black', alpha=0.7, ls='-')
axins.axvline(x=window_99_mixed_com, ymin=0.00, ymax=0.715, c='black', alpha=0.7, ls='-')
#axins.axhline(y=0.99, xmin=0.55, xmax=1.00, c='black', alpha=0.7, ls='-')

secax = ax1.secondary_xaxis('top', functions=(vox_to_res, res_to_vox))
secax.set_xlabel('Window size [mm]')
secax.xaxis.set_minor_locator(MultipleLocator(1))

ax2 = ax1.twiny()
fig.subplots_adjust(bottom=0.2)
# Move twinned axis ticks and label from top to bottom
ax2.xaxis.set_ticks_position("bottom")
ax2.xaxis.set_label_position("bottom")
# Offset the twin axis below the host
ax2.spines["bottom"].set_position(("axes", -0.15))
# Turn on the frame for the twin axis, but then hide all 
# but the bottom spine
ax2.set_frame_on(True)
ax2.patch.set_visible(False)
for sp in ax2.spines.values():
    sp.set_visible(False)
ax2.spines["bottom"].set_visible(True)

bot_xtickslabels = np.array((10_000, 50_000, 100_000, 150_000, 200_000, 250_000))
new_tick_locations = np.array([size_to_vox(x) for x in bot_xtickslabels])

ax2.set_xticks(new_tick_locations)
ax2.set_xticklabels([f"{x/1000:,.0f}k" for x in bot_xtickslabels])
ax2.set_xlabel(r"Volume size [voxels]")

#thirdax = ax1.secondary_xaxis('bottom', functions=(vox_to_size, size_to_vox))
#thirdax.set_xlabel('Space size')
#thirdax.xaxis.set_major_locator(MultipleLocator(1000))

ax1.set_xlabel('Window size [voxels]')
ax1.set_ylabel('Inclusion')
ax1.set_title('N-image-centered vs STN-centered \nmethods for mapping VTAs inclusion', pad=20,fontdict={'fontsize':18})
# after plotting the data, format the labels
current_values = plt.gca().get_yticks()
# using format string '{:.0f}' here but you can choose others
plt.gca().set_yticklabels([f'{x*100:.0f}%' for x in current_values])


# Here is the label and arrow code of interest
x_anot = 0.665
axins.annotate('$\Delta$ = 9k voxels', xy=(x_anot, -0.05), xytext=(x_anot, -0.4), xycoords='axes fraction', 
            fontsize=9, ha='center', va='bottom',
            arrowprops=dict(arrowstyle='-[, widthB=1.3, lengthB=0.2', lw=1.0))

plt.draw()
plt.show()

In [None]:
(2*14+1)**3 - (2*12+1)**3

In [None]:
(2*14+1)**3

## Creating optimized dataset-independent affine matrix and container shape

In [58]:
opti_edge_radii = {}
resolutions_ = []
centers = []

for key in data.data.keys():
    resolution = key.split('-')[0]
    if resolution not in resolutions_:
        resolutions_.append(resolution)
    center = key.split('-')[1]
    if center not in centers:
        centers.append(center)

for resolution in resolutions_:
    for center in centers:
        arr = data.get_array(resolution, center, 'vta', 'cuboid', 'stn')
        if arr is not None:
            for i, sigma in enumerate((0.6827, 0.945, 0.9973)):
                new_key = f"{resolution}-{center}-{i+1}sigma"
                opti_edge_radii[new_key] = np.argmax(arr > sigma)

In [59]:
# CS CoM + windows size -> affine + container
# container space = [2 * window size + 1 for 3 dims]
# affine matrix = I of resolutions, offsets = CS_mixed @ CoM_mixed - window size * res
opti_affine_matrix, opti_container_shape = {}, {}
for resolution in resolutions:
    for i in range(3):
        print(float(resolution), i+1)
        key = f'{resolution}-merged_massive-{i+1}sigma'
        
        opti_window_size = data.opti_edge_radii[key]
        opti_container_shape[(resolution, i+1)] =np.full((3,),2*opti_window_size+1, dtype=int)
        common_space_affine = data.get_common_space_affine(resolution, 'merged_massive')
        stn_com_coords_in_common_space = data.get_stn_com_coords_in_common_space(resolution, 'merged_massive')
        opti_offsets = (matrix_point_mul(common_space_affine,stn_com_coords_in_common_space) - opti_window_size * resolution)
        opti_affine_matrix[(resolution, i+1)] = np.array([
            [resolution, 0, 0, opti_offsets[0]],
            [0, resolution, 0, opti_offsets[1]],
            [0, 0, resolution, opti_offsets[2]],
            [0, 0, 0, 1]
        ])

0.25 1
0.25 2
0.25 3
0.5 1
0.5 2
0.5 3
1.0 1
1.0 2
1.0 3


In [60]:
opti_affine_matrix

{(0.25,
  1): array([[  0.25      ,   0.        ,   0.        ,   7.81873258],
        [  0.        ,   0.25      ,   0.        , -15.99244764],
        [  0.        ,   0.        ,   0.25      , -11.11541605],
        [  0.        ,   0.        ,   0.        ,   1.        ]]),
 (0.25,
  2): array([[  0.25      ,   0.        ,   0.        ,   5.81873258],
        [  0.        ,   0.25      ,   0.        , -17.99244764],
        [  0.        ,   0.        ,   0.25      , -13.11541605],
        [  0.        ,   0.        ,   0.        ,   1.        ]]),
 (0.25,
  3): array([[  0.25      ,   0.        ,   0.        ,   3.56873258],
        [  0.        ,   0.25      ,   0.        , -20.24244764],
        [  0.        ,   0.        ,   0.25      , -15.36541605],
        [  0.        ,   0.        ,   0.        ,   1.        ]]),
 (0.5,
  1): array([[  0.5       ,   0.        ,   0.        ,   7.81873258],
        [  0.        ,   0.5       ,   0.        , -15.99244764],
        [  0.      

In [61]:
opti_container_shape

{(0.25, 1): array([29, 29, 29]),
 (0.25, 2): array([45, 45, 45]),
 (0.25, 3): array([63, 63, 63]),
 (0.5, 1): array([15, 15, 15]),
 (0.5, 2): array([23, 23, 23]),
 (0.5, 3): array([33, 33, 33]),
 (1.0, 1): array([7, 7, 7]),
 (1.0, 2): array([13, 13, 13]),
 (1.0, 3): array([17, 17, 17])}

In [None]:
for resolution in resolutions:
    print(f'{int(resolution*1000)}um ')
    common_space_shape = data.get_common_space_shape(resolution, 'merged_massive')
    print(common_space_shape)

In [None]:
for resolution in resolutions:
    print(f'{int(resolution*1000)}um ')
    common_space_shape = data.get_common_space_shape(resolution, 'merged_massive')
    print(f'Voxels usage compared to fully inclusive common space : {100*np.prod(opti_container_shape[int(1000*resolution)])/np.prod(common_space_shape):.2f}%')
    print(f'{int(np.floor(np.prod(common_space_shape)/np.prod(opti_container_shape[resolution])))}-fold voxel reduction')

In [None]:
data

In [62]:
import gc

table = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table.csv')

table_bern = table[table['center'] == 'Bern']
table_koln = table[table['center'] == 'Cologne']

filenames_bern = table_bern['massive_filename'].to_list()
filenames_koln = table_koln['massive_filename'].to_list()
filenames_merged = table['massive_filename'].to_list()

filenames_dict = {
    'bern' : filenames_bern,
    'koln' : filenames_koln,
    'merged': filenames_merged
}

filenames = filenames_merged

for resolution in resolutions:
    print(f'resolution : {int(1000*resolution)}')
    print('loading niftis')
    imgs = [nib.load(filename) for filename in tqdm(filenames)]

    # flip all the VTAS to right side & force isotropy on space
    #print('converting to ras')
    #imgs = [reorient_to_ras(img) for img in tqdm(imgs)]
    sum_voxels = 0
    for img in tqdm(imgs):
        img_data = np.round(img.get_fdata())
        if img_data.ndim > 1:
            sum_voxels += np.sum(np.ravel(img_data))
        else:
            sum_voxels += img_data
    # binarizing and vectorizing numpy arrays of VTAs
    print('converting to container common space')
    common_space_affine = data.get_common_space_affine(resolution, 'merged_massive')
    common_space_shape = data.get_common_space_shape(resolution, 'merged_massive')
    common_space_datas = [nib.processing.resample_from_to(from_img=img, to_vox_map=(common_space_shape, common_space_affine)) for img in tqdm(imgs)]

    print('rounding and converting to int')
    common_space_vtas = []
    for vta in tqdm(common_space_datas):
        common_space_vtas.append(np.where(vta.get_fdata() >= 0.5, 1, 0).astype(np.int8))
    common_space_vtas = np.array(common_space_vtas)
    common_space_sum = np.sum(common_space_vtas)

    output_path = f'/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/interim/common_space/merged/flipped/VTAs_uncleaned/{int(resolution*1000)}um.npz'
    Path(output_path).parents[0].mkdir(parents=True, exist_ok=True)
    np.savez_compressed(output_path, common_space_vtas)

    del common_space_datas, common_space_vtas
    gc.collect()

    for i in range(3):
        print(f'sigma : {i+1}')

        print('converting to opti common space')
        opti_space_datas = [nib.processing.resample_from_to(from_img=img, to_vox_map=(opti_container_shape[(resolution, i+1)], opti_affine_matrix[(resolution, i+1)]), order=1) for img in tqdm(imgs)]
        opti_space_vtas = []
        for vta in tqdm(opti_space_datas):
            opti_space_vtas.append(np.where(vta.get_fdata() >= 0.5, 1, 0).astype(np.int8))
        opti_space_vtas = np.array(opti_space_vtas)
        print(f'activated voxels original space : {sum_voxels}')
        print(f'container act vox : {common_space_sum:.0f}, opti act vox : {np.sum(opti_space_vtas):.0f}, ratio : {(100*np.sum(opti_space_vtas)/common_space_sum):.2f}%') 
        output_path = f'/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/interim/stn_space_{i+1}sigma/merged/flipped/VTAs_uncleaned/{int(resolution*1000)}um.npz'
        Path(output_path).parents[0].mkdir(parents=True, exist_ok=True)
        np.savez_compressed(output_path, opti_space_vtas)
        del opti_space_datas, opti_space_vtas
        gc.collect()
# 15m02s

resolution : 250
loading niftis


  0%|          | 0/8556 [00:00<?, ?it/s]

100%|██████████| 8556/8556 [00:02<00:00, 3753.21it/s]
100%|██████████| 8556/8556 [00:06<00:00, 1420.73it/s]


converting to container common space


100%|██████████| 8556/8556 [08:14<00:00, 17.32it/s]


rounding and converting to int


100%|██████████| 8556/8556 [00:06<00:00, 1321.65it/s]


sigma : 1
converting to opti common space


100%|██████████| 8556/8556 [00:14<00:00, 580.89it/s]
100%|██████████| 8556/8556 [00:01<00:00, 5461.25it/s]


activated voxels original space : 17358200.0
container act vox : 44260877, opti act vox : 30134353, ratio : 68.08%
sigma : 2
converting to opti common space


100%|██████████| 8556/8556 [00:44<00:00, 193.63it/s]
100%|██████████| 8556/8556 [00:02<00:00, 4027.34it/s]


activated voxels original space : 17358200.0
container act vox : 44260877, opti act vox : 41710840, ratio : 94.24%
sigma : 3
converting to opti common space


100%|██████████| 8556/8556 [01:41<00:00, 84.65it/s] 
100%|██████████| 8556/8556 [00:05<00:00, 1558.07it/s]


activated voxels original space : 17358200.0
container act vox : 44260877, opti act vox : 43560962, ratio : 98.42%
resolution : 500
loading niftis


100%|██████████| 8556/8556 [00:02<00:00, 3844.17it/s]
100%|██████████| 8556/8556 [00:06<00:00, 1405.75it/s]


converting to container common space


100%|██████████| 8556/8556 [01:27<00:00, 97.91it/s] 


rounding and converting to int


100%|██████████| 8556/8556 [00:00<00:00, 9411.93it/s] 


sigma : 1
converting to opti common space


100%|██████████| 8556/8556 [00:05<00:00, 1683.46it/s]
100%|██████████| 8556/8556 [00:00<00:00, 92787.67it/s]


activated voxels original space : 17358200.0
container act vox : 5548054, opti act vox : 3971612, ratio : 71.59%
sigma : 2
converting to opti common space


100%|██████████| 8556/8556 [00:09<00:00, 939.85it/s] 
100%|██████████| 8556/8556 [00:00<00:00, 43363.19it/s]


activated voxels original space : 17358200.0
container act vox : 5548054, opti act vox : 5285429, ratio : 95.27%
sigma : 3
converting to opti common space


100%|██████████| 8556/8556 [00:17<00:00, 485.87it/s]
100%|██████████| 8556/8556 [00:00<00:00, 16614.08it/s]


activated voxels original space : 17358200.0
container act vox : 5548054, opti act vox : 5478982, ratio : 98.76%
resolution : 1000
loading niftis


100%|██████████| 8556/8556 [00:02<00:00, 3653.41it/s]
100%|██████████| 8556/8556 [00:06<00:00, 1400.03it/s]


converting to container common space


100%|██████████| 8556/8556 [00:36<00:00, 231.83it/s]


rounding and converting to int


100%|██████████| 8556/8556 [00:00<00:00, 82245.97it/s]


sigma : 1
converting to opti common space


100%|██████████| 8556/8556 [00:03<00:00, 2370.27it/s]
100%|██████████| 8556/8556 [00:00<00:00, 164803.54it/s]


activated voxels original space : 17358200.0
container act vox : 693445, opti act vox : 460737, ratio : 66.44%
sigma : 2
converting to opti common space


100%|██████████| 8556/8556 [00:04<00:00, 1851.88it/s]
100%|██████████| 8556/8556 [00:00<00:00, 112556.74it/s]


activated voxels original space : 17358200.0
container act vox : 693445, opti act vox : 676791, ratio : 97.60%
sigma : 3
converting to opti common space


100%|██████████| 8556/8556 [00:05<00:00, 1524.10it/s]
100%|██████████| 8556/8556 [00:00<00:00, 78641.99it/s]


activated voxels original space : 17358200.0
container act vox : 693445, opti act vox : 685340, ratio : 98.83%


In [None]:
common_space_shape