In [48]:
import sys
import SimpleITK as sitk
import os
import numpy as np
%matplotlib inline 
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from ipywidgets import interact, fixed
import math
import glob
from scipy.interpolate import RegularGridInterpolator, interpn

from skimage import measure

import pylab as pl
import trimesh
from stl import mesh
import re

from scipy.spatial import KDTree
import pyquaternion
from scipy.spatial.transform import Rotation as R
from test_utils import rigid_transform_3D, metrics
from sklearn.neighbors import NearestNeighbors

In [43]:
#prepare data set
def centeroidnp(arr):
    """get the centroid of a point cloud"""
    length = arr.shape[0]
    sum_x = np.sum(arr[:, 0])
    sum_y = np.sum(arr[:, 1])
    sum_z = np.sum(arr[:, 2])
    return math.ceil(sum_x/length), \
            math.ceil(sum_y/length), \
            math.ceil(sum_z/length)

In [44]:
def create_7D(source_pc, source_center, target_center):
    """create a 7D pointcloud as explained in Fu et al."""
    v_s = np.zeros((len(source_pc), 7))
    for i in range(len(v_s)):
        v_ss = source_center - source_pc[i,:3]
        v_st = target_center - source_pc[i,:3]
        v_s[i,:3] = v_ss 
        v_s[i,3:6] = v_st
        v_s[i,6] = source_pc[i,3]
    return v_s


In [45]:
def set_position_and_orientation(files_deform, files_regular):
    """
    set the positioning and orientation of the deformed images so that they align
    NOTE: this should be done in the begining when the initial deformed mhd files are created
    """
    for d ,f in zip(files_deform, files_regular):
        with open(d, 'a+') as deformed:
            with open(f, 'r') as regular:
                for r in regular.readlines():
                    if "Position" in r: #or "Orientation" in r:
                        deformed.write(r)

# Make connections between vertebrae bodies and laminas with facets for XML scene in SOFA framework

In [46]:

def dist_pts(a, b):
    return np.linalg.norm(a-b)

def min_dist(points, p):
    min_=min([dist_pts(a[:3],p) for a in points])
    return [dist_pts(a[:3],p) for a in points].index(min_)

def intersect2d(X, Y):
        """
        Function to find intersection of two 2D arrays.
        Returns index of rows in X that are common to Y.
        """
        X = np.tile(X[:,:,None], (1, 1, Y.shape[0]) )
        Y = np.swapaxes(Y[:,:,None], 0, 2)
        Y = np.tile(Y, (X.shape[0], 1, 1))
        eq = np.all(np.equal(X, Y), axis = 1)
        eq = np.any(eq, axis = 1)
        return np.nonzero(eq)[0]


In [54]:
def get_bbox(position):
    """ 
    Gets the bounding box of the object defined by the given vertices.

    Arguments
    -----------
    position : list
    List with the coordinates of N points (position field of Sofa MechanicalObject).

    Returns
    ----------
    xmin, xmax, ymin, ymax, zmin, zmax : floats
    min and max coordinates of the object bounding box.
    """
    points_array = np.asarray(position)
    m = np.min(points_array, axis=0)
    xmin, ymin, zmin = m[0], m[1], m[2]

    m = np.max(points_array, axis=0)
    xmax, ymax, zmax = m[0], m[1], m[2]

    return xmin, xmax, ymin, ymax, zmin, zmax

def get_indices_in_bbox( positions, bbox ):
    """
    Get the indices of the points falling within the specified bounding box.

    Arguments
    ----------
    positions (list):
    N x 3 list of points coordinates.
    bbox (list):
    [xmin, ymin, zmin, xmax, ymax, zmax] extremes of the bounding box.

    Returns
    ----------
    indices:
    List of indices of points enclosed in the bbox.

    """
    # bbox is in the format (xmin, ymin, zmin, xmax, ...)
    assert len(bbox) == 6
    indices = []
    for i, x in enumerate( positions ):
        if x[0] >= bbox[0] and x[0] <= bbox[3] and x[1] >= bbox[1] and x[1] <= bbox[4] and x[2] >= bbox[2] and x[2] <= bbox[5]:
            indices.append( i )
    return indices
        

def print_stiff_springs(vert1,vert2, bbox_v1_v2,bbox_v2_v1, s, d):
    """
    vert1 and vert2: two adjecent vertebrae
    bbox_v1_v2 and bbox_v2_v1 are the bounding boxes representing area where the 
    springs are found on the closer sides of two adjecent vertebrae
    """
    idx1 = get_indices_in_bbox(vert1, bbox_v1_v2)[::5]
    idx2 = get_indices_in_bbox(vert2, bbox_v2_v1)[::5]
    print("SPRINGS: ")
    np.random.shuffle(idx1)
    np.random.shuffle(idx2)
    print(min(len(idx1),len(idx2)))
    print()
    for i,j in zip(idx1,idx2):
        print("{0} {1} {2} {3} {4}  ".format(i,j,s,d,dist_pts(vert1[i],vert2[j])), end='')
        
def print_positions(vert1, bbox_v1_t12):
    """
    this function is used for seting the fixed points on L1 and L5 simulating
    connection with T12 and S1 respectively
    """
    
    idx1 = get_indices_in_bbox(vert1, bbox_v1_t12)[::5]
    print("POSITIONS:")
    for i in idx1:
        print("{0} {1} {2}  ".format(vert1[i][0],vert1[i][1],vert1[i][2]), end='')
    print("Indexes for fixed constraint")
    for i,_ in enumerate(idx1):
        print(i, end=" ")
    print("SPRINGS: ")
    for i,j in enumerate(idx1):
        print("{0} {1} {2} {3} {4}  ".format(i,j,1000,10,0.00001), end='')



## uncomment the block below to print the connections and vertices for the springs in SOFA framework change paths and bounding boxes accordingly

In [64]:
vert1 = np.loadtxt("/Users/janelameski/Desktop/jane/Thesis/prethesis/DataJane/Spine8/v1.txt")[:,:3]
vert2 = np.loadtxt("/Users/janelameski/Desktop/jane/Thesis/prethesis/DataJane/Spine8/v2.txt")[:,:3]
# vert3 = np.loadtxt("/Users/janelameski/Desktop/jane/Thesis/prethesis/DataJane/Spine10/v3.txt")[:,:3]
# vert4 = np.loadtxt("/Users/janelameski/Desktop/jane/Thesis/prethesis/DataJane/Spine10/v4.txt")[:,:3]
# vert5 = np.loadtxt("/Users/janelameski/Desktop/jane/Thesis/prethesis/DataJane/Spine10/v5.txt")[:,:3] 
              
bbox_v1_t12 = [-0.135, 0.189, -0.26, -0.127, 0.258, -0.219]
bbox_v1_v2 = [-0.109, 0.189, -0.26, -0.09, 0.258, -0.219]
bbox_v2_v1 = [-0.109, 0.189, -0.25, -0.09, 0.258, -0.219]
bbox_v2_v3 = [-0.076, 0.189, -0.25, -0.02, 0.258, -0.215]
bbox_v3_v2 = [-0.076, 0.202, -0.25, -0.061, 0.258, -0.215]
bbox_v3_v4 = [-0.042, 0.202, -0.25, -0.03, 0.258, -0.215]
bbox_v4_v3 = [-0.041, 0.202, -0.247, -0.0265, 0.258, -0.215]
bbox_v4_v5 = [-0.009, 0.195, -0.247, 0.01, 0.258, -0.215]
bbox_v5_v4 = [-0.009, 0.195, -0.247, 0.01, 0.258, -0.215]
bbox_v5_s1 = [0.018, 0.195, -0.27, 0.04, 0.258, -0.215]

bbox_bone_v1_v2 = [-0.11, 0.195, -0.28, -0.088, 0.25, -0.255]
bbox_bone_v2_v1 = [-0.13, 0.195, -0.275, -0.088, 0.25, -0.255]
bbox_bone_v2_v3 = [-0.07, 0.195, -0.275, -0.06, 0.25, -0.25]
bbox_bone_v3_v2 = [-0.08, 0.195, -0.27, -0.06, 0.25, -0.25]
bbox_bone_v3_v4 = [-0.054, 0.19, -0.27, -0.027, 0.24, -0.25]
bbox_bone_v4_v3 = [-0.06, 0.19, -0.27, -0.027, 0.24, -0.25]
bbox_bone_v4_v5 = [-0.02, 0.18, -0.285, 0, 0.24, -0.26]
bbox_bone_v5_v4 = [-0.05, 0.18, -0.285, 0, 0.24, -0.26]

# print_positions(vert1, bbox_v1_t12)
# print()
# print_positions(vert5, bbox_v5_s1)
# print()
# # format(i,j,500,3,dist_pts(vert1[i],vert2[j])), end='')
# print_stiff_springs(vert1, vert2, bbox_v1_v2, bbox_v2_v1,500,3)
# print()
# print_stiff_springs(vert2, vert3, bbox_v2_v3, bbox_v3_v2,500,3)
# print()
# print_stiff_springs(vert3, vert4, bbox_v3_v4, bbox_v4_v3,500,3)
# print()
# print_stiff_springs(vert4, vert5, bbox_v4_v5, bbox_v5_v4,500,3)
# # format(i,j,8000,500,dist_pts(vert1[i],vert2[j])), end='')
# print()
# print()
print_stiff_springs(vert1, vert2, bbox_bone_v1_v2, bbox_bone_v2_v1,8000,500)
# print()
# print_stiff_springs(vert2, vert3, bbox_bone_v2_v3, bbox_bone_v3_v2,8000,500)
# print()
# print_stiff_springs(vert3, vert4, bbox_bone_v3_v4, bbox_bone_v4_v3,8000,500)
# print()
# print_stiff_springs(vert4, vert5, bbox_bone_v4_v5, bbox_bone_v5_v4,8000,500)

SPRINGS: 
461

1697 16984 8000 500 0.013638694988890976  2134 11629 8000 500 0.02264568835341508  5526 16729 8000 500 0.018569577297289253  1007 15978 8000 500 0.0182612486155794  2493 15973 8000 500 0.02774905405594936  396 16651 8000 500 0.013447646671444042  4828 16834 8000 500 0.009486727623369396  5476 12169 8000 500 0.035539713012347186  947 16053 8000 500 0.012049116191654897  1417 14159 8000 500 0.016368396408933873  1437 12570 8000 500 0.023608964420321375  4868 16681 8000 500 0.01606334961955319  4783 16023 8000 500 0.03392297747839951  4863 11311 8000 500 0.02735754376767036  1198 11351 8000 500 0.020257127165518796  598 15444 8000 500 0.021953642089639718  2069 16884 8000 500 0.015817812775475643  5905 17029 8000 500 0.010489957149578838  603 11577 8000 500 0.015822237547199205  4838 12015 8000 500 0.04204434802681566  1233 16696 8000 500 0.02856557370332338  5171 17019 8000 500 0.007063653516417697  5840 14560 8000 500 0.013964598132420415  5934 12907 8000 500 0.0240151452

## Take points for Biomechanical constraint

In [9]:
def find_nearest_point_idx(point_cloud, input_point):
    """
    Returns the index closest point to <input_point> in the input <point_cloud>
    """
    idx = np.array([np.linalg.norm(x+y+z) for (x,y,z) in np.abs(point_cloud[:,:3]-input_point[:3])]).argmin()
    return int(idx)

In [10]:
def bboxes(values):
    """
    Returns the bounding boxes from a file created manualy in ImFusion
    """
    return [np.min(values[:,0]), np.min(values[:,1]) ,np.min(values[:,2]),
      np.max(values[:,0]), np.max(values[:,1]),np.max(values[:,2])]

In [11]:
# vert1 = np.loadtxt("DataJane/Spine10/v1.txt")[:,:3]
# vert2 = np.loadtxt("DataJane/Spine10/v2.txt")[:,:3]
# vert3 = np.loadtxt("DataJane/Spine10/v3.txt")[:,:3]
# vert4 = np.loadtxt("DataJane/Spine10/v4.txt")[:,:3]
# vert5 = np.loadtxt("DataJane/Spine10/v5.txt")[:,:3]

# verts = [vert1,vert2,vert3,vert4,vert5]
# #bounding boxes of the vertebrae, produced manually and saved in the above cell for 
# #every spine separately
# bbox_v1_t12 = [-0.0239, 0.2275, -0.0591, 0.0629, 0.2764, -0.0442]
# bbox_v1_v2 = [-0.0239, 0.2275, -0.08, 0.0629, 0.272, -0.062]
# bbox_v2_v1 = [-0.0239, 0.2275, -0.087, 0.0629, 0.266, -0.062]
# bbox_v2_v3 = [-0.0239, 0.2265, -0.11, 0.0629, 0.262, -0.09]
# bbox_v3_v2 = [-0.0239, 0.2, -0.115, 0.0629, 0.253, -0.09]
# bbox_v3_v4 = [-0.0239, 0.2, -0.14, 0.0629, 0.248, -0.122]
# bbox_v4_v3 = [-0.0239, 0.2, -0.15, 0.0629, 0.245, -0.122]
# bbox_v4_v5 = [-0.0239, 0.2, -0.18, 0.0629, 0.245, -0.16]
# bbox_v5_v4 = [-0.0239, 0.19, -0.187, 0.0629, 0.24, -0.16]
# bbox_v5_s1 = [-0.0239, 0.19, -0.214, 0.0629, 0.24, -0.195]

# bboxes = [bbox_v1_t12, bbox_v1_v2,
#           bbox_v2_v1, bbox_v2_v3, 
#           bbox_v3_v2, bbox_v3_v4,
#           bbox_v4_v3, bbox_v4_v5,
#           bbox_v5_v4, bbox_v5_s1]

# len_v = 0
# with open("Spine10_biomechanical.txt", "w+") as file:
#     indices1 = get_indices_in_bbox(verts[0], bboxes[1])
#     indices2 = get_indices_in_bbox(verts[1], bboxes[2])
#     indices3 = get_indices_in_bbox(verts[1], bboxes[3])
#     indices4 = get_indices_in_bbox(verts[2], bboxes[4])
#     indices5 = get_indices_in_bbox(verts[2], bboxes[5])
#     indices6 = get_indices_in_bbox(verts[3], bboxes[6])
#     indices7 = get_indices_in_bbox(verts[3], bboxes[7])
#     indices8 = get_indices_in_bbox(verts[4], bboxes[8])

    
#     a1 = find_nearest_vector(verts[0],np.mean(verts[0][indices1], axis=0))
#     a2 = find_nearest_vector(verts[1],np.mean(verts[1][indices2], axis=0))
#     a3 = find_nearest_vector(verts[1],np.mean(verts[1][indices3], axis=0))
#     a4 = find_nearest_vector(verts[2],np.mean(verts[2][indices4], axis=0))
#     a5 = find_nearest_vector(verts[2],np.mean(verts[2][indices5], axis=0))
#     a6 = find_nearest_vector(verts[3],np.mean(verts[3][indices6], axis=0))
#     a7 = find_nearest_vector(verts[3],np.mean(verts[3][indices7], axis=0))
#     a8 = find_nearest_vector(verts[4],np.mean(verts[4][indices8], axis=0))
    
#     file.write("{0} {1} {2} {3} {4} {5} {6} {7}".format(a1,a2,a3,a4,a5,a6,a7,a8))
    #SANITY CHECK
#     file.write("{0} {1} {2}\n{3} {4} {5}\n{6} {7} {8}\n{9} {10} {11}\n{12} {13} {14}\n{15} {16} {17}\n{18} {19} {20}\n{21} {22} {23}\n"
#                                 .format(verts[0][a1][0],
#                                         verts[0][a1][1],
#                                         verts[0][a1][2],
#                                         verts[1][a2][0],
#                                         verts[1][a2][1],
#                                         verts[1][a2][2],
#                                          verts[1][a3][0],
#                                         verts[1][a3][1],
#                                         verts[1][a3][2],
#                                         verts[2][a4][0],
#                                         verts[2][a4][1],
#                                         verts[2][a4][2],
#                                          verts[2][a5][0],
#                                         verts[2][a5][1],
#                                         verts[2][a5][2],
#                                         verts[3][a6][0],
#                                         verts[3][a6][1],
#                                         verts[3][a6][2],
#                                          verts[3][a7][0],
#                                         verts[3][a7][1],
#                                         verts[3][a7][2],
#                                         verts[4][a8][0],
#                                         verts[4][a8][1],
#                                         verts[4][a8][2],
#                                                          ))

## 2 Creating the dataset
The following scripts describe how to reorder and preprocess the .vtu data, in output from the sofa framework to a
data format compatible with the network loader

### 2. 1 Reordering Files generated from sofa:
The output files from sofa are reordered such that DbFolder (i.e. where the
dataset is saved) is organized as follows: DbFolder/Spine<i>/timestamp<t> with <i> being the spine id and <t> the
simulation timestamp. Each of such folder contains 5 files, one corresponding to each vertebra

In [32]:
import os
from shutil import copy2

def extract_spine_id(filename):
    """
    Given a file, it extracts the id of the spine.

    Example 1.

    .. code-block:: console
    >> filename = <spine_folder>\sspine1_vert1_1_0.txt
    >> extract_spine_id(filename)
    spine_1

    Example 2:
    >> filename = spine1_vert1_0.txt
    >> extract_spine_id(filename)
    spine_1

    """

    filename = os.path.split(filename)[-1]

    return filename.split("_")[0]

def extract_vertebra_id(filename):
    """
    Given a file, it extracts the id of the vertebra

    Example 1.

    .. code-block:: console
    >> filename = <spine_folder>\sspine1_vert1_1_0.txt
    >> extract_vertebra_id(filename)
    vert1

    Example 2:
    >> filename = spine1_vert1_0.txt
    >> extract_vertebra_id(filename)
    vert1

    """

    filename = os.path.split(filename)[-1]
    return filename.split("_")[1][0:5]


def extract_timestamp_id(filename):
    """
    Given a file, it extracts the id of the timestamp

    Example 1.

    .. code-block:: console
    >> filename = <spine_folder>\sspine1_vert1_1_0.txt
    >> extract_timestamp_id(filename)
    1_0

    Example 2:
    >> filename = spine1_vert1_1_0.txt
    >> extract_timestamp_id(filename)
    1_0

    """
    filename = os.path.split(filename)[-1]
    spine_id = extract_spine_id(filename)
    vertebra_id = extract_vertebra_id(filename)

    timestamp_id = filename.replace(spine_id + "_" + vertebra_id, "")
    timestamp_id = timestamp_id.split(".")[0]

    return timestamp_id


def order_files_in_fold(src_filepath, dst_folder, copy=True):
    """
    Given a certain filepath, it copies (or moves it if copy==False) it in the correct folder location

    Example:

        .. code-block:: console
        >> src_filepath = VTU_output_from_SOFA\spine1_vert1_1_0.vtu
        >> dst_folder = DB_Folder
        >> patient_id = spine1
        >> timestamp_id = _1_0

        by running:
        >> order_files_in_fold(src_filepath, dst_folder, patient_id, timestamp_id, copy=True)

        The file will be copied in DB_Folder/spine_1/ts_1_0/spine1_vert1_1_0.vtu
    """

    patient_id = extract_spine_id(src_filepath)
    timestamp_id = extract_timestamp_id(src_filepath)

    src_filename = os.path.split(src_filepath)[-1]

    # minor correction in naming - the first deformation file misses the timestamp 0
    if timestamp_id.count("_") == 1:
        timestamp_id = "_0" + timestamp_id

    dst_folder = os.path.join(dst_folder, patient_id, "ts" + timestamp_id)

    if not os.path.exists(dst_folder):
        os.makedirs(dst_folder)

    dst_filepath = os.path.join(dst_folder, src_filename)

    if copy:
        copy2(src_filepath, dst_filepath)

# order_files_in_fold(src_filepath="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/obj_files 2",
#                   dst_folder="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/obj_files", copy=False)

        
def reorder_vtu_files(src_vtu_dir, dst_vtu_dir, copy=True):

    """
    Given a certain filepath containing the .vtu (or any other format) obtained by sofa with the following naming:
    spine<spineId>_vert<vertId><timestamp_id>.vtu, it reorders them in the dst_vtu_dir, such that a given file
    spine<spineId>_vert<vertId><timestamp_id>.vtu is stored in
    dst_vtu_dir/spine<spineId>/<timestamp_id>/spine<spineId>_vert<vertId><timestamp_id>.vtu

    Example:
        given a src_vtu_dir containing
        ['spine1_vert10.vtu', 'spine1_vert1_0.vtu', 'spine1_vert1_1.vtu', 'spine1_vert1_10_0.vtu',
        'spine1_vert1_10_1.vtu', ..., 'spine1_vert5_7_1.vtu', 'spine1_vert5_8_0.vtu', 'spine1_vert5_8_1.vtu',
        'spine1_vert5_9_0.vtu', 'spine1_vert5_9_1.vtu''spine2_vert10.vtu', 'spine2_vert1_10_0.vtu',
        'spine2_vert1_11_0.vtu', 'spine2_vert1_12_0.vtu', 'spine2_vert1_13_0.vtu', ... ]

    The function copies (or moves if copy == False) them

    Example:
        .. code-block:: console
            >> src_filepath = VTU_output_from_SOFA\spine1_vert1_1_0.vtu
            >> dst_folder = DB_Folder
            >> patient_id = spine1
            >> timestamp_id = _1_0

    by running:
    >> order_files_in_fold(src_filepath, dst_folder, patient_id, timestamp_id, copy=True)

    The files are reordered in the dst_vtu_dir as follows:
    dst_vtu_dir\spine1\ts10\spine1_vert10.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert20.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert30.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert40.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert50.vtu
                    ...
    dst_vtu_dir\spine1\ts10\spine1_vert1_10_0.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert2_10_0.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert3_10_0.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert4_10_0.vtu
    dst_vtu_dir\spine1\ts10\spine1_vert5_10_0.vtu
                    ...
    dst_vtu_dir\spine2\ts10\spine1_vert1_10_0.vtu
    dst_vtu_dir\spine2\ts10\spine1_vert2_10_0.vtu
    dst_vtu_dir\spine2\ts10\spine1_vert3_10_0.vtu
    dst_vtu_dir\spine2\ts10\spine1_vert4_10_0.vtu
    dst_vtu_dir\spine2\ts10\spine1_vert5_10_0.vtu
                    ...
    """

    for file in [item for item in os.listdir(src_vtu_dir) if ".vtu" in item]:
        order_files_in_fold(src_filepath=os.path.join(src_vtu_dir, file),
                            dst_folder=dst_vtu_dir,
                            copy=copy)
# reorder_vtu_files(src_vtu_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/obj_files 2",
#                   dst_vtu_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/obj_files")
reorder_vtu_files(src_vtu_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/vtu_files",
                  dst_vtu_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/reordered_vtu_files")

### 2. 2 vtu to txt
Generating .txt point cloud files from the Dataset folder containing the vtu files

In [33]:
import re
import os

def vtu2txt(src_vtu_dir, dst_txt_dir):
    """
    Converts a .vtu file in output from sofa to a .txt point cloud file by copying the points coordinates to the
    .txt file

    :param src_vtu_dir: str: path to the folder containing the .vtu files, ordered according to the reorder_vtu_files
    script (e.g. where files are saved according to:
     vtu_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert<vertId>_<timestampId>.vtu)
    :param dst_txt_dir: str: path to the folder where the .txt files will be saved, according to the usual folder
    structure dst_txt_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert<vertId>_<timestampId>.txt)
    """
    #check lines which start anything but numbers(in the vtu files the rows 
    #with numbers are the rows containing the point cloud)
    regex = re.compile("^ *<|^  +\d|^\t<|^\t\d|^\t-|^ +-")

    patient_id_list = [item for item in os.listdir(src_vtu_dir) if "spine" in item]

    for patient_id in patient_id_list:
        timestamp_list = [item for item in os.listdir(os.path.join(src_vtu_dir, patient_id)) if "ts" in item]

        for timestamp in timestamp_list:
            file_list = [item for item in os.listdir(os.path.join(src_vtu_dir, patient_id, timestamp))
                         if item.endswith(".vtu")]

            #write them in a file with same name but ending txt
            dst_folder = os.path.join(dst_txt_dir, patient_id, timestamp)
            if not os.path.exists(dst_folder):
                os.makedirs(dst_folder)

            #read all files one by one
            for file in file_list:
                with open(os.path.join(src_vtu_dir, patient_id, timestamp, file), "r") as f:
                    lines = f.readlines()
                #filter them using the regex above
                filtered = [i for i in lines if not regex.match(i)]
                dst_filename = os.path.join(dst_folder, file.replace(".vtu", ".txt"))

                with open(dst_filename, "w+") as f:
                    for l in filtered:

                        x,y,z = l.split(" ")
                        f.write("{0} {1} {2}\n".format(float(x)*1e+3,float(y)*1e+3,float(z)*1e+3))

vtu2txt(src_vtu_dir = "/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/reordered_vtu_files",
        dst_txt_dir = "/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/txt_files")

### 2. 3 vtu to obj
Generating .obj meshes from the Dataset folder containing the vtu files

In [13]:
import os
import meshio

def vtu2obj(src_vtu_dir, dst_obj_dir):
    """
    Converts a .vtu file in output from sofa to a .obj file containing the point cloud mesh.

    :param src_vtu_dir: str: path to the folder containing the .vtu files, ordered according to the reorder_vtu_files
    script (e.g. where files are saved according to:
     vtu_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert<vertId>_<timestampId>.vtu)
    :param dst_obj_dir: str: path to the folder where the .txt files will be saved, according to the usual folder
    structure dst_obj_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert<vertId>_<timestampId>.obj)
    """

    patient_id_list = [item for item in os.listdir(src_vtu_dir) if "spine" in item]

    for patient_id in patient_id_list:
        timestamp_list = [item for item in os.listdir(os.path.join(src_vtu_dir, patient_id)) if "ts" in item]

        for timestamp in timestamp_list:
            file_list = [item for item in os.listdir(os.path.join(src_vtu_dir, patient_id, timestamp))
                         if item.endswith(".vtu")]

            #write them in a file with same name but ending txt
            dst_folder = os.path.join(dst_obj_dir, patient_id, timestamp)
            if not os.path.exists(dst_folder):
                os.makedirs(dst_folder)

            for file in file_list:
                mesh_vtu = meshio.read(os.path.join(src_vtu_dir, patient_id, timestamp, file))

                mesh = meshio.Mesh(
                    mesh_vtu.points*1e3,
                    mesh_vtu.cells,
                    # Optionally provide extra data on points, cells, etc.
                    mesh_vtu.point_data,
                    # Each item in cell data must match the cells array
                    mesh_vtu.cell_data,
                    )

                dst_filename = os.path.join(dst_folder, file.replace(".vtu", ".obj"))
                mesh.write(dst_filename)

vtu2obj(src_vtu_dir = "/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/reordered_vtu_files",
        dst_obj_dir = "/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/obj_files")

### 2.4 Generating the label maps objects from the .obj files using in ImFusion.
To generate the .mhd label maps from the ImFusion we run the imfusion_workspace/create_mhd_spines_from_obj.iws
using the batch file that can be generated with the script below.
The generated .mhd labelmaps will be saved in the dst_mhd_dir, according to the previously described data structure
(i.e. (<dst_mhd_dir>\spine<spine_id>\<timestamp_id>.mhd))

In [38]:
import os
def generate_obj2mhd_batch_file(src_obj_dir, dst_mhd_dir, batch_file_path):
    """
    Generates the batch file for the imfusion_workspaces/create_mhd_spines_from_obj.iws imfusion workspace.
    Example
    .. code-block:: text
        INPUT1;INPUT2;INPUT3;INPUT4;INPUT5;OUTPUT
        obj_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert1_<timestampId>.obj;  \
            obj_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert2_<timestampId>.obj; \
            obj_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert3_<timestampId>.obj; \
            obj_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert4_<timestampId>.obj; \
            obj_dir\spine<spineId>\ts<timestampId>\spine<spineId>_vert5_<timestampId>.obj; \
            mhd_dir\spine<spineId>\ts<timestampId>.mhd; \
         ...

    """
    patient_id_list = [item for item in os.listdir(src_obj_dir) if "spine" in item]

    fid = open(batch_file_path, 'w')
    fid.write('INPUT1;INPUT2;INPUT3;INPUT4;INPUT5;OUTPUT')
    for patient_id in patient_id_list:
        timestamp_list = [item for item in os.listdir(os.path.join(src_obj_dir, patient_id)) if "ts" in item]

        for timestamp in timestamp_list:
            if timestamp.endswith(".imf"):
                continue
            print(timestamp)
            file_list = [item for item in os.listdir(os.path.join(src_obj_dir, patient_id, timestamp))
                         if item.endswith(".obj")]

            if not os.path.exists(os.path.join(dst_mhd_dir, patient_id)):
                os.makedirs(os.path.join(dst_mhd_dir, patient_id))

            mhd_filepath = os.path.join(dst_mhd_dir, patient_id, timestamp + ".mhd")

            fid.write("\n" + ";".join([os.path.normpath(os.path.join(src_obj_dir, patient_id, timestamp, file))
                                 for file in file_list]) + ";" + mhd_filepath)

generate_obj2mhd_batch_file(src_obj_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/obj_files",
                            dst_mhd_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/mhd_files",
                            batch_file_path="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/imfusion_workspaces/obj2mhd.txt")

### 2.5 Raycast the generated .mhd labelmap
The script reads the labelmaps and generate the ray-casted .mhd file. The file are saved in the save_root, according
to the previously described data_structure (<save_root>\spine<spine_id>\<timestamp_id>.mhd)

In [23]:
import os
import SimpleITK as sitk
import numpy as np

def ray_cast_slice(image, spine_id=""):
    """
    Generate a ray-casted version of the input slice, in the rows direction (arrows indicating ray-casting direction).

    -> __________________________________________________
    -> |                                                 |
    -> |                                                 |
    -> |                                                 |
    -> |                                                 |
    -> |                                                 |
    -> |                                                 |
    -> |_________________________________________________|

    Only for spine_8, the ray casting is done from right to left instead of left to right
    spine_20 and spine_21 are special case where the slice direction is the same as spine_8 but the dimmensions 
    on which we iterate are different

    __________________________________________________
    |                                                 | <-
    |                                                 | <-
    |                                                 | <-
    |                                                 | <-
    |                                                 | <-
    |                                                 | <-
    |_________________________________________________| <-
    """

    rays = np.zeros_like(np.squeeze(np.squeeze(image)))
    if spine_id == "spine8" :

        j_range = range(image.shape[0])

        for i in range(image.shape[1]):
            for j in j_range:
                if image[j, i] != 0:
                    rays[j, i] = 1
                    break
    elif spine_id == "spine20" or spine_id == "spine21" :
        j_range = range(image.shape[1]-1, 0, -1)

        for i in range(image.shape[0]):
            for j in j_range:
                if image[i, j] != 0:
                    rays[i, j] = 1
                    break
    else:

        j_range = range(image.shape[0]-1, 0, -1)
        for i in range(image.shape[1]):
            for j in j_range:
                if image[j, i] != 0:
                    rays[j, i] = 1
                    break

    return rays

def ray_cast_data(data_path, spine_id):
    """
    Ray cast all the data in the spine dataset

    :param: data_path: str: The path to the .mhd file
    :param: spine_id:str: The spine id (the raycasting direction is selected depending on the spine_id
    """

    assert data_path.endswith(".mhd")

    img_mhd = sitk.ReadImage(data_path)
    im = sitk.GetArrayFromImage(img_mhd)

    raycasted = np.zeros_like(im)

    
    if spine_id in ["spine1", "spine2", "spine3", "spine4", "spine6", "spine7", "spine10",\
                   "spine11", "spine12", "spine13", "spine14", "spine16", "spine17", "spine22",\
                   "spine15", "spine18", "spine19"]:
        for i in range(im.shape[0]):
            raycasted[i,...] = ray_cast_slice(im[i,...], spine_id)
    elif spine_id in ["spine5", "spine9", "spine8"]:

        if spine_id in ["spine5", "spine9"]:
            for i in range(im.shape[2]):
                raycasted[...,i] = ray_cast_slice(im[...,i], spine_id)#spine 5, 9
        elif spine_id in ["spine8"]:
            for i in range(im.shape[2]):
                raycasted[...,i] = ray_cast_slice(im[...,i], spine_id)#spine 8,
    else:

        for i in range(im.shape[1]):
            raycasted[:,i,:] = ray_cast_slice(im[:,i,:], spine_id)

    # Setting the position and orientation of the ray-casted image
    raycasted_img = sitk.GetImageFromArray(raycasted)
    raycasted_img.SetDirection(img_mhd.GetDirection())
    raycasted_img.SetOrigin(img_mhd.GetOrigin())

    return raycasted_img


def ray_cast_files(data_root, save_root):
    """
    Ray-casts all the files contained in the dataroot directory and saved according to:
    <data_root>\spine<spineId>\ts<timestampId>.mhd

    :param: data_root: str: The path to the data, which must be saved according to:
        <data_root>\spine<spineId>\ts<timestampId>.mhd
    :param: save_root: the path where the raycasted data will be saved, according to:
        <save_root>\spine<spineId>\raycasted_ts<timestampId>.mhd
    """

    spine_ids =[item for item in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, item)) and
                "spine" in item]

    for spine_id in spine_ids:
        spine_folder_path = os.path.join(data_root, spine_id)

        save_spine_folder = os.path.join(save_root, spine_id)
        if not os.path.exists(save_spine_folder):
            os.makedirs(save_spine_folder)

        for file in [item for item in os.listdir(spine_folder_path) if item.endswith(".mhd")]:

            raycasted_img = ray_cast_data(data_path=os.path.join(spine_folder_path, file),
                                      spine_id=spine_id)

            save_path = os.path.join(save_spine_folder, "raycasted_" + file.split(".")[0] + ".mhd")
            sitk.WriteImage(raycasted_img,  save_path)


ray_cast_files(data_root="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/mhd_files",
               save_root="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/mhd_files_raycasted")


### 2.6. Convert the generated .mhd label maps to .txt point cloud file.
mhd to .txt is done using the ImFusion workspace imfusion_workspaces/labelmap2pc.iws.
The batch file to be used with the imfusion workspace can be generated with the script below

In [31]:
import os

def generate_mhd2pc_batch_file(src_labelmaps_dir, dst_pcs_dir, batch_file_path):
    """
    Generate the script to generate the batch file to be used for launching the imfusion_workspaces/labelmap2pc.iws.

    :param: src_labelmaps_dir: str: The path to the labelmaps
    :param: dst_pcs_dir: str: The directory where the point cloud files will be saved
    :param: batch_file_path: str: The path where the (imfusion) .txt batch file will be generated

    Example:

        .. code-block:: text
            INPUTMHD;OUTPUTPC
            <src_labelmaps_dir>\spine<spineId>\raycasted_ts<timestampId>.mhd;<dst_pcs_dir>\spine<spineId>\raycasted_ts<timestampId>.txt
                                        ...

    """
    spine_ids = os.listdir(src_labelmaps_dir)

    fid = open(batch_file_path, "w")

    fid.write("INPUTMHD;OUTPUTPC")

    for spine_id in spine_ids:
        #it appeared in my pc i had to hard code it @Jane
        if spine_id == ".DS_Store":# or spine_id in ["spine11", "spine12", "spine13", "spine14", "spine16", "spine17", "spine22",\
                 # "spine15", "spine18", "spine19", "spine20","spine21"]:
            continue
        dst_spine_id_folder = os.path.join(dst_pcs_dir, spine_id)
        if not os.path.exists(dst_spine_id_folder):
            os.makedirs(dst_spine_id_folder)
        
        for file in [item for item in os.listdir(os.path.join(src_labelmaps_dir, spine_id)) if ".mhd" in item]:

            input_mhd = os.path.join(src_labelmaps_dir, spine_id, file)
            output_pc = os.path.join(dst_pcs_dir, spine_id, file.replace(".mhd", ".txt"))

            fid.write("\n" + input_mhd + ";" + output_pc)

    fid.close()

generate_mhd2pc_batch_file(src_labelmaps_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/mhd_files",
                           dst_pcs_dir="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/txt_files",
                           batch_file_path="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/imfusion_workspaces/labelmap2pc.txt")

# todo add imfusion workspace

# 2.7 Divide ray-casted vertebrae



# 2.8 Prepare spine data .txt to .npz (to be used as an input to the network)

In [39]:
import visualization_utils as utils
from scipy.spatial import KDTree
import numpy as np
import math
import os

list_files = "/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/" + "txtFiles/"

class Point:
    def __init__(self, x, y, z, color):
        """
        :param: x: float: x coordinate of the point (in mm)
        :param: y: float: y coordinate of the point (in mm)
        :param: z: float: z coordinate of the point (in mm)
        :param: color: int: integer indicating the color of the point
        """
        self.x = x
        self.y = y
        self.z = z
        self.color = color
    
    def __str__(self):
        return f"[{self.x}, {self.y}, {self.z}, {self.color}]"
    
    def _get_pt_as_array(self):
        return np.array([self.x, self.y, self.z])

    def get_closest_point_in_cloud(self, pc, filter_by_color=True):

        distances = np.array(
            [np.linalg.norm(x + y + z) for (x, y, z) in np.abs(pc[:, :3] - self._get_pt_as_array())])

        if not filter_by_color:
            idx = distances.argmin()
            return idx, pc[idx]

        if len(np.where(pc[:, 3] == self.color)) == 0:
            return None, None

        distances[pc[:, 3] != self.color] = np.max(distances) + 1

        idx = distances.argmin()

        return idx, pc[idx]


def indexes2points(idxes_list, point_cloud, color=0):
    """
    Converts a list of indexes or a index to a set of Points, extracting the point coordinates from the input
    point_cloud

    :param: idxes_list: list(int): list of input indexes
    :param: point_cloud: np.ndarray of size [Nx3] or [Nx4]. If the array size is [Nx4], the last dimension is considered
        to be the color of the point
    :param: color: if the array has size [Nx3], the the color of all the points in the point cloud is set to color
        (Default to 0).
    :return: a list of Point objects, containing the 3d coordinates and color of the input point cloud at the input
        indexes

    Example:
        .. code-block:: console
            >> idxes_list = [1, 3]
            >> point_cloud = np.ndarray([ 10, 14, 20, 1
                                        [ 10, 30, 20, 4],
                                        [ 18, 20, 18, 2],
                                        [40, 1, 20, 2])

            >> indexes2points(idxes_list, point_cloud)
            [Point(x = 10, y = 30, z = 4, color = 1), Point(x = 40, y=1, z = 20, color = 2)]
    """

    if point_cloud.shape[1] > 3:
        color = point_cloud[:, 3]
    else:
        color = np.ones([point_cloud.shape[0],])*color

    if isinstance(idxes_list, int) or isinstance(idxes_list, float):
        idxes_list = [idxes_list]

    constraints_points = []
    for item in idxes_list:
        if isinstance(item, tuple) or isinstance(item, list):
            assert all(isinstance(x, int) for x in item) or all(isinstance(x, float) for x in item)

            constraints_points.append(tuple(Point(x=point_cloud[idx, 0],
                                                  y=point_cloud[idx, 1],
                                                  z=point_cloud[idx, 2],
                                                  color=color[idx]) for idx in item))

        else:
            constraints_points.append(Point(x=point_cloud[item, 0],
                                             y=point_cloud[item, 1],
                                             z=point_cloud[item, 2],
                                             color=color[item]))

    if len(constraints_points) == 1:
        return constraints_points[0]

    return constraints_points


def points2indexes(point_list, point_cloud):
    """
    Converts a list of indexes or a points to a set of indexes, corresponding to indexes of the closest points in the
    source point cloud.

    :param: point_list: list(Point): list of input Point
    :param: point_cloud: np.ndarray of size [Nx3]. If the number of columns is higher than 3 (e.g. the input array
        has size [Nx4], then only the first 3 columns are considered)

    Example:
        .. code-block:: console
            >> idxes_list = [Point(x = 11, y = 29, z = 3, color = 1), Point(x = 41, y=1, z = 20, color = 2)]
            >> point_cloud = np.ndarray([ 10, 14, 20, 1
                                        [ 10, 30, 20, 4],
                                        [ 18, 20, 18, 2],
                                        [40, 1, 20, 2])

            >> indexes2points(idxes_list, point_cloud)
            [1, 3]
    """

    idxes_list = []

    for item in point_list:
        if isinstance(item, tuple) or isinstance(item, list):
            assert all(isinstance(x, Point) for x in item)
            idxes_list.append(tuple(p.get_closest_point_in_cloud(point_cloud)[0] for p in item))

        else:
            idxes_list.append(item.get_closest_point_in_cloud(point_cloud[0]))

    return idxes_list

def obtain_indices_raycasted_original_pc(spine_target, r_target):
    """
    Find indices in spine_target w.r.t. r_target such that they are the closest points between the two
    point clouds

    :param: spine_target: np.ndarray with size [Nx3] with the point cloud for which the closest point indexes are
        extracted If the second dimension is higher then 3, only the first 3 dimensions are considered
    :param: r_target: np.ndarray with size [Nx3] with the point cloud used to find the closest points in spine_target.
        If the second dimension is higher then 3, only the first 3 dimensions are considered

    Example:

        .. code-block:: console
            >> spine_target = np.ndarray([ 10, 14, 20, 1
                                [ 10, 30, 20, 4],
                                [ 18, 20, 18, 2],
                                [40, 1, 20, 2])

            >> r_target = np.ndarray([ 18, 21, 18, 2],
                                     [40, 1, 20, 3])

            >> obtain_indices_raycasted_original_pc(spine_target, r_target)
            [2, 3]
    """
    kdtree=KDTree(spine_target[:,:3])
    dist,points=kdtree.query(r_target[:,:3],1)

    return list(set(points))

def create_source_target_with_vertebra_label(source_pc, target_pc, vert):
    """
    source_pc: source point cloud
    target_pc: target point cloud
    vert: [1-5] for [L1-L5] vertebra respectively
    
    this function is to create source and target point clouds with label for each vertebra
    """
    
    source = np.ones((source_pc.shape[0], source_pc.shape[1]+1))
    source[:, :3]=source_pc
    source[:, 3] = source[:, 3]*vert
    target = np.ones((target_pc.shape[0], target_pc.shape[1]+1))
    target[:, :3]=target_pc
    target[:, 3]= target[:, 3]*vert
    
    return source, target

def create_source_target_flow_spine(source_pc, target_pc, vert):
    """
    source_pc: source point cloud
    target_pc: target point cloud
    vert: [1-5] for [L1-L5] vertebra respectively
    
    this function is to create source and target point clouds with 7D
    where the point clouds are centered.
    """
    
    source_pc, target_pc = create_source_target_with_vertebra_label(source_pc, target_pc, vert)

    centroid_source = centeroidnp(source_pc)
    centroid_target = centeroidnp(target_pc)
    
    source_7d = create_7D(source_pc, centroid_source, centroid_target)
    target_7d = create_7D(target_pc, centroid_source, centroid_target)
    
    flow = target_7d[:,:3]-source_7d[:,:3]
    
    return source_7d, target_7d, flow


def get_lumbar_vertebrae_dict(folder_path):
    """
    Given a timestamp folder, containing the 5 .txt files corresponding to the lumbar vertebrae, the function loads
    the vertebra and returns a dict containing the point clouds.
    Example, given the folder TestDataOrderingJane\txt_files\spine1\ts_0_0 containing the files (spine1_vert1_0.txt,
    spine1_vert2_0.txt, spine1_vert3_0.txt, spine1_vert4_0.txt, spine1_vert5_0.txt), the function returns a dict like
    {"vert1" : np.array(..), "vert2" : np.array(..), "vert3" : np.array(..), "vert4" : np.array(..),
    "vert5" : np.array(..)}, where the np.arrays are Nx3 arrays containing the 3D coordinates of the point clouds
    of each vertebra

    :param folder_path: str: The path to the folder containing the vertebra point clouds .txt files
    """

    vertebra_files = [item for item in os.listdir(folder_path) if "vert" in item]

    vertebrae_dict = dict()
    for vertebra in ["vert1", "vert2", "vert3", "vert4", "vert5"]:

        vert_file = [item for item in vertebra_files if vertebra in item]
        assert len(vert_file) == 1

        vertebrae_dict[vertebra] = np.loadtxt(os.path.join(folder_path, vert_file[0]))

    return vertebrae_dict

def load_biomechanical_constraints(spine_folder_path, source_vertebrae_dict):
    """
    Loads the biomechanical constraints and returns them as a list of tuples like
    [(c1_1, c1_2), (c2_1, c2_2), ..., (cn_1, cn_2)] where each tuple contains the Point objects
     of the "starting" point connected to the spring and the index of the "ending" point connected to the spring:

    ci_1 _/\/\/\/\_ ci_2

    :param: spine_folder_path: str: The path containing the data for a given spine dataset, where the text file
        containing the biomechanical constraints is stored
    :param: source_vertebrae_dict: The vertebrae dict containing the point cloud corresponding to each vertebra
        like:
        source_vertebrae_dict = {"vert1" : np.array(..), "vert2" : np.array(..), "vert3" : np.array(..),
            "vert4" : np.array(..), "vert5" : np.array(..)}
    """

    biomechanical_constraints_path = os.path.join(spine_folder_path,
                                                  extract_spine_id(spine_folder_path).replace("s", "S")
                                                  + "_biomechanical.txt")

    if not os.path.exists(biomechanical_constraints_path):
        return []

    # The biomechanical constraints are saved in an array on a single row, like:
    # idx_c1_1, idx_c1_2, idx_c2_1, idx_c2_2, ..., idx_cn_1, idx_cn_2
#     biomechanical_constraints_array = np.squeeze(np.loadtxt(biomechanical_constraints_path))
#     biomechanical_constraint_list = []
#     for i in range(0, biomechanical_constraints_array.shape[0] - 1, 2):
#         biomechanical_constraint_list.append(
#             (int(biomechanical_constraints_array[i]), int(biomechanical_constraints_array[i + 1]) ))
    
    
    # The biomechanical constraints are saved in an n x 2 on a rows, like:
    # idx_c1_1, idx_c1_2, 
    # idx_c2_1, idx_c2_2, 
    # ..., 
    # idx_cn_1, idx_cn_2
    biomechanical_constraints_array = np.loadtxt(biomechanical_constraints_path)
    biomechanical_constraints_array = np.array(
        [item for sublist in biomechanical_constraints_array for item in sublist])

    biomechanical_constraint_list = []
    for i in range(0, biomechanical_constraints_array.shape[0] - 1, 2):
        biomechanical_constraint_list.append(
            (int(biomechanical_constraints_array[i]), int(biomechanical_constraints_array[i + 1]) ))
        
    ######################
    
    constraints_points = []
    dict_keys = [item for item in source_vertebrae_dict.keys()]

    for i in range(0, biomechanical_constraints_array.shape[0] - 1, 2):

        vert_name = dict_keys[int(i/2)]
        next_vert_name = dict_keys[int(i/2) + 1]
        p1 = indexes2points(int(biomechanical_constraints_array[i]),
                            point_cloud=source_vertebrae_dict[vert_name],
                            color=int(i/2) + 1)

        p2 = indexes2points(int(biomechanical_constraints_array[i+1]),
                            point_cloud=source_vertebrae_dict[next_vert_name],
                            color=int(i/2) + 2)

        constraints_points.append((p1, p2))

    return constraints_points


def preprocess_spine_data(spine_path):
    """
    Preprocess the data for a given spine dataset. Specifically, for the given spine (i.e. for a given spine_id),
    it ierates over all the timestamps for that given spine.
    The function does the following.
    1. It loads the "ts0" as the timestamp of the underformed spine, and therefore of the source spine.
    2. It loads the biomechanical constraints for the given spine
    3. It iterates over all the timestamps different from t0, where the spine is considered to be deformed compared to
        t0, and for each timestamp different from ts0:
        3.a Computes the flow from the source to the target points, assuming a correspondence
            between points at different timestamps
        3.b Concatenates all the vertebrae together for both source (ts0) and target (considered timestamp),
            indicating the vertebral level in the resulting concatenated point clouds in a 4th column,
            where L1 is indicated with 1, L2 with 2, L3 with 3,
            L4 with 4, L5 with 5.
        3.c. For each given source-deformed spine pair, generate a Data dict with the following keys:
            "spine_id", "source_ts_id", "target_ts_id", "source_pc", "target_pc", "flow", "biomechanical_constraint"

    """
    spine_id = os.path.split(spine_path)[-1]

    # Get the folder containing the data relative to the un-deformed spine (source) and the list of folders
    # containing the deformed spine
    source_timestamp = "ts0"
    deformed_timestamps = [item for item in os.listdir(spine_path) if item != source_timestamp and "ts" in item]

    # Getting the source vertebrae dict, as {"vert1" : np.array(..), "vert2" : np.array(..),
    # "vert3" : np.array(..), "vert4" : np.array(..), "vert5" : np.array(..)}
    source_vertebrae = get_lumbar_vertebrae_dict(os.path.join(spine_path, source_timestamp))

    # Load the biomechanical constraints for the selected spine. biomechanical_constraints is loaded as a list
    # of tuples (Point, Point). biomechanical_constraints = [(Point, Point), (Point, Point), ..., (Point, Point)]
    # For a given tuple, the first element is the point from which the spring starts, the second point is the point
    # where the spring ends. Note that the biomechanical_constraints contain tuple defining the 3D position of the
    # constraints, and not their indexes.
    biomechanical_constraints = load_biomechanical_constraints(spine_path, source_vertebrae)

    # Iterate over all the deformed versions (folders) of the source spine and generate the data list
    data = []
    for deformed_timestamp in deformed_timestamps:

        # Getting the target vertebrae dict, as {"vert1" : np.array(..), "vert2" : np.array(..),
        # "vert3" : np.array(..), "vert4" : np.array(..), "vert5" : np.array(..)}
        deformed_vertebrae = get_lumbar_vertebrae_dict(os.path.join(spine_path, deformed_timestamp))

        # Preprocess the point clouds of each given vertebra and then concatenate the vertebrae in a single point cloud
        preprocessed_source_vertebrae = []
        preprocessed_target_vertebrae = []
        for i, vertebra in enumerate(["vert1", "vert2", "vert3", "vert4", "vert5"]):
            preprocessed_source_pc, preprocessed_target_pc = \
                create_source_target_with_vertebra_label(source_pc=source_vertebrae[vertebra],
                                                         target_pc=deformed_vertebrae[vertebra],
                                                         vert=i + 1)
            preprocessed_source_vertebrae.append(preprocessed_source_pc)
            preprocessed_target_vertebrae.append(preprocessed_target_pc)

        # Concatenating source and target vertebrae into a single spine point cloud
        preprocessed_source_spine = np.concatenate(preprocessed_source_vertebrae)
        preprocessed_target_spine = np.concatenate(preprocessed_target_vertebrae)

        # Append the generated source-target pair to the data list
        data.append({
            "spine_id": spine_id,
            "source_ts_id": source_timestamp,
            "target_ts_id": deformed_timestamp,
            "source_pc": preprocessed_source_spine,
            "target_pc": preprocessed_target_spine,
            "flow": preprocessed_target_spine[:, :3] - preprocessed_source_spine[:, :3],
            "biomechanical_constraint": biomechanical_constraints
        })

    return data


def get_ray_casted_data(data, raycasted_txt_path):
    # Loading the raycasted point clouds
    source_ray_casted_pc = np.loadtxt(os.path.join(raycasted_txt_path, data["spine_id"],
                                                   "raycasted_" + data["source_ts_id"] + ".txt"))
    target_ray_casted_pc = np.loadtxt(os.path.join(raycasted_txt_path, data["spine_id"],
                                                   "raycasted_" + data["target_ts_id"] + ".txt"))

    # Getting the flow at the biomechanical_constraints points as it will be needed later
    constraint_indexes = points2indexes(point_list=data["biomechanical_constraint"],
                                        point_cloud=data["source_pc"])

    constraint_points, constraint_flows = [], []
    for (p1_idx, p2_idx) in constraint_indexes:
        p1_colored, p2_colored = data["source_pc"][p1_idx, :], data["source_pc"][p2_idx, :]
        p1_flow, p2_flow = data["flow"][p1_idx, :], data["flow"][p2_idx, :]

        constraint_points.append((p1_colored, p2_colored))
        constraint_flows.append((p1_flow, p2_flow))

    # Getting the indexes of the points in the source data which are closest to the ray_casted source points
    source_ray_casted_idxes = obtain_indices_raycasted_original_pc(spine_target=data["source_pc"],
                                                                   r_target=source_ray_casted_pc)
    data["source_pc"] = data["source_pc"][source_ray_casted_idxes]
    data["flow"] = data["flow"][source_ray_casted_idxes]

    # Getting the indexes of the points in the target data which are closest to the ray_casted target points
    target_ray_casted_idxes = obtain_indices_raycasted_original_pc(spine_target=data["target_pc"],
                                                                   r_target=target_ray_casted_pc)
    data["target_pc"] = data["target_pc"][target_ray_casted_idxes]

    # Adding the biomechanical constraints to the source as they might be not present due to the ray-casting
    new_constraints_idx = []
    for (p1, p2), (flow1, flow2) in zip(constraint_points, constraint_flows):
        data["source_pc"] = np.concatenate((data["source_pc"], np.reshape(p1, [1, 4]), 
                                            np.reshape(p2, [1, 4])), axis=0)
        data["flow"] = np.concatenate((data["flow"], np.reshape(flow1, [1, 3]), 
                                       np.reshape(flow2, [1, 3])), axis=0)

    return data

def get_color_code(color_name):
    color_code_dict = {
        "dark_green" : "0 0.333 0 1",
        "yellow": "1 1 0 1",
        "default": "1 1 0 1"
    }

    if color_name in color_code_dict.keys():
        return color_code_dict[color_name]

    else:
        return color_code_dict["default"]

def save_for_sanity_check(data, save_dir):
    """
    Saving the generated data in imfusion workspaces at specific location
    """

    source_pc = data["source_pc"][:, :3]
    target_pc = data["target_pc"][:, :3]

    gt_target_pc = source_pc + data["flow"]

    save_folder_path = os.path.join(save_dir, data["spine_id"], data["target_ts_id"])
    if not os.path.exists(save_folder_path):
        os.makedirs(save_folder_path)

    # saving the point clouds
    # 1. Saving the full point clouds
    np.savetxt(os.path.join(save_folder_path, "full_source_pc.txt"), source_pc[:, :3])
    np.savetxt(os.path.join(save_folder_path, "full_target_pc.txt"), target_pc[:, :3])
    np.savetxt(os.path.join(save_folder_path, "full_gt_pc.txt"), gt_target_pc[:, :3])

    ps_list = [("full_source_pc", os.path.join(save_folder_path, "full_source_pc.txt"),
                get_color_code("dark_green")),
               ("full_target_pc", os.path.join(save_folder_path, "full_target_pc.txt"), 
                get_color_code("yellow")),
               ("full_gt_pc", os.path.join(save_folder_path, "full_gt_pc.txt"), 
                get_color_code("yellow"))]

    imf_tree, imf_root = utils.get_empty_imfusion_ws()

    for i, (name, path, color) in enumerate(ps_list):

        imf_root = utils.add_block_to_xml(imf_root,
                                          parent_block_name="Annotations",
                                          block_name="point_cloud_annotation",
                                          param_dict={"referenceDataUid":"data" + str(i),
                                                      "name": str(name),
                                                      "color": str(color),
                                                      "labelText":"some",
                                                      "pointSize": "2"})

        imf_root = utils.add_block_to_xml(imf_root,
                                          parent_block_name="Algorithms",
                                          block_name="load_point_cloud",
                                          param_dict={"location": path,
                                                      "outputUids": "data" + str(i)})

    # Adding the biomechanical_constraints

    for i, (c1, c2) in enumerate(data["biomechanical_constraint"]):

        c1_idx, _ = c1.get_closest_point_in_cloud(data["source_pc"], filter_by_color=True)
        c2_idx, _ = c2.get_closest_point_in_cloud(data["source_pc"], filter_by_color=True)

        p1 = data["source_pc"][c1_idx, :3]
        p2 = data["source_pc"][c2_idx, :3]
        points = " ".join([str(item) for item in p1]) + " " + " ".join([str(item) for item in p2])
        imf_root = utils.add_block_to_xml(imf_root,
                                          parent_block_name="Annotations",
                                          block_name="segment_annotation",
                                          param_dict={"name": "constraint_" + str(i+1),
                                                      "points": points})

    utils.write_on_file(imf_tree, os.path.join(save_folder_path, "imf_ws.iws"))


def generate_npz_files(src_txt_pc_path, dst_npz_path, src_raycasted_pc_path="", ray_casted=False,
                       dst_sanity_check_data=""):

    if not os.path.exists(dst_npz_path):
        os.makedirs(dst_npz_path)

    # Iterate over all the patients (spine_id) in the dataset
    for spine_id in os.listdir(src_txt_pc_path):
        if spine_id == ".DS_Store": #or spine_id in ["spine11", "spine12", "spine13", "spine14", "spine16", "spine17", "spine22",\
                   #"spine15", "spine18", "spine19", "spine20","spine21"]:
            continue
        print(spine_id)
            # Getting the dataset for the specific patient id (spine). It is a list of dict like:
            # [{"source_ts_id": ts0,
            #   "target_ts_id": ts_19_0,
            #   "source_pc": np.ndarray([])
            #   "target_pc": np.ndarray([])
            #   "biomechanical_constraint": np.ndarray([])}, ...]
        spine_data = preprocess_spine_data(os.path.join(src_txt_pc_path, spine_id))

        for data in spine_data:
            if ray_casted:
                data = get_ray_casted_data(data, src_raycasted_pc_path)

#             save_for_sanity_check(data, dst_sanity_check_data)

            # convert biomechanical_constraint to a 1-d array, putting all the constraint on a single 
            # row - this needs to be changed in future to be a list of tuple or similar format where it is clear 
            # which point belongs to the same connecting spring

            constraint_indexes = points2indexes(data["biomechanical_constraint"], data["source_pc"])

            flattened_constraints = [i for sub in constraint_indexes for i in sub]
            np.savez_compressed(file=os.path.join(dst_npz_path,
                                                  "full_" + spine_id + "_" + data["target_ts_id"] + ".npz"),
                                flow=data["flow"],
                                pc1=data["source_pc"],
                                pc2=data["target_pc"],
                                ctsPts=flattened_constraints)


generate_npz_files(src_txt_pc_path="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/txt_files",
                   dst_npz_path="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/npz_data")
#                    ray_casted=True,
#                    src_raycasted_pc_path="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/txt_files_raycasted",
#                    dst_sanity_check_data="/Users/janelameski/Desktop/jane/sofa/SOFAZIPPED/install/bin/sanity_check")

spine18
spine20
spine16
spine11
spine10
spine17
spine21
spine19
spine5
spine2
spine3
spine4
spine12
spine15
spine14
spine13
spine22
spine1
spine6
spine8
spine9
spine7


In [164]:
def aligned_data(source, target, predicted_flow, gt_flow, tre_points=None):
    new_aligned_data = []
    new_source = np.zeros_like(source)
    new_flow = np.zeros_like(gt_flow)
    new_source[:,3] = source[:,3]
    new_pred_tre = np.zeros((2*5, 4))
    tmp_vert = 0
    for vertebrae_level in range(1,6):
        
        vertebrae_idxes = np.argwhere(source[:, 3] == vertebrae_level).flatten()
        predicted_deformed_v = source[vertebrae_idxes,0:3] + predicted_flow[vertebrae_idxes]
        source_v = source[vertebrae_idxes, ...]
        predict_t = metrics.compute_rigid_transform(source_v[..., 0:3], predicted_deformed_v)
        #make homogeneous
        predict_t[-1] = 1

        transposed_s_s = source_v.T
        rigidly_deformed_source = np.matmul(predict_t, transposed_s_s)

        final_deformed_source_vert = rigidly_deformed_source[0:3,...].T
        
        gt_deformed = source[vertebrae_idxes, 0:3] + gt_flow[vertebrae_idxes]
        
        new_source[vertebrae_idxes, :3] = final_deformed_source_vert
        new_flow[vertebrae_idxes] = gt_deformed - final_deformed_source_vert
        
        if tre_points is None:
            continue
        
        vertebrae_target = tre_points[tre_points[:,-1] == vertebrae_level]
        #homogeneous
        vertebrae_target[:, -1] = 1
        
        vertebrae_target = vertebrae_target.T
        predicted_registered_target = np.matmul(predict_t, vertebrae_target)  # Nx4
        print(predicted_registered_target)
        new_pred_tre[tmp_vert:vertebrae_level*2] = predicted_registered_target.T
        tmp_vert = vertebrae_level*2
    if tre_points is not None:
        new_pred_tre[:,3] = tre_points[:,3]
#         new_pred_tre[vertebrae_level] = predicted_registered_target.T[1,:]

        
    
    if tre_points is None:    
        return new_source, target, new_flow, None
    else:
        return new_source, target, new_flow, new_pred_tre

def files_to_folder_lists(path_to_model_output):
    """
    split the names of the files created from the output of the model in
    predicted, source, target and ground truth lists
    """
    predicted_paths = []
    source_paths = []
    target_paths = []
    gt_paths = []
    for path in os.listdir(path_to_model_output):
        if path.startswith("predicted"):
            predicted_paths.append(path)
        elif path.startswith("source"):
            source_paths.append(path)
        elif path.startswith("target"):
            target_paths.append(path)
        elif path.startswith("gt"):
            gt_paths.append(path)
    
    predicted_paths = sorted(predicted_paths)
    source_paths = sorted(source_paths)
    target_paths = sorted(target_paths)
    gt_paths = sorted(gt_paths)
    
    return predicted_paths, source_paths, target_paths, gt_paths

def aligned_to_npz(predicted_paths, source_paths, target_paths, gt_paths, dst_npz_path):
    """kjh"""
    for i in range(len(predicted_paths)):
        source, target, flow, tre = aligned_data(np.loadtxt(dst_npz_path+source_paths[i])
        ,np.loadtxt(dst_npz_path+target_paths[i])
        ,np.loadtxt(dst_npz_path+predicted_paths[i])[:,:3] - \
        np.loadtxt(dst_npz_path+source_paths[i])[:,:3]
        ,np.loadtxt(dst_npz_path+gt_paths[i])[:,:3] - \
        np.loadtxt(dst_npz_path+source_paths[i])[:,:3])
        
        if gt_paths[i][gt_paths[i].find("spine") + 6].isdigit()==True:
            
#             print(os.path.join(dst_npz_path,
#                                 "aligned_" + gt_paths[i][gt_paths[i].find("spine"):gt_paths[i].find("spine") + 7] + \
#                                 "_" + gt_paths[i][gt_paths[i].find("ts"):-4]+ ".npz"))
            np.savez_compressed(file=os.path.join(dst_npz_path,
                                "aligned_raycasted_" + gt_paths[i][gt_paths[i].find("spine"):gt_paths[i].find("spine") + 7] + \
                                "_" + gt_paths[i][gt_paths[i].find("ts"):-4]+ ".npz"),
                                    flow=flow,
                                    pc1=source,
                                    pc2=target,
                                    ctsPts=[i for i in range(4095 - 8)])
            if tre is None:
                continue
            np.savetxt(os.path.join(dst_npz_path, "aligned_"+ gt_paths[i][gt_paths[i].find("spine"):gt_paths[i].find("spine") + 7] + \
                                   "facet_target.txt"), tre)
        else:
#             print(os.path.join(dst_npz_path,
#                             "aligned_" + gt_paths[i][gt_paths[i].find("spine"):gt_paths[i].find("spine") + 6] + \
#                             "_" + gt_paths[i][gt_paths[i].find("ts"):-4]+ ".npz"))
            np.savez_compressed(file=os.path.join(dst_npz_path,
                                  "aligned_raycasted_" + gt_paths[i][gt_paths[i].find("spine"):gt_paths[i].find("spine") + 6] + \
                                  "_" + gt_paths[i][gt_paths[i].find("ts"):-4]+ ".npz"),
                                    flow=flow,
                                    pc1=source,
                                    pc2=target,
                                    ctsPts=[i for i in range(4095 - 8)])
            if tre is None:
                continue
            np.savetxt(os.path.join(dst_npz_path, "aligned_"+ gt_paths[i][gt_paths[i].find("spine"):gt_paths[i].find("spine") + 7] + \
                                   "facet_target.txt"), tre)
            
# predicted_paths, source_paths, target_paths, gt_paths = files_to_folder_lists("/Users/janelameski/PycharmProjects/pythonProject/thesis/Results/flownet3d/test_result/")    
# aligned_to_npz(predicted_paths, source_paths, target_paths, gt_paths,"/Users/janelameski/PycharmProjects/pythonProject/thesis/Results/flownet3d/test_result/")

In [166]:
dst_npz_path = "/Users/janelameski/PycharmProjects/pythonProject/thesis/Results/flownet3d/test_result/" 
source_paths = "source_raycasted_spine22_ts_10_0.txt"
target_paths = "target_raycasted_spine22_ts_10_0.txt"
predicted_paths = "predicted_raycasted_spine22_ts_10_0.txt"
gt_paths = "gt_raycasted_spine22_ts_10_0.txt"

tre_path = "spine22_facet_targets.txt"
source, target, flow, tre= aligned_data(np.loadtxt(dst_npz_path+source_paths)
        ,np.loadtxt(dst_npz_path+target_paths)
        ,np.loadtxt(dst_npz_path+predicted_paths)[:,:3] - \
        np.loadtxt(dst_npz_path+source_paths)[:,:3]
        ,np.loadtxt(dst_npz_path+gt_paths)[:,:3] - \
        np.loadtxt(dst_npz_path+source_paths)[:,:3], np.loadtxt(dst_npz_path+tre_path))
a = np.loadtxt(dst_npz_path + tre_path)
print(tre)
np.savetxt(dst_npz_path + "tre_aligned.txt", tre)
data = np.load("/Users/janelameski/PycharmProjects/pythonProject/thesis/Results/flownet3d/test_result/aligned_raycasted_spine22_ts_10_0.npz")

np.savetxt("/Users/janelameski/PycharmProjects/pythonProject/thesis/Results/flownet3d/test_result/test_npz_source.txt",data["pc1"])

[[ -11.38443584   13.57171542]
 [  99.12348151  100.0757703 ]
 [-346.63540389 -346.40615103]
 [-258.00434113 -231.72973251]]
[[  21.42998978   -5.45919497]
 [  92.79521938   94.12108146]
 [-375.75309139 -377.30108662]
 [-264.29425907 -291.61031342]]
[[  30.81651566   -1.2299494 ]
 [  94.76247103   93.54808014]
 [-403.94203584 -406.08096277]
 [-284.49339485 -320.10886383]]
[[  30.58662205   -5.31760195]
 [  97.50045235   97.13990955]
 [-428.40452188 -428.23351957]
 [-303.67665482 -339.9325304 ]]
[[  35.41309722   -7.97284737]
 [ 104.91855186  104.36747592]
 [-442.98484021 -444.02889946]
 [-308.06954384 -352.92299557]]
[[ -11.38443584   99.12348151 -346.63540389    1.        ]
 [  13.57171542  100.0757703  -346.40615103    1.        ]
 [  21.42998978   92.79521938 -375.75309139    2.        ]
 [  -5.45919497   94.12108146 -377.30108662    2.        ]
 [  30.81651566   94.76247103 -403.94203584    3.        ]
 [  -1.2299494    93.54808014 -406.08096277    3.        ]
 [  30.58662205   97.