In [82]:
import glob
from skimage.measure import marching_cubes_lewiner, marching_cubes
import nibabel
import os
import trimesh
import numpy as np
from natsort import natsorted
import re
import pymeshlab
sorted = natsorted
from numba import njit
import cc3d
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from IPython.core.debugger import set_trace
from shutil import rmtree

from tqdm import tqdm_notebook

%config Completer.use_jedi = False

BRAIN_FOLDER = '../fcd_newdataset/brain_folder/'
MASK_FOLDER = '../fcd_newdataset/mask_folder/'


SAVE_FOLDER_INNER = '../fcd_newdataset_meshes/inner'
SAVE_FOLDER_OUTER = '../fcd_newdataset_meshes/outer/'
SAVE_FOLDER_INNER_OUTER = '../fcd_newdataset_meshes/inner_and_outer/'

REBUILT = True
if REBUILT:
    rmtree(SAVE_FOLDER_INNER)
    rmtree(SAVE_FOLDER_OUTER)
    rmtree(SAVE_FOLDER_INNER_OUTER)
    
for f in [SAVE_FOLDER_INNER, SAVE_FOLDER_OUTER, SAVE_FOLDER_INNER_OUTER]:
    os.makedirs(f, exist_ok=True)
    os.makedirs(os.path.join(f,'seg'), exist_ok=True)
    os.makedirs(os.path.join(f,'sseg'), exist_ok=True)

brain_mask_names = sorted(glob.glob(MASK_FOLDER + '*.nii.gz'))
labels = list(map(lambda x: re.findall('\d+\.\d+',x), brain_mask_names))

brain_img_names = {}
brain_mask_names = {}

for [label] in labels:
    
    brain_path = os.path.join(BRAIN_FOLDER, f'fcd_{label}.nii.gz')
    mask_path = os.path.join(MASK_FOLDER, f'mask_fcd_{label}.nii.gz')
    
    if os.path.isfile(brain_path):
        
        brain_img_names[label] = brain_path
        brain_mask_names[label] = mask_path



# TODO

 - connect falling apart components into one
 - build full w\g-matter graphs

## Util functions

In [20]:
def count_edges(faces):
    """
        input: array of faces
        output: dict, keys = id of edge, value = faces that contains this edge
    """
    edge_dict = {}
    for face in faces:
        keys = [
            tuple(sorted((face[0], face[1]))), tuple(sorted((face[1], face[2]))), tuple(sorted((face[0], face[2])))
        ]
        for key in keys:
            edge_dict[key] = edge_dict.get(key, 0) + 1
    return edge_dict

def get_vertex_labels(verts, mask):
    """
        verts: list of 3d coordinates of vertices
        mask: 3d binary array of mask
        
        labels: binary labels for each vertice
    """
    vs = verts.astype('int32')
    labels = []
    for v_ in vs:
        a, b, c = v_
        labels += [mask[a, b, c]]
    labels = np.array(labels)
    return labels

def get_edge_labels(v_labels, e_dict):
    """
        v_labels: vertex_labels
        e_dict: edge_dict
        
        e_labels: binary labels for each edge
    """
    es = list(e_dict.keys())
    e_labels = []
    for i, (v1, v2) in enumerate(es):
        e_labels += [v_labels[v1] or v_labels[v2]]
    e_labels = np.array(e_labels)
    return e_labels

@njit
def morph_3d(mask, l=5):
    """
        3d dilation for masks
    """
    d, h, w = mask.shape
    mask_padded = np.zeros((d+2*l//2, h+2*l//2, w+2*l//2))
    mask_padded[l//2: -l//2, l//2: -l//2, l//2: -l//2] = mask.copy()
    for i in range(d):
        for j in range(h):
            for k in range(w):
                mask[i, j, k] = np.all(mask_padded[i: i + l, j: j + l, k: k + l])
                
def show_slices(brain_tensor, n_slices_show=5, mask_tensor=None):
    
    fig, axes = plt.subplots(ncols=3, nrows=n_slices_show, figsize=(15,n_slices_show*5))
    X_max, Y_max, Z_max = brain_tensor.shape
    for i in range(n_slices_show):

        x_slice_pos = (X_max//(n_slices_show+2))*(i+1)
        y_slice_pos = (Y_max//(n_slices_show+2))*(i+1)
        z_slice_pos = (Z_max//(n_slices_show+2))*(i+1)

        brain_tensor_x_slice = brain_tensor[x_slice_pos,:,:]
        brain_tensor_y_slice = brain_tensor[:,y_slice_pos,:]
        brain_tensor_z_slice = brain_tensor[:,:,z_slice_pos]

        axes[i,0].imshow(brain_tensor_x_slice, 'gray')
        axes[i,1].imshow(brain_tensor_y_slice, 'gray')
        axes[i,2].imshow(brain_tensor_z_slice, 'gray')
        
        if mask is not None:
            
            mask_tensor_x_slice = mask_tensor[x_slice_pos,:,:]
            mask_tensor_y_slice = mask_tensor[:,y_slice_pos,:]
            mask_tensor_z_slice = mask_tensor[:,:,z_slice_pos]

            axes[i,0].imshow(mask_tensor_x_slice, 'jet', interpolation='none', alpha=0.7)
            axes[i,1].imshow(mask_tensor_y_slice, 'jet', interpolation='none', alpha=0.7)
            axes[i,2].imshow(mask_tensor_z_slice, 'jet', interpolation='none', alpha=0.7)

    plt.show()

##  Extracting example

In [3]:
i = 0
br_name = brain_img_names.values()[i]

brain = nibabel.load(br_name)
mask = nibabel.load(brain_mask_names[i]).get_fdata() > 0 # why cast to bool? 
brain_tensor_orig = brain.get_fdata() 

In [4]:
# plt.hist(brain_tensor_orig[brain_tensor_orig>0].flatten(), bins=100)
# plt.show()

In [5]:
# show_slices(brain_tensor_orig, n_slices_show=2, mask_tensor=brain_tensor_orig > white_matter_thresold)

In [7]:
white_matter_thresold = 300

brain_tensor_white = brain_tensor_orig > white_matter_thresold # for the inner
l = 1
t = 0.4
brain_tensor_white = gaussian_filter(brain_tensor_white.astype('float'), sigma=l)
brain_tensor_white = brain_tensor_white > t # gray matter vs white matter? 
labels_in = brain_tensor_white # where is a white matter

In [8]:
# show_slices(brain_tensor_orig, n_slices_show=2, mask_tensor=labels_in)

In [9]:
labels_out = cc3d.connected_components(labels_in, connectivity=6)
u, c = np.unique(labels_out, return_counts=True)
max_connected_label = np.argmax(c[1:]) + 1
brain_tensor_white = labels_out == max_connected_label

In [10]:
# show_slices(brain_tensor_orig, n_slices_show=2, mask_tensor=brain_tensor_white)

In [11]:
verts, faces, normals, values = marching_cubes(brain_tensor_white, 
                                               step_size=1, # 6
                                               allow_degenerate=False, 
                                               gradient_direction='ascent')

In [12]:
mesh_n = trimesh.base.Trimesh(vertices = verts, faces = faces, vertex_normals=normals, process = True)
# mesh_n.export(SAVE_FOLDER + f"{i}.obj")

In [13]:
mesh_n.show()

In [70]:
v_labels = get_vertex_labels(verts, mask)

e_dict = count_edges(faces)
e_labels = get_edge_labels(v_labels, e_dict)

# CHECKING
np_e_dict = np.array(list(e_dict.values()))
print('number of manifold edges', (np_e_dict == 2).sum())
print('number of non-manifold edges', (np_e_dict != 2).sum())

array([[     0,      1,      2],
       [     0,      3,      4],
       [     2,      3,      0],
       ...,
       [473654, 473693, 473652],
       [473694, 473693, 473654],
       [473642, 473694, 473654]])

In [93]:
e_labels

array([False, False, False, ..., False, False, False])

In [97]:
seg = e_labels + 1 # EDGE LABELS, 1 (False) and 2 (True)
sseg = np.zeros((len(e_dict), 2), dtype=np.int32) # 
sseg[np.arange(seg.size), seg-1] = 1 # 

# with open(SAVE_FOLDER + f"{i}.obj", 'a') as f:
#     for j, e in enumerate(e_dict):
#         f.write(f'\ne {e[0]} {e[1]} {seg[j]}')

# with open(SAVE_FOLDER + f"{i}.obj", 'r') as fin:
#     data = fin.read().splitlines(True)

# with open(SAVE_FOLDER + f"{i}.obj", 'w') as fout:
#     fout.writelines(data[1:])

# np.savetxt(SAVE_FOLDER + f'seg/{i}.eseg', seg)
# np.savetxt(SAVE_FOLDER + f'sseg/{i}.seseg', sseg)

# Extracting all meshes

In [76]:
# for mesh_type,SAVE_FOLDER in {'inner':SAVE_FOLDER_INNER, 
#                               'outer':SAVE_FOLDER_OUTER}.items():

#     for i, (label,br_name) in tqdm_notebook(enumerate(brain_img_names.items())):
# #         old_path1 = os.path.join(SAVE_FOLDER, f'seg/{i}.eseg')
#         old_path2 = os.path.join(SAVE_FOLDER, f'sseg/{i}.seseg')

# #         new_path1 = os.path.join(SAVE_FOLDER, f'seg/{label}.eseg')
#         new_path2 = os.path.join(SAVE_FOLDER, f'sseg/{label}.seseg')

#         try:
# #             os.rename(old_path1, new_path1)
#             os.rename(old_path2, new_path2)    
#         except:
#             pass

In [83]:
with open('./foo', 'w') as file:
    for i in range(10):
        file.write(f'a {i} \n')

In [90]:
!cat foo

a 1 
a 2 
a 3 
a 4 
a 5 
a 6 
a 7 
a 8 
a 9 


In [85]:
with open('./foo', 'r') as file:
    data = file.read().splitlines(True)

In [89]:
with open('./foo', 'w') as file:
    file.writelines(data[1:])

In [38]:
for mesh_type,SAVE_FOLDER in {'inner':SAVE_FOLDER_INNER, 
                              'outer':SAVE_FOLDER_OUTER}.items():
    
    for i, (label,br_name) in tqdm_notebook(enumerate(brain_img_names.items())):

        brain_tensor = nibabel.load(br_name).get_fdata()  > 0 
        mask_tensor = nibabel.load(brain_mask_names[label]).get_fdata() > 0
        
        if mesh_type == 'inner':
            brain_tensor = brain.get_fdata() > 300
            l = 1
            t = 0.4
            brain_tensor = gaussian_filter(brain_tensor.astype('float'), sigma=l)
            brain_tensor = brain_tensor > 0.4 # gray matter vs white matter? 
            labels_in = brain_tensor

            labels_out = cc3d.connected_components(labels_in, connectivity=6)
            u, c = np.unique(labels_out, return_counts=True)
            max_connected_label = np.argmax(c[1:]) + 1
            brain_tensor = labels_out == max_connected_label

        else:
            brain_tensor = brain.get_fdata() > 0

        verts, faces, normals, values = marching_cubes(brain_tensor, 
                                                       step_size=1, 
                                                       allow_degenerate=False, 
                                                       gradient_direction = 'ascent')
        
        v_labels = get_vertex_labels(verts, mask)
        print('vertex labels mean', v_labels.mean())
        
        e_dict = count_edges(faces)
        np_e_dict = np.array(list(e_dict.values()))
        e_labels = get_edge_labels(v_labels, e_dict)
        print('edge labels mean', v_labels.mean())

        print('number of manifold edges', (np_e_dict == 2).sum())
        print('number of non-manifold edges', (np_e_dict != 2).sum())

        mesh_n = trimesh.base.Trimesh(vertices = verts, faces = faces, process = False)
        mesh_n.export(os.path.join(SAVE_FOLDER, f"{label}.obj"))

        seg = e_labels + 1
        sseg = np.zeros((len(e_dict), 2), dtype=np.int32)
        sseg[np.arange(seg.size), seg-1] = 1

        with open(os.path.join(SAVE_FOLDER, f"{label}.obj"), 'a') as f:
            for j, e in enumerate(e_dict):
                f.write(f'\ne {e[0]} {e[1]} {seg[j]}') # why we append anyway?

        with open(os.path.join(SAVE_FOLDER, f"{label}.obj"), 'r') as fin:
            data = fin.read().splitlines(True)

        with open(os.path.join(SAVE_FOLDER, f"{label}.obj"), 'w') as fout:
            fout.writelines(data[1:]) # why we getting rid of firs element?

        np.savetxt(os.path.join(SAVE_FOLDER, f'seg/{label}.eseg'), seg)
        np.savetxt(os.path.join(SAVE_FOLDER, f'sseg/{label}.seseg'), sseg)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


0it [00:00, ?it/s]

vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold edges 0
vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold edges 0
vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold edges 0
vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold edges 0
vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold edges 0
vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold edges 0
vertex labels mean 0.0027336642674060524
edge labels mean 0.0027336642674060524
number of manifold edges 4386081
number of non-manifold 

0it [00:00, ?it/s]

vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold edges 0
vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold edges 0
vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold edges 0
vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold edges 0
vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold edges 0
vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold edges 0
vertex labels mean 0.0029385997318949956
edge labels mean 0.0029385997318949956
number of manifold edges 1420851
number of non-manifold 

##  Combining inner and outter meshes

In [39]:
def count_vertices(data):
    """
        return count of vertices, faces and edges
    """
    for j in range(len(data)):
        if data[j][0] != 'v':
            a = j
            break
    for j in range(a, len(data)):
        if data[j][0] != 'f':
            b = j
            break
    return a, b - a, len(data) - b

def shift_line(line, n):
    """
        line: line from text description of .obj file. it starts with 'v', 'e' or 'f'
        n: shift indexing by this number
        
        return: shifted line
    """
    ls = line.split(' ')
    joins = [ls[0], str(int(ls[1]) + n), str(int(ls[2]) + n)]
    try:
        if ls[0] == 'e':
            joins.append(str(int(ls[3])))
        else:
            joins.append(str(int(ls[3]) + n))
    except:
        print(line)
        print(ls)
        raise ValueError
    return " ".join(joins) + '\n'

In [80]:
label = '0.1'
with open(os.path.join(SAVE_FOLDER_OUTER, f"{label}.obj"), 'r') as fin:
    data_outter = fin.read().splitlines(True)

In [81]:
data_outter

['e 0 1 1\n',
 'e 1 2 1\n',
 'e 0 2 1\n',
 'e 0 3 1\n',
 'e 3 4 1\n',
 'e 0 4 1\n',
 'e 2 3 1\n',
 'e 4 5 1\n',
 'e 5 6 1\n',
 'e 4 6 1\n',
 'e 3 5 1\n',
 'e 6 7 1\n',
 'e 7 8 1\n',
 'e 6 8 1\n',
 'e 5 7 1\n',
 'e 7 9 1\n',
 'e 8 9 1\n',
 'e 10 11 1\n',
 'e 11 12 1\n',
 'e 10 12 1\n',
 'e 10 13 1\n',
 'e 13 14 1\n',
 'e 10 14 1\n',
 'e 12 13 1\n',
 'e 14 15 1\n',
 'e 15 16 1\n',
 'e 14 16 1\n',
 'e 13 15 1\n',
 'e 16 17 1\n',
 'e 17 18 1\n',
 'e 16 18 1\n',
 'e 15 17 1\n',
 'e 18 19 1\n',
 'e 19 20 1\n',
 'e 18 20 1\n',
 'e 17 19 1\n',
 'e 1 20 1\n',
 'e 1 19 1\n',
 'e 0 20 1\n',
 'e 0 21 1\n',
 'e 20 21 1\n',
 'e 4 22 1\n',
 'e 0 22 1\n',
 'e 21 22 1\n',
 'e 6 23 1\n',
 'e 4 23 1\n',
 'e 22 23 1\n',
 'e 8 24 1\n',
 'e 6 24 1\n',
 'e 23 24 1\n',
 'e 8 25 1\n',
 'e 9 25 1\n',
 'e 24 25 1\n',
 'e 26 27 1\n',
 'e 27 28 1\n',
 'e 26 28 1\n',
 'e 11 26 1\n',
 'e 11 28 1\n',
 'e 10 26 1\n',
 'e 10 29 1\n',
 'e 26 29 1\n',
 'e 14 30 1\n',
 'e 10 30 1\n',
 'e 29 30 1\n',
 'e 16 31 1\n',
 'e 14

In [5]:
# for i in range(len(brain_img_names)):
for label in brain_img_names:
    
    with open(os.path.join(SAVE_FOLDER_INNER, f"{label}.obj"), 'r') as fin:
        data_inner = fin.read().splitlines(True)
        
    with open(os.path.join(SAVE_FOLDER_OUTER, f"{label}.obj"), 'r') as fin:
        data_outter = fin.read().splitlines(True)
        
    a1, b1, c1 = count_vertices(data_inner)
    a2, b2, c2 = count_vertices(data_outter)

    data = []
    data += data_inner[:a1]
    data += data_outter[:a2]
    data += data_inner[a1: a1 + b1]
    data += list(map(lambda x: shift_line(x, a1), data_outter[a2: a2 + b2]))
    
#     data += data_inner[a1 + b1:]
#     data += list(map(lambda x: shift_line(x, a1), data_outter[a2 + b2:]))

    with open(SAVE_FOLDER_INNER_OUTER + f"{label}.obj", 'w') as fout:
        fout.writelines(data)
    
    seg_inner = np.loadtxt(os.path.join(SAVE_FOLDER_INNER, 'seg/' + f"{label}.eseg"))
    sseg_inner = np.loadtxt(os.path.join(SAVE_FOLDER_INNER, 'sseg/' + f"{label}.seseg"))    
    
    seg_outter = np.loadtxt(os.path.join(SAVE_FOLDER_OUTER + 'seg/' + f"{label}.eseg"))
    sseg_outter = np.loadtxt(os.path.join(SAVE_FOLDER_OUTER + 'sseg/' + f"{label}.seseg"))
        
    seg = np.concatenate([seg_inner, seg_outter])
    sseg = np.concatenate([sseg_inner, sseg_outter])
    print(i, c1 + c2, seg.mean() - 1)

    np.savetxt(os.path.join(SAVE_FOLDER_INNER_OUTER, f'seg/{label}.eseg'), seg)
    np.savetxt(os.path.join(SAVE_FOLDER_INNER_OUTER, f'sseg/{label}.seseg'), sseg)

0 16200 0.0011111111111110628
1 15465 0.00032331070158431174
2 14301 0.003985735263268397
3 13530 0.00539541759053952
4 17787 0.0026986001011974903
5 16596 0.004760183176669042
6 15633 0.0019829847118275623
7 13767 0.0010895619960775704
8 15414 0.0004541326067211138
9 16653 0.0010208370864108751
10 15084 0.0012596128347919233
11 15624 0.0017921146953405742
12 13365 0.0011971567527122584
13 15192 0.0
14 15387 0.0
15 14871 0.0038329634859795014
16 14904 0.0019457863660761188
17 13914 0.0034497628288054916
18 15330 0.004566210045662045
19 16110 0.003910614525139744
20 15690 0.001274697259400881
21 14562 0.004188985029528913
22 12708 0.0010229776518728695
23 14070 0.0010660980810235365
24 15315 0.0024159320927195083
25 13209 0.0007570595805890257
26 16401 0.0037192854094263428


## Preparing data for training MeshCNN

In [10]:
import shutil

FCD_FOLDER = '/home/ruslan/MeshCNN/datasets/fcd_seg/'

if os.path.exists(FCD_FOLDER + 'test/cache'):
    shutil.rmtree(FCD_FOLDER + 'test/cache')
if os.path.exists(FCD_FOLDER + 'train/cache'):
    shutil.rmtree(FCD_FOLDER + 'train/cache')
    
files = glob.glob(FCD_FOLDER + 'test/*')
for f in files:
    os.remove(f)
files = glob.glob(FCD_FOLDER + 'train/*')
for f in files:
    os.remove(f)

test_idxs = [1, 14]
train_idxs = [i for i in range(len(brain_img_names)) if i not in test_idxs]

for i in train_idxs:
    shutil.copyfile(SAVE_FOLDER + f"{i}.obj", FCD_FOLDER + 'train/' + f"{i}.obj")
    shutil.copyfile(SAVE_FOLDER + 'seg/'+ f"{i}.eseg", FCD_FOLDER + 'seg/' + f"{i}.eseg")
    shutil.copyfile(SAVE_FOLDER + 'sseg/'+ f"{i}.seseg", FCD_FOLDER + 'sseg/' + f"{i}.seseg")

for i in test_idxs:
    shutil.copyfile(SAVE_FOLDER + f"{i}.obj", FCD_FOLDER + 'test/' + f"{i}.obj")
    shutil.copyfile(SAVE_FOLDER + 'seg/'+ f"{i}.eseg", FCD_FOLDER + 'seg/' + f"{i}.eseg")
    shutil.copyfile(SAVE_FOLDER + 'sseg/'+ f"{i}.seseg", FCD_FOLDER + 'sseg/' + f"{i}.seseg")   

## Shifting meshes from FreeSurfer

In [None]:
N = 15
base_names = ['lh.pial', 'lh.orig', 'rh.pial', 'rh.orig']
def shift_mesh(i, base_name):
    coords, faces, meta = nibabel.freesurfer.io.read_geometry(f'{i}_1_surf/{base_name}', read_metadata=True)
    coords += meta['cras']
    m = pymeshlab.Mesh(coords, faces)
    ms = pymeshlab.MeshSet()
    ms.add_mesh(m, base_name)
    ms.save_current_mesh(f"meshlab_objects/{i}_{base_name}.obj")

for i in range(N):
    for base_name in base_names:
        shift_mesh(i, base_name)

## Checking good value for gauss sigma and threshold value

In [17]:

for i, br_name in enumerate(brain_img_names):
    print(i)
    brain = nibabel.load(br_name)
    mask = nibabel.load(brain_mask_names[i]).get_fdata() > 0
    brain_tensor = brain.get_fdata() > 95
    l = 7
    brain_tensor = gaussian_filter(brain_tensor.astype('float'), sigma=l)
    nifti = nibabel.Nifti1Image(brain_tensor.astype('float'), brain.affine)
    nibabel.save(nifti, f"{i}_gauss_g_{l}.nii.gz")
    t = 0.4
    nifti_t = nibabel.Nifti1Image((brain_tensor > t).astype('float'), brain.affine)
    nibabel.save(nifti_t, f"{i}_gauss_g_{l}_t_{t}.nii.gz") 
    

KeyboardInterrupt: 