In [None]:
import json
import numpy as np
import os
import glob
from soma import aims

from sulci.registration.spam import spam_register

import anatomist.api as ana
from soma.qt_gui.qtThread import QtThreadCall
from soma.qt_gui.qt_backend import Qt

from soma.aimsalgo import MorphoGreyLevel_S16

# Global static variables
_AIMS_BINARY_ONE = 32767
_dilation = 1
_threshold = 2

# launching anatomist
a = ana.Anatomist()

In [None]:
def dilate(mask, radius=_dilation):
    """Makes a dilation radius _dilation, in mm
    """
    arr = mask.np
    # Binarization of mask
    arr[arr < 1] = 0
    arr[arr >= 1] = _AIMS_BINARY_ONE
    # Dilates initial volume of 10 mm
    morpho = MorphoGreyLevel_S16()
    dilate = morpho.doDilation(mask, radius)
    arr_dilate = dilate.np
    arr_dilate[arr_dilate >= 1] = 1
    return dilate

In [None]:
spam_file = '/neurospin/dico/data/deep_folding/current/mask/2mm/regions/L/Sc.Cal.-S.Li._left.nii.gz'
skel_path = '/neurospin/dico/data/deep_folding/current/datasets/pclean/binarized_skeletons/L'
skel_files = glob.glob(f'{skel_path}/*.nii.gz')
skel_files

In [None]:
mask_result = aims.read(spam_file)
# Makes binarization and dilation on spam
mask_result[mask_result.np <= _threshold] = 0
mask_result.np[:] = dilate(mask_result).np

In [None]:
def realign(skel_f):
    """Realigns skeleton mask to spam
    
    skel_f is a file name of skeleton file"""
    
    # Reads spam and skeleton files

    skel_data = aims.read(skel_f)
    skel_data.np[:] = (skel_data.np > 0).astype(np.int16)

    # Masks skeleton data with dilated spam
    skel_data.np[mask_result.np <= 0] = 0
    aims.write(skel_data, "/tmp/skel_before.nii.gz")
    
    # Reads initial spam volume
    spam_vol = aims.read(spam_file, dtype="Volume_FLOAT")
    spam_vol.np[:] = spam_vol.np
    
    # Makes realignment
    out_tr = spam_register(spam_vol,
                        skel_data,
                        do_mask=False,
                        R_angle_var=np.pi / 128,
                        t_var=5.,
                        verbose=False,
                        in_log=False,
                        calibrate_distrib=15)
    # out_tr.setTranslation((10, -5, 0))
    aims.write(out_tr, '/tmp/transform.trm')
    print(out_tr.np)
    
    # Applies the realignment
    os.system(f"AimsApplyTransform -i /tmp/skel_before.nii.gz -o /tmp/skel_realigned.nii.gz -I /tmp/transform.trm")
    
    # loads realigned file:
    after = aims.read("/tmp/skel_realigned.nii.gz")
    
    return after

In [None]:
mask_result

In [None]:
after_all = aims.Volume(mask_result.getSize(), 'S16')
after_all.copyHeaderFrom(mask_result.header())
list_after = []
for skel_f in skel_files:
    after = aims.Volume(mask_result.getSize(), 'S16')
    after.copyHeaderFrom(mask_result.header())
    after += realign(skel_f)
    after_all += after
    list_after.append(after)

In [None]:
np.unique(after_all.np)

In [None]:
# Visualization
spam = a.loadObject(spam_file)
spam.setPalette("Blues")
spam_after = a.toAObject(after_all)
spam_after.setPalette("Reds")
w = a.createWindow('Sagittal')
w.addObjects(spam)
w.addObjects(spam_after)

In [None]:
# spam = a.loadObject(spam_file)
# spam.setPalette("Blues")
# list_after_a = [a.toAObject(after) for after in list_after]
# for after in list_after_a:
#     after.setPalette("RED-lfusion")
# w = a.createWindow('Sagittal')
# w.addObjects(list_after_a)
# w.addObjects(spam)