In [503]:

import numpy as np

from scipy.ndimage import gaussian_filter
import subprocess
from soma import aims

from sulci.registration.spam import spam_register
from soma.aimsalgo import MorphoGreyLevel_S16

import sulci.registration.spam
from sulci.models import distribution_aims
from sulci.registration.spam import (
    SpamRegistration, dilate_spam_mask, move_image_slightly, spam_register)
from soma import aims, aimsalgo

# Global static variables
_AIMS_BINARY_ONE = 32767
_dilation = 0
_threshold = 1
_dilation_final = 5
_threshold_final = 0
_edge_smoothing = 10.
_nb_spam_subjects = 61

# Anatomist

import anatomist.api as ana
from soma.qt_gui.qtThread import QtThreadCall
from soma.qt_gui.qt_backend import Qt
# launching anatomist
a = ana.Anatomist()

# Define functions

In [504]:
def dilate(mask, radius=_dilation):
    """Makes a dilation radius _dilation, in mm
    """
    arr = mask.np
    # Binarization of mask
    arr[arr < 1] = 0
    if radius > 0:
        arr[arr >= 1] = _AIMS_BINARY_ONE
        # Dilates initial volume of 10 mm
        morpho = MorphoGreyLevel_S16()
        dilate = morpho.doDilation(mask, radius)
        arr_dilate = dilate.np
        arr_dilate[arr_dilate >= 1] = 1
        return dilate
    else:
        arr[arr >= 1] = 1
        return mask

In [505]:
def realign_mi_register(spam_vol: aims.Volume_FLOAT, skel_vol: aims.Volume_S16):
    """Realigns skeleton mask to spam
    
    skel_vol is the test aims volume"""
    
    spam_vol = aims.Volume_FLOAT(spam_vol)
    spam_vol.np[:] = spam_vol.np*100
    skel_vol2 = aims.Volume_S16(skel_vol)
    skel_vol = aims.Volume(skel_vol2.getSize(), 'FLOAT')
    skel_vol.copyHeaderFrom(skel_vol2.header())
    skel_vol.np[:] = skel_vol2.np
    
    # Reads skeleton file
    skel_vol.np[:] = (skel_vol.np > 0).astype(np.float32)
    
    # # Makes binarization and dilation on spam
    # mask_result = aims.Volume(spam_vol.getSize(), 'S16')
    # mask_result.copyHeaderFrom(spam_vol.header())
    # mask_result.np[:] = spam_vol.np
    # mask_result[mask_result.np <= _threshold] = 0
    # mask_result.np[:] = dilate(mask_result).np

    # # Masks skeleton data with dilated spam
    # skel_vol.np[mask_result.np <= 0] = 0
    # skel_vol_masked = aims.Volume_FLOAT(skel_vol)
    aims.write(skel_vol, "/tmp/skel_before.nii.gz")
    skel_vol.np[:] = 100*gaussian_filter(skel_vol.np, sigma=2)
    print(np.unique(skel_vol.np))
    aims.write(skel_vol, "/tmp/skel_before_filtered.nii.gz")
    
    # Writes nifti files
    aims.write(spam_vol, "/tmp/spam.nii.gz")
    
    # Makes realignment
    subprocess.check_call("AimsMIRegister -r /tmp/spam.nii.gz  -t /tmp/skel_before_filtered.nii.gz --dir /tmp/transform.trm", shell=True)
    # print(out_tr.np)
    
    # Applies the realignment
    subprocess.check_call("AimsApplyTransform -i /tmp/skel_before.nii.gz -o /tmp/skel_realigned.nii.gz -m /tmp/transform.trm", shell=True)
    
    # loads realigned file:
    after_vol = aims.read("/tmp/skel_realigned.nii.gz")
    
    return after_vol #, skel_vol_masked

In [506]:
def do_masking_dilation(spam_vol, skel_vol, dilation, threshold, do_binarization):
   
    spam_vol = aims.Volume_FLOAT(spam_vol)
    skel_vol = aims.Volume_S16(skel_vol)
    
    # Do binarization for registration
    if do_binarization:
        skel_vol.np[:] = (skel_vol.np > 0).astype(np.int16)

    # # Makes binarization and dilation on spam
    # mask_result = aims.Volume(spam_vol.getSize(), 'S16')
    # mask_result.copyHeaderFrom(spam_vol.header())
    # mask_result.np[:] = spam_vol.np
    # mask_result.np[mask_result.np <= _threshold] = 0.
    # mask_result.np[:] = dilate(mask_result).np
    # print(mask_result.np.sum()) 
    
    # # Do the actual masking
    # skel_vol.np[mask_result.np > 0.] = 0
    
    return skel_vol

In [507]:
def realign_spam_register(spam_vol: aims.Volume_FLOAT, skel_vol: aims.Volume_S16, mask_vol: aims.Volume_FLOAT, do_edge_smoothing: bool):
    """Realigns skeleton mask to spam"""
        
    spam_vol = aims.Volume_FLOAT(spam_vol)
    skel_vol = aims.Volume_S16(skel_vol)
    
    # Masks with first dilation and threshold
    skel_vol_before = do_masking_dilation(spam_vol, skel_vol, _dilation, _threshold, True)
    aims.write(skel_vol_before, "/tmp/skel_before.nii.gz")
    
    if do_edge_smoothing:
        g = aimsalgo.Gaussian3DSmoothing_FLOAT(_edge_smoothing, _edge_smoothing, _edge_smoothing)
        mask_vol = g.doit(mask_vol)
        mask_vol.np[:] = mask_vol.np / mask_vol.max()
        spam_vol.np[:] = spam_vol.np * mask_vol.np
    
    # Computes transform for realignment
    out_tr = spam_register(spam_vol,
                        skel_vol_before,
                        do_mask=False,
                        eps=1e-5,
                        R_angle_var=np.pi / 8,
                        t_var=10.,
                        verbose=True,
                        in_log=False,
                        calibrate_distrib=30)
    print(out_tr)
    aims.write(out_tr, '/tmp/transform.trm')
    
    # Masks with final dilation and threshold
    skel_vol = do_masking_dilation(spam_vol, skel_vol, _dilation_final, _threshold_final, False)
    aims.write(skel_vol, "/tmp/skel_final_before.nii.gz")
    
    # Applies the realignment
    subprocess.check_call(f"AimsApplyTransform -i /tmp/skel_final_before.nii.gz -o /tmp/skel_final_realigned.nii.gz -m /tmp/transform.trm", shell=True)
    
    # Applies the inverse of the realignment to mask
    subprocess.check_call(f"AimsApplyTransform -i /tmp/spam_init.nii.gz -o /tmp/spam_realigned.nii.gz -I /tmp/transform.trm", shell=True)
    
    # loads realigned file:
    after = aims.read("/tmp/skel_final_realigned.nii.gz")
    spam_after = aims.read("/tmp/spam_realigned.nii.gz")
    
    return after, spam_after

# Toy model

In [508]:
hdr = dict(aims.StandardReferentials.icbm2009cTemplateHeader())
dims = (np.ceil(np.array(hdr['volume_dimension']) / 2)).astype(int)
vs = np.array(hdr['voxel_size']) * 2
hdr['voxel_size'] = vs

################################
# Creates skeleton as two lines
################################
skel_vol = aims.Volume(list(dims), dtype='S16')
skel_vol.copyHeaderFrom(hdr)
skel_vol.fill(0)
# skel_vol[60, 10:100, 50] = 1
skel_vol[30, 30:70, 50] = 1 

################################
# Creates spam centered on skeleton
################################
spam_vol = aims.Volume(list(dims), dtype='FLOAT')
spam_vol.copyHeaderFrom(hdr)
spam_vol.fill(0)
spam_vol[:] = skel_vol[:]

g = aimsalgo.Gaussian3DSmoothing_FLOAT(8., 8., 8.)
spam_vol = g.doit(spam_vol)
spam_vol[spam_vol.np < 0] = 0.

spam_vol.np[:] = spam_vol.np
spam_vol.np[:] = spam_vol.np / spam_vol.max()
spam_vol[spam_vol.np < 1/200.] = 0. # Our real SPAM file is a concatenation of 60 subjects, so it can't be smaller than 1/60

################################
# Creates mask
################################
mask_vol = aims.Volume(list(dims), dtype='S16')
mask_vol.copyHeaderFrom(hdr)
mask_vol.fill(0)
mask_vol[spam_vol.np>0] = 1
mask_vol = dilate(mask_vol, _dilation_final)
c = aims.Converter(intype=aims.Volume('S16'), outtype=aims.Volume('FLOAT'))
mask_vol = c(mask_vol)
g = aimsalgo.Gaussian3DSmoothing_FLOAT(10., 10., 10.)
mask_vol = g.doit(mask_vol)

################################
# Adds some smoothed lines on SPAM
################################
spam_vol2 = aims.Volume(list(dims), dtype='FLOAT')
spam_vol2.copyHeaderFrom(hdr)
spam_vol2.fill(0)
# spam_vol2[70, 10:100, 60] = 1
spam_vol2[20, 30:70, 40] = 1
g = aimsalgo.Gaussian3DSmoothing_FLOAT(5., 5., 5.)
spam_vol2 = g.doit(spam_vol2)
spam_vol2[spam_vol2.np < 0] = 0.
spam_vol2.np[:] = spam_vol2.np
spam_vol2.np[:] = spam_vol2.np / spam_vol2.max()
spam_vol2[spam_vol2.np < 1/60.] = 0.

spam_vol2[spam_vol.np <= 0] = 0.
spam_vol.np[:] = spam_vol.np + spam_vol2.np


aims.write(spam_vol, '/tmp/spam_init.nii.gz')
mc = aims.MassCenters_FLOAT(spam_vol)
mc.doit()
gravity_center = np.expand_dims(np.array(mc.infos()['0'][0][0]), axis=1)

################################
# Changes one line of skeleton
################################
skel_vol[30, 30:70, 50] = 0
skel_vol[25, 25:65, 50] = 1 
# skel_vol[60, 10:100, 50] = 0
# skel_vol[55, 5:95, 50] = 1


moved_skel, tr = move_image_slightly(skel_vol, (0, 0, 1), np.pi / 20 * 1.5,
                                     np.array((-1.2, 2.3, -3.5)) * 2,
                                     gravity_center)

DILATION
  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 %

In [509]:
skel_vol.shape

(97, 115, 97, 1)

In [510]:
np.unique(moved_skel.np, return_counts=True)

(array([0, 1], dtype=int16), array([1081994,      41]))

In [511]:
np.max(spam_vol.np)

1.0

In [512]:
moved_skel.header()

{ 'referentials' : [ 'Talairach-MNI template-SPM' ], 'transformations' : [ [ -1, 0, 0, 96, 0, -1, 0, 96, 0, 0, -1, 114, 0, 0, 0, 1 ] ], 'voxel_size' : [ 2, 2, 2, 2 ], 'volume_dimension' : [ 97, 115, 97, 1 ], 'sizeX' : 97, 'sizeY' : 115, 'sizeZ' : 97, 'sizeT' : 1, 'referential' : '84b1989b-eb68-8665-0049-8feaf3c22679' }

In [513]:
gravity_center

array([[58.35338211],
       [98.99966431],
       [98.35235596]])

In [514]:
print(np.unique(spam_vol, return_counts=True))
print(np.unique(moved_skel, return_counts=True))

(array([0.        , 0.00501042, 0.00501048, ..., 0.9999733 , 0.99999946,
       1.        ], dtype=float32), array([1055551,       1,       2, ...,       1,       1,       1]))
(array([0, 1], dtype=int16), array([1081994,      41]))


In [515]:
sulci.registration.spam.__file__

'/casa/host/build/python/sulci/registration/spam.py'

In [516]:
after_vol, spam_after = realign_spam_register(spam_vol, moved_skel, mask_vol, True)
np.unique(after_vol.np)

en: -30.000000395425936 0.5179090515839313 -0.0 6.2106932391050868603
powell, en = -23.271398  [[0. 0. 0.]] [[0. 0. 0.]]
en: -30.000000395425936 0.5179090515839313 -0.0 6.2106932391050868603
powell, en = -23.271398  [[0. 0. 0.]] [[0. 0. 0.]]
en: -10.395025346337295 1.688519617168439 -0.0 6.2106932391050868603
powell, en = -2.495812  [[1. 0. 0.]] [[0. 0. 0.]]
en: -16.229308993351168 3.184633157384961 -0.0 6.2106932391050868603
powell, en = -6.833983  [[-1.618034  0.        0.      ]] [[0. 0. 0.]]
en: -30.000000395425936 0.5179090515839313 -0.0 6.2106932391050868603
powell, en = -23.271398  [[0. 0. 0.]] [[0. 0. 0.]]
en: -36.07139051832804 0.9889587582814627 -0.0 6.2106932391050868603
powell, en = -28.871739  [[-0.61803397  0.          0.        ]] [[0. 0. 0.]]
en: -22.596381449117892 1.688519583853946 -0.0 6.2106932391050868603
powell, en = -14.697169  [[-0.99999998  0.          0.        ]] [[0. 0. 0.]]
en: -45.528918382598135 0.7260728258913884 -0.0 6.2106932391050868603
powell, en = -



loading direct transformations
Output dimensions: 97, 115, 97
Output voxel size: 2, 2, 2 mm
Resampling carto_volume of S16...   0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 %




loading inverse transformations
Output dimensions: 97, 115, 97
Output voxel size: 2, 2, 2 mm
Resampling carto_volume of FLOAT...   0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 %


array([0, 1], dtype=int16)

In [517]:
# after_vol_mi = realign_mi_register(spam_vol, moved_skel)
# np.unique(after_vol_mi.np)

In [518]:
skel_a = a.toAObject(moved_skel)
spam_a = a.toAObject(spam_vol)
after_a = a.toAObject(after_vol)
mask_a = a.toAObject(mask_vol)
# after_mi_a = a.toAObject(after_vol_mi)
skel_a.setPalette("BLUE-lfusion")
after_a.setPalette("RED-lfusion")
mask_a.setPalette("Greens")
# after_mi_a.setPalette("VIOLET-lfusion")
w = a.createWindow('Coronal')
w.addObjects(spam_a)
w.addObjects(skel_a)
w.addObjects(after_a)
w.addObjects(mask_a)
# w.addObjects(after_mi_a)

observable 0x6414059c0de0(N9anatomist7AVolumeIsEE) could not be removed from observer 0x641404d39008 (N9anatomist8Fusion2DE)
observable 0x641407191d90(N9anatomist7AVolumeIsEE) could not be removed from observer 0x6414065789b8 (N9anatomist8Fusion2DE)
observable 0x641407191d90(N9anatomist7AVolumeIsEE) could not be removed from observer 0x6414065789b8 (N9anatomist8Fusion2DE)


: 