In [1]:
import numpy as np
import os
from soma import aims
from skimage.morphology import ball, binary_dilation



In [2]:
class TrimEdgesTensor(object):
    """
    Trim the lateral edges of the folds based on sample_extremities.
    Parameters
    ----------
    p: probability to trim each branch (i.e. proportion of trimmed branches)
    protective structure: object such as morphology.ball(n). The object
    shape must be odd (so it has an int center).
    arr_extremities : binary mask of the trimmed skeleton voxels.
    """

    def __init__(self, sample_extremities, sample_foldlabel,
                 input_size, protective_structure, p=0.5):
        self.input_size = input_size
        self.protective_structure = protective_structure
        self.p = p
        self.sample_foldlabel = sample_foldlabel
        self.sample_extremities = sample_extremities
    
    def __call__(self, arr_skel):
        arr_foldlabel = self.sample_foldlabel
        arr_extremities = self.sample_extremities

        # log.debug(f"arr_skel.shape = {arr_skel.shape}")
        # log.debug(f"arr_foldlabel.shape = {arr_foldlabel.shape}")
        assert (self.p >= 0)

        arr_trimmed_branches = np.zeros(arr_skel.shape)
        indexed_branches = np.mod(arr_foldlabel,
                                np.full(arr_foldlabel.shape, fill_value=1000))
        indexes =  np.unique(indexed_branches)
        assert (len(indexes)>1), 'No branch in foldlabel'
        # loop over branches
        for index in indexes[1:]:
            mask_branch = indexed_branches==index
            branch = arr_skel * mask_branch
            r = np.random.uniform()
            if r < self.p:
                trimmed_branch = (1-arr_extremities) * branch
                if np.array_equal(branch!=0, trimmed_branch!=0): # nothing to trim
                    pass
                else:
                    # find mass center
                    coords = np.nonzero(branch)
                    center = [np.mean(coords[i]) for i in range(len(coords))]
                    center = (np.round(center)).astype(int)
                    # branch center is protected using given structure
                    mask_protection = np.zeros(branch.shape)
                    slc = [slice(c-s//2,c+s//2 +1) for c,s in zip(center, self.protective_structure.shape[:3])]
                    slc.append(slice(1))
                    mask_protection[tuple(slc)]=self.protective_structure
                    trimmed_branch = branch * np.logical_or(mask_protection, 1-arr_extremities)

                arr_trimmed_branches += trimmed_branch
            else:
                arr_trimmed_branches += branch
        arr_trimmed = arr_trimmed_branches.copy()

        
        arr_trimmed = arr_trimmed.astype('float32')

        return arr_trimmed


In [3]:
save_dir = '/volatile/jl277509/data/tmp/ukb'
subject = 'sub-1000021'

In [4]:
## load foldlabel and skeleton nifti to match trimmed shapes
skeleton = aims.read(f'/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/skeletons/raw/L/Lskeleton_generated_{subject}.nii.gz')
foldlabel = aims.read(f'/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/foldlabels/raw/L/Lfoldlabel_{subject}.nii.gz')
trimmed = aims.read(f'/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/trimmed_skeletons/L/Lextremities_{subject}.nii.gz')
ss = aims.read(f'/neurospin/dico/data/deep_folding/current/datasets/UkBioBank/trimmed_skeletons/L/Lskeleton_ss_edges_{subject}.nii.gz')

In [5]:
print(np.unique(skeleton.np, return_counts=True))
print(np.unique(foldlabel.np, return_counts=True))
print(np.unique(trimmed.np, return_counts=True))

(array([  0,  30,  35,  60, 100, 120], dtype=int16), array([5478346,    4527,    6669,   27811,    1724,     101]))
(array([   0,    1,    2,    3,    4,    5,    6,    9,   10,   11,   12,
         13,   14,   16,   18,   19,   20,   22,   23,   24,   25,   26,
         27,   29,   30,   32,   33,   34,   35,   37,   38,   41,   43,
         44,   46,   47,   48,   49,   50,   51,   52,   53,   55,   56,
         57,   58,   59,   60,   63,   66,   67,   68,   70,   72,   74,
         75,   76,   77,   78,   79,   83,   84,   85,   86,   87,   88,
         89,   91,   92,   93,   94,   95,   96,   97,   98,   99,  100,
        101,  102,  103,  104,  105,  106,  107,  108,  109,  110,  111,
        112,  113,  114,  115,  116,  117,  118,  119,  121,  122,  123,
        124,  125,  126,  128,  130,  131,  132,  133,  134,  135,  136,
        138,  141,  142,  144,  145,  146,  149,  151,  153,  157,  159,
        160,  163,  164,  166,  167,  168,  170,  171,  172,  173,  174,
       

In [6]:
## pad skeleton and foldlabel to trimmed size
shape_skel = skeleton.np.shape
print(shape_skel)
coords = np.nonzero(skeleton.np)
center_skel = [np.mean(coords[i]) for i in range(len(coords))]
print(center_skel)

(157, 217, 162, 1)
[116.11848550156739, 134.07849235893417, 103.03607464733543, 0.0]


In [7]:
shape_ss = ss.np.shape
print(shape_ss)
coords = np.nonzero(ss.np)
center_trimmed = [np.mean(coords[i]) for i in range(len(coords))]
print(center_trimmed)

(176, 235, 210, 1)
[116.22329645048204, 133.85166520595968, 102.89723926380368, 0.0]


In [8]:
# same gravity center => padding only on the right
skel = np.copy(skeleton.np)
pad_shape = ((0, shape_ss[0] - shape_skel[0]), (0, shape_ss[1] - shape_skel[1]), (0, shape_ss[2] - shape_skel[2]), (0,0))
skel = np.pad(skel, pad_shape, constant_values=0)

In [9]:
# same on foldlabel
label = np.copy(foldlabel.np)
label = np.pad(label, pad_shape, constant_values=0)

In [10]:
from skimage.morphology import ball, binary_dilation

# dilate trimmed voxels to encompass neighbouring top values (35)
## two arrays required : skel (skeleton), trimmed (extremities)
dilation_magnitude = 2

print(f'Voxels to trim without tops: {np.sum(trimmed.np)}')
trimmed_dilated = binary_dilation(trimmed.np[:,:,:,0], ball(dilation_magnitude))
trimmed_dilated = np.expand_dims(trimmed_dilated, axis=-1)
tops_to_trim = np.logical_and(trimmed_dilated, (skel==35))
new_trimmed = np.logical_or(trimmed.np, tops_to_trim)
print(f'Voxels to trim after addig tops : {np.sum(new_trimmed)}')
trimmed.np[:]=new_trimmed

Voxels to trim without tops: 2820
Voxels to trim after addig tops : 4294


In [11]:
proba = 0.5

In [12]:
trimededges = TrimEdgesTensor(sample_extremities=trimmed.np,
                            sample_foldlabel=label,
                            input_size=(176, 235, 210, 1),
                            protective_structure=np.expand_dims(ball(3), axis=-1),
                            p=proba)

In [13]:
trimmed_arr = trimededges(skel)

In [14]:
trimmed_arr.shape

(176, 235, 210, 1)

In [15]:
print(f'nb vx before trim: {np.sum(skel!=0)}')
print(f'nb vx after trim: {np.sum(trimmed_arr!=0)}')
print(f'total vx trimmed: {np.sum(skel!=0)-np.sum(trimmed_arr!=0)}')
print(f'total vx possible to trim (without protection): {np.sum(trimmed.np)}')

nb vx before trim: 40832
nb vx after trim: 38690
total vx trimmed: 2142
total vx possible to trim (without protection): 4294


In [16]:
# check that the topoglogical values from skel are kept
print(np.unique(skel, return_counts=True))
print(np.unique(trimmed_arr, return_counts=True))
# too many top values are removed because of the dilation, but it seems ok.

(array([  0,  30,  35,  60, 100, 120], dtype=int16), array([8644768,    4527,    6669,   27811,    1724,     101]))
(array([  0.,  30.,  35.,  60., 100., 120.], dtype=float32), array([8646910,    4176,    5605,   27112,    1697,     100]))


In [17]:
print(np.unique(skel-trimmed_arr, return_counts=True))
# skel and trimmed_arr have corresponding topological values

(array([  0.,  30.,  35.,  60., 100., 120.], dtype=float32), array([8683458,     351,    1064,     699,      27,       1]))


In [176]:
ss.np[:] = trimmed_arr!=0
aims.write(ss, os.path.join(save_dir, f'L{subject}_augm_test_proba_{proba}_tops_included.nii.gz'))

In [94]:
np.sum(trimmed_arr!=0)

38012

In [95]:
np.sum(skel!=0)

40832