In [None]:
import glob
from skimage.measure import marching_cubes_lewiner, marching_cubes
import nibabel
import os
import trimesh
import numpy as np
from natsort import natsorted
import pymeshlab
sorted = natsorted
from numba import njit
import cc3d
from scipy.ndimage import gaussian_filter

%config Completer.use_jedi = False


BRAIN_FOLDER = '/home/ruslan/mri/server_brains/'
MASK_FOLDER = '/home/ruslan/mri/server_masks/'

SAVE_FOLDER_INNER = 'new_mesh_objects_inner/'
SAVE_FOLDER_OUTTER = 'new_mesh_objects_outter/'
SAVE_FOLDER_FINAL = 'new_mesh_objects/'


for f in [SAVE_FOLDER_INNER, SAVE_FOLDER_OUTTER, SAVE_FOLDER_FINAL]:
    os.makedirs(f, exist_ok=True)
    os.makedirs(f + 'seg', exist_ok=True)
    os.makedirs(f + 'sseg', exist_ok=True)

brain_img_names = sorted(glob.glob(BRAIN_FOLDER + '*_fcd.nii.gz'))
brain_mask_names = sorted(glob.glob(MASK_FOLDER + '*.nii.gz'))

## Util functions

In [None]:
def count_edges(faces):
    """
        input: array of faces
        output: dict, keys = id of edge, value = faces that contains this edge
    """
    edge_dict = {}
    for face in faces:
        keys = [
            tuple(sorted((face[0], face[1]))), tuple(sorted((face[1], face[2]))), tuple(sorted((face[0], face[2])))
        ]
        for key in keys:
            edge_dict[key] = edge_dict.get(key, 0) + 1
    return edge_dict

def get_vertex_labels(verts, mask):
    """
        verts: list of 3d coordinates of vertices
        mask: 3d binary array of mask
        
        labels: binary labels for each vertice
    """
    vs = verts.astype('int32')
    labels = []
    for v_ in vs:
        a, b, c = v_
        labels += [mask[a, b, c]]
    labels = np.array(labels)
    return labels

def get_edge_labels(v_labels, e_dict):
    """
        v_labels: vertex_labels
        e_dict: edge_dict
        
        e_labels: binary labels for each edge
    """
    es = list(e_dict.keys())
    e_labels = []
    for i, (v1, v2) in enumerate(es):
        e_labels += [v_labels[v1] or v_labels[v2]]
    e_labels = np.array(e_labels)
    return e_labels

@njit
def morph_3d(mask, l=5):
    """
        3d dilation for masks
    """
    d, h, w = mask.shape
    mask_padded = np.zeros((d+2*l//2, h+2*l//2, w+2*l//2))
    mask_padded[l//2: -l//2, l//2: -l//2, l//2: -l//2] = mask.copy()
    for i in range(d):
        for j in range(h):
            for k in range(w):
                mask[i, j, k] = np.all(mask_padded[i: i + l, j: j + l, k: k + l]) 


##  Extracting inner and outter meshes

In [None]:
for SAVE_FOLDER in [SAVE_FOLDER_INNER, SAVE_FOLDER_OUTTER]:
    for i, br_name in enumerate(brain_img_names):
        print(i)
        brain = nibabel.load(br_name)
        mask = nibabel.load(brain_mask_names[i]).get_fdata() > 0
        brain_tensor = brain.get_fdata()
        brain_tensor = brain_tensor > 0
        if 'inner' in SAVE_FOLDER:
            brain_tensor = brain.get_fdata() > 95
            l = 7
            t = 0.4
            brain_tensor = gaussian_filter(brain_tensor.astype('float'), sigma=l)
            brain_tensor = brain_tensor > 0.4
            labels_in = brain_tensor
            labels_out = cc3d.connected_components(labels_in, connectivity=6)
            u, c = np.unique(labels_out, return_counts=True)
            max_connected_label = np.argmax(c[1:]) + 1
            brain_tensor = labels_out == max_connected_label
        else:
            brain_tensor = brain.get_fdata() > 0


        verts, faces, normals, values = marching_cubes(brain_tensor, step_size=6, allow_degenerate=False, gradient_direction = 'ascent')
        v_labels = get_vertex_labels(verts, mask)
        print('vertex labels mean', v_labels.mean())


        e_dict = count_edges(faces)
        np_e_dict = np.array(list(e_dict.values()))
        e_labels = get_edge_labels(v_labels, e_dict)
        print('edge labels mean', v_labels.mean())


        print('number of manifold edges', (np_e_dict == 2).sum())
        print('number of non-manifold edges', (np_e_dict != 2).sum())


        mesh_n = trimesh.base.Trimesh(vertices = verts, faces = faces, process = False)
        mesh_n.export(SAVE_FOLDER + f"{i}.obj")

        seg = e_labels + 1
        sseg = np.zeros((len(e_dict), 2), dtype=np.int32)
        sseg[np.arange(seg.size), seg-1] = 1

        with open(SAVE_FOLDER + f"{i}.obj", 'a') as f:
            for j, e in enumerate(e_dict):
                f.write(f'\ne {e[0]} {e[1]} {seg[j]}')
        with open(SAVE_FOLDER + f"{i}.obj", 'r') as fin:
            data = fin.read().splitlines(True)
        with open(SAVE_FOLDER + f"{i}.obj", 'w') as fout:
            fout.writelines(data[1:])

        np.savetxt(SAVE_FOLDER + f'seg/{i}.eseg', seg)
        np.savetxt(SAVE_FOLDER + f'sseg/{i}.seseg', sseg)
    

##  Combining inner and outter meshes

In [4]:
                
def count_vertices(data):
    """
        return count of vertices, faces and edges
    """
    for j in range(len(data)):
        if data[j][0] != 'v':
            a = j
            break
    for j in range(a, len(data)):
        if data[j][0] != 'f':
            b = j
            break
    return a, b - a , len(data) - b

def shift_line(line, n):
    """
        line: line from text description of .obj file. it starts with 'v', 'e' or 'f'
        n: shift indexing by this number
        
        return: shifted line
    """
    ls = line.split(' ')
    joins = [ls[0], str(int(ls[1]) + n), str(int(ls[2]) + n)]
    try:
        if ls[0] == 'e':
            joins.append(str(int(ls[3])))
        else:
            joins.append(str(int(ls[3]) + n))
    except:
        print(line)
        print(ls)
        raise ValueError
    return " ".join(joins) + '\n'

In [5]:

for i in range(len(brain_img_names)):
    with open(SAVE_FOLDER_INNER + f"{i}.obj", 'r') as fin:
        data_inner = fin.read().splitlines(True)
    with open(SAVE_FOLDER_OUTTER + f"{i}.obj", 'r') as fin:
        data_outter = fin.read().splitlines(True)
    a1, b1, c1 = count_vertices(data_inner)
    a2, b2, c2 = count_vertices(data_outter)

    data = []
    data += data_inner[:a1]
    data += data_outter[:a2]
    data += data_inner[a1: a1 + b1]
    data += list(map(lambda x: shift_line(x, a1), data_outter[a2: a2 + b2]))
#     data += data_inner[a1 + b1:]
#     data += list(map(lambda x: shift_line(x, a1), data_outter[a2 + b2:]))

    
    with open(SAVE_FOLDER_FINAL + f"{i}.obj", 'w') as fout:
        fout.writelines(data)
    
    seg_inner = np.loadtxt(SAVE_FOLDER_INNER + 'seg/' + f"{i}.eseg")
    sseg_inner = np.loadtxt(SAVE_FOLDER_INNER + 'sseg/' + f"{i}.seseg")    
    
    seg_outter = np.loadtxt(SAVE_FOLDER_OUTTER + 'seg/' + f"{i}.eseg")
    sseg_outter = np.loadtxt(SAVE_FOLDER_OUTTER + 'sseg/' + f"{i}.seseg")
        
    seg = np.concatenate([seg_inner, seg_outter])
    sseg = np.concatenate([sseg_inner, sseg_outter])
    print(i, c1 + c2, seg.mean() - 1)

    np.savetxt(SAVE_FOLDER_FINAL + f'seg/{i}.eseg', seg)
    np.savetxt(SAVE_FOLDER_FINAL + f'sseg/{i}.seseg', sseg)
    

0 16200 0.0011111111111110628
1 15465 0.00032331070158431174
2 14301 0.003985735263268397
3 13530 0.00539541759053952
4 17787 0.0026986001011974903
5 16596 0.004760183176669042
6 15633 0.0019829847118275623
7 13767 0.0010895619960775704
8 15414 0.0004541326067211138
9 16653 0.0010208370864108751
10 15084 0.0012596128347919233
11 15624 0.0017921146953405742
12 13365 0.0011971567527122584
13 15192 0.0
14 15387 0.0
15 14871 0.0038329634859795014
16 14904 0.0019457863660761188
17 13914 0.0034497628288054916
18 15330 0.004566210045662045
19 16110 0.003910614525139744
20 15690 0.001274697259400881
21 14562 0.004188985029528913
22 12708 0.0010229776518728695
23 14070 0.0010660980810235365
24 15315 0.0024159320927195083
25 13209 0.0007570595805890257
26 16401 0.0037192854094263428


## Preparing data for training MeshCNN

In [10]:
import shutil

FCD_FOLDER = '/home/ruslan/MeshCNN/datasets/fcd_seg/'

if os.path.exists(FCD_FOLDER + 'test/cache'):
    shutil.rmtree(FCD_FOLDER + 'test/cache')
if os.path.exists(FCD_FOLDER + 'train/cache'):
    shutil.rmtree(FCD_FOLDER + 'train/cache')
    
files = glob.glob(FCD_FOLDER + 'test/*')
for f in files:
    os.remove(f)
files = glob.glob(FCD_FOLDER + 'train/*')
for f in files:
    os.remove(f)

test_idxs = [1, 14]
train_idxs = [i for i in range(len(brain_img_names)) if i not in test_idxs]

for i in train_idxs:
    shutil.copyfile(SAVE_FOLDER + f"{i}.obj", FCD_FOLDER + 'train/' + f"{i}.obj")
    shutil.copyfile(SAVE_FOLDER + 'seg/'+ f"{i}.eseg", FCD_FOLDER + 'seg/' + f"{i}.eseg")
    shutil.copyfile(SAVE_FOLDER + 'sseg/'+ f"{i}.seseg", FCD_FOLDER + 'sseg/' + f"{i}.seseg")

for i in test_idxs:
    shutil.copyfile(SAVE_FOLDER + f"{i}.obj", FCD_FOLDER + 'test/' + f"{i}.obj")
    shutil.copyfile(SAVE_FOLDER + 'seg/'+ f"{i}.eseg", FCD_FOLDER + 'seg/' + f"{i}.eseg")
    shutil.copyfile(SAVE_FOLDER + 'sseg/'+ f"{i}.seseg", FCD_FOLDER + 'sseg/' + f"{i}.seseg")   




## Shifting meshes from FreeSurfer

In [None]:
N = 15
base_names = ['lh.pial', 'lh.orig', 'rh.pial', 'rh.orig']
def shift_mesh(i, base_name):
    coords, faces, meta = nibabel.freesurfer.io.read_geometry(f'{i}_1_surf/{base_name}', read_metadata=True)
    coords += meta['cras']
    m = pymeshlab.Mesh(coords, faces)
    ms = pymeshlab.MeshSet()
    ms.add_mesh(m, base_name)
    ms.save_current_mesh(f"meshlab_objects/{i}_{base_name}.obj")

for i in range(N):
    for base_name in base_names:
        shift_mesh(i, base_name)

## Checking good value for gauss sigma and threshold value

In [17]:

for i, br_name in enumerate(brain_img_names):
    print(i)
    brain = nibabel.load(br_name)
    mask = nibabel.load(brain_mask_names[i]).get_fdata() > 0
    brain_tensor = brain.get_fdata() > 95
    l = 7
    brain_tensor = gaussian_filter(brain_tensor.astype('float'), sigma=l)
    nifti = nibabel.Nifti1Image(brain_tensor.astype('float'), brain.affine)
    nibabel.save(nifti, f"{i}_gauss_g_{l}.nii.gz")
    t = 0.4
    nifti_t = nibabel.Nifti1Image((brain_tensor > t).astype('float'), brain.affine)
    nibabel.save(nifti_t, f"{i}_gauss_g_{l}_t_{t}.nii.gz")
    

KeyboardInterrupt: 