In [1]:
import slicer
import vtk
import slicer_utils
import utils
import trim_utils
import os
import glob
import time
import numpy as np
import json
import pandas as pd

In [2]:
def read_vtpfile(filename):
    reader = vtk.vtkXMLPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()
    return reader.GetOutput()

def write_vtpfile(polydata, filename):
    writer = vtk.vtkXMLPolyDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(polydata)
    writer.Write()

def write_vtkfile(polydata, filename):
    vtk_writer = vtk.vtkPolyDataWriter()
    vtk_writer.SetFileName(filename)
    vtk_writer.SetInputData(polydata)
    #vtk_writer.SetFileVersion(42)
    vtk_writer.Write()

def read_vtkfile(filename):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(filename)
    reader.Update()
    return reader.GetOutput()


In [3]:
def extract_pa_tree(file_name):
    # Input: file_name = 'base_path/ID.nii.gz'
    # Saves centerline, network, preprocessed, endpoints to  'base_path/vmtk_out/'
    base_name = os.path.basename(file_name)
    ID = base_name.split('.nii.gz')[0]
    base_path = file_name.split(ID)[0]
    
    save_path = os.path.join(base_path,'vmtk_out_pa')
    os.makedirs(save_path, exist_ok=True)
    mesh_file_path = f"{save_path}/{ID}_mesh.vtk"

    loadedVolume = slicer.util.loadVolume(file_name)

    labelmapVolumeNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLLabelMapVolumeNode")
    slicer.modules.volumes.logic().CreateLabelVolumeFromVolume(slicer.mrmlScene, labelmapVolumeNode, loadedVolume)

    # Step 3: Create a new segmentation node and import the labelmap into the segmentation node
    segmentationNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
    slicer.modules.segmentations.logic().ImportLabelmapToSegmentationNode(labelmapVolumeNode, segmentationNode)

    segmentIDs = vtk.vtkStringArray()
    segmentationNode.GetSegmentation().GetSegmentIDs(segmentIDs)

    #Create desired geometryImageData with overSamplingFactor
    segmentationGeometryLogic = slicer.vtkSlicerSegmentationGeometryLogic()
    segmentationGeometryLogic.SetInputSegmentationNode(segmentationNode)
    segmentationGeometryLogic.SetSourceGeometryNode(segmentationNode)
    segmentationGeometryLogic.SetOversamplingFactor(2)
    segmentationGeometryLogic.CalculateOutputGeometry()
    geometryImageData = segmentationGeometryLogic.GetOutputGeometryImageData()

    for index in range(segmentIDs.GetNumberOfValues()):
        currentSegmentID = segmentIDs.GetValue(index)
        currentSegment = segmentationNode.GetSegmentation().GetSegment(currentSegmentID)
        currentLabelmap = currentSegment.GetRepresentation("Binary labelmap")
    
        success = slicer.vtkOrientedImageDataResample.ResampleOrientedImageToReferenceOrientedImage(currentLabelmap, geometryImageData, currentLabelmap, False, True)
    
        if not success:
                print("Segment {}/{} failed to be resampled".format(segmentationNode.GetName(), currentSegmentID))

    segmentationNode.Modified()
    segmentationNode.CreateClosedSurfaceRepresentation()

    segmentId = segmentationNode.GetSegmentation().GetNthSegmentID(0)  # Get the first segment's ID (or modify for specific segments)
    vtk_mesh = segmentationNode.GetSegmentation().GetSegment(segmentId).GetRepresentation(slicer.vtkSegmentationConverter.GetClosedSurfaceRepresentationName())
    logic = utils.ExtractCenterlineLogic()
    
    # Default preprocessing parameter
    targetpts = float(5000)
    decimationAgressiveness = float(4.0)
    subDivideInput = False
   
    preProcess = logic.preprocess(vtk_mesh, targetpts, decimationAgressiveness, subDivideInput)
    
    retain_connected_component = True
    if retain_connected_component:
        connectivityFilter = vtk.vtkConnectivityFilter()
        connectivityFilter.SetInputData(preProcess)
        connectivityFilter.SetExtractionModeToLargestRegion()  # Keep only the largest connected component
        connectivityFilter.Update()
        preProcess = connectivityFilter.GetOutput()

    endPointsMarkupsNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "Centerline endpoints")
    networkPolyData = logic.extractNetwork(preProcess, endPointsMarkupsNode, computeGeometry=True)

    #########
    startPointPosition=None
    endpointPositions = logic.getEndPoints(networkPolyData, startPointPosition)
    endPointsMarkupsNode.GetDisplayNode().PointLabelsVisibilityOff()
    endPointsMarkupsNode.RemoveAllMarkups()

    
    for position in endpointPositions:
        endPointsMarkupsNode.AddControlPoint(vtk.vtkVector3d(position))

    if endPointsMarkupsNode.GetNumberOfControlPoints() > 0:
        endPointsMarkupsNode.SetNthControlPointSelected(0, False)
    ##########


    
        
    if not endPointsMarkupsNode or endPointsMarkupsNode.GetNumberOfControlPoints() < 2:
        #raise ValueError("At least two endpoints are needed for centerline extraction")
        print('*****Two end points are needed for centerline extraction!!!******')
        print(f'Skipping ID = {ID}. Best to do manually.')
        slicer.mrmlScene.RemoveNode(loadedVolume)
        slicer.mrmlScene.RemoveNode(labelmapVolumeNode)
        slicer.mrmlScene.RemoveNode(segmentationNode)
        slicer.mrmlScene.RemoveNode(endPointsMarkupsNode)

        return 0

        
    


    base_name = os.path.basename(file_name)
    ID = base_name.split('.nii.gz')[0]
    base_path = file_name.split(ID)[0]

    # Step 6: Save the data
    vtk_writer = vtk.vtkPolyDataWriter()

    #mesh_file_path = f"{save_path}/{ID}_mesh.vtk"
    #vtk_writer.SetFileName(mesh_file_path)
    #vtk_writer.SetInputData(vtk_mesh)
    #vtk_writer.SetFileVersion(42)
    #vtk_writer.Write()
    #print(f"Saved network data to {mesh_file_path}")

    # Save networkPolyData as vtk
    network_file_path = f"{save_path}/{ID}_network.vtp"
    write_vtpfile(networkPolyData,network_file_path)
    print(f"Saved network data to {network_file_path}")

    # Save endPointsMarkupsNode as fcsv (fiducial list)
    endpoints_file_path = f"{save_path}/{ID}_endpoints.fcsv"
    slicer.util.saveNode(endPointsMarkupsNode, endpoints_file_path)
    print(f"Saved endpoints to {endpoints_file_path}")

    preprocess_file_path = f"{save_path}/{ID}_meshPreProcessed.vtk"
    vtk_writer.SetFileName(preprocess_file_path)
    vtk_writer.SetInputData(preProcess)
    vtk_writer.SetFileVersion(42)
    vtk_writer.Write()
    print(f"Saved network data to {preprocess_file_path}")

    # Save centerlinePolyData as vtk
    

    do_centerlines = False
    if do_centerlines:
        centerlinePolyData, voronoiDiagramPolyData = logic.extractCenterline(preProcess, endPointsMarkupsNode)
        centerlinePropertiesTableNode = None
    
        centerline_file_path = f"{save_path}/{ID}_centerline.vtk"
        vtk_writer.SetFileName(centerline_file_path)
        vtk_writer.SetInputData(centerlinePolyData)
        vtk_writer.SetFileVersion(42)
        vtk_writer.Write()
        print(f"Saved centerline data to {centerline_file_path}")

        snap_endpoints = False
        if snap_endpoints:
            #### Snap endpoints ####
            labelmapVolumeNode = slicer.mrmlScene.AddNewNodeByClass('vtkMRMLLabelMapVolumeNode')
            slicer.modules.segmentations.logic().ExportAllSegmentsToLabelmapNode(segmentationNode, labelmapVolumeNode)
            segmentationArray = slicer.util.arrayFromVolume(labelmapVolumeNode)
            
            # Get affine matrix from segmentation labelMapVolumeNode
            vtkAff = vtk.vtkMatrix4x4()
            aff = np.eye(4)
            labelmapVolumeNode.GetIJKToRASMatrix(vtkAff)
            vtkAff.DeepCopy(aff.ravel(), vtkAff)
            new_endpoints = []
            print('check 5')
            #endpoints = np.matrix.transpose(np.array(endpointPositions))
            for endpoint in endpointPositions:
                new_endpoint = utils.robustEndPointDetection(endpoint, segmentationArray, aff, n=5)
                new_endpoints.append(new_endpoint)
        
            endPointsMarkupsNode2 = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "Centerline endpoints NEW")
            endPointsMarkupsNode2.GetDisplayNode().PointLabelsVisibilityOff()
            endPointsMarkupsNode2.RemoveAllMarkups()
            print('check 6')
            
            for position in new_endpoints:
                endPointsMarkupsNode2.AddControlPoint(vtk.vtkVector3d(position))
        
            if endPointsMarkupsNode2.GetNumberOfControlPoints() > 0:
                endPointsMarkupsNode2.SetNthControlPointSelected(0, False)
            ####################################################
    
            centerlinePolyData2, voronoiDiagramPolyData = logic.extractCenterline(preProcess, endPointsMarkupsNode2)
            centerlinePropertiesTableNode2 = None
    
            # Save centerlinePolyData as vtk
            centerline_file_path2 = f"{save_path}/{ID}_centerline2.vtk"
            vtk_writer.SetFileName(centerline_file_path2)
            vtk_writer.SetInputData(centerlinePolyData2)
            vtk_writer.SetFileVersion(42)
            vtk_writer.Write()
            print(f"Saved centerline data to {centerline_file_path2}")
    
            # Save endPointsMarkupsNode as fcsv (fiducial list)
            endpoints_file_path = f"{save_path}/{ID}_endpoints2.fcsv"
            slicer.util.saveNode(endPointsMarkupsNode2, endpoints_file_path)
            print(f"Saved endpoints to {endpoints_file_path}")
    
            slicer.mrmlScene.RemoveNode(endPointsMarkupsNode2)
        
    
    # Step 6: Remove all loaded nodes at the end
    slicer.mrmlScene.RemoveNode(loadedVolume)
    slicer.mrmlScene.RemoveNode(labelmapVolumeNode)
    slicer.mrmlScene.RemoveNode(segmentationNode)
    slicer.mrmlScene.RemoveNode(endPointsMarkupsNode)
    slicer.mrmlScene.RemoveNode(labelmapVolumeNode)


In [4]:
file_name = './example_patree.nii.gz'
extract_pa_tree(file_name)

Saved network data to ./vmtk_out_pa/example_patree_network.vtp
Saved endpoints to ./vmtk_out_pa/example_patree_endpoints.fcsv
Saved network data to ./vmtk_out_pa/example_patree_meshPreProcessed.vtk


# Trim extracted network

In [5]:
def assign_topology(dicts, polydata):
    cell_data = polydata.GetCellData()
    length_array = cell_data.GetArray('Length')
    #LPA or RPA
    # the length of the RPA is always greater than that of the LPA
    dict1, dict2 = dicts[1], dicts[2]
    cell1_id, cell2_id = dict1['cell'], dict2['cell']
    cell1, cell2 = polydata.GetCell(cell1_id), polydata.GetCell(cell2_id)
    if length_array.GetValue(cell1_id) > length_array.GetValue(cell2_id):
        #Assign 1 as RPA
        dict1['side'] = 'RPA'
        dict2['side'] = 'LPA'
    else:
        dict1['side'] = 'LPA'
        dict2['side'] = 'RPA'

    cell_to_dict = {d['cell']: d for d in dicts}
    #This loop works on the assumption that the dicts are in order of generation (which they should be if using original code for original task...)
    for dict in dicts:
        gen_num = dict['gen']
        if gen_num >= 2:
            cell_id = dict['cell']
            parent_dict = cell_to_dict.get(dict['parent'])
            parent_side = parent_dict['side']
            if parent_side in ['R', 'RPA']:
                dict['side'] = 'R'
            elif parent_side in ['L', 'LPA']:
                dict['side'] = 'L'
            else:
                raise ValueError('There is a mistake with the side somehow...')
                
    return dicts

def analyse_cells(polydata, cells, column_prefix_name = 'Gen1_R'):
    radii = [] # as long as points in cells
    tapers = [] # as long as cells
    torsions = [] # as long as points in cells (pointArray)
    torts = [] # as long as cells (cellArray)
    curvs = [] # as long as points in cells (pointArray)
    vols = [] # as long as cells
    surfs = [] # as long as cells
    lengths = [] # as long as cells (cellArray)

    cell_data = polydata.GetCellData()
    length_array = cell_data.GetArray('Length')
    tort_array = cell_data.GetArray('Tortuosity')

    point_data = polydata.GetPointData()
    curv_array = point_data.GetArray('Curvature')
    torsion_array = point_data.GetArray('Torsion')
    radius_array = point_data.GetArray('Radius')

    points = polydata.GetPoints()

    if len(cells) > 0:
        for cell_id in cells:
            cell = polydata.GetCell(cell_id)
    
            #Cell info
            vol, surf = trim_utils.calc_vol_surf(polydata, cell_id)
            vols.append(vol if not np.isnan(vol) else 1e-9)
            surfs.append(surf if not np.isnan(surf) else 1e-9)
            length = length_array.GetValue(cell_id)
            lengths.append(length if not np.isnan(length) else 1e-9)
            tort = tort_array.GetValue(cell_id)
            torts.append(tort if not np.isnan(tort) else 1e-9)
    
            point_ids = cell.GetPointIds()
            cell_rads = []
            for i in range(point_ids.GetNumberOfIds()):
                point_id = point_ids.GetId(i)
                radius = radius_array.GetValue(point_id)
                radii.append(radius if not np.isnan(radius) else 1e-9)
                cell_rads.append(radius if not np.isnan(radius) else 1e-9)
                torsion = torsion_array.GetValue(point_id)
                torsions.append(torsion if not np.isnan(torsion) else 1e-9)
                curvature = curv_array.GetValue(point_id)
                curvs.append(curvature if not np.isnan(curvature) else 1e-9)
    
            index_q1 = round(len(cell_rads) * 0.25)
            index_q3 = round(len(cell_rads) * 0.75)
            taper = 1 - cell_rads[index_q3]/cell_rads[index_q1]
            tapers.append(taper)
    else:
        vols.append(1e-9)
        surfs.append(1e-9)
        lengths.append(1e-9)
        torts.append(1e-9)
        radii.append(1e-9)
        torsions.append(1e-9)
        curvs.append(1e-9)
        tapers.append(1)

    df = pd.DataFrame()
    
    #if column_prefix_name == 'Gen4_L':
    #    print(radii)
    #    print(curvs)
    
    df[f'{column_prefix_name}_max_radius'] = [max(radii)]
    df[f'{column_prefix_name}_avg_radius'] = sum(radii)/len(radii)
    df[f'{column_prefix_name}_min_radius'] = min(radii)
    df[f'{column_prefix_name}_taper'] = sum(tapers)/len(tapers)
    df[f'{column_prefix_name}_max_curvature'] = max(curvs)
    df[f'{column_prefix_name}_avg_curvature'] = sum(curvs)/len(curvs)
    df[f'{column_prefix_name}_min_curvature'] = min(curvs)
    df[f'{column_prefix_name}_max_torsion'] = max(torsions)
    df[f'{column_prefix_name}_avg_torsion'] = sum(torsions)/len(torsions)
    df[f'{column_prefix_name}_min_torsion'] = min(torsions)
    df[f'{column_prefix_name}_total_vol'] = sum(vols)
    df[f'{column_prefix_name}_avg_vol'] = sum(vols)/len(vols)
    df[f'{column_prefix_name}_total_surf'] = sum(surfs)
    df[f'{column_prefix_name}_avg_surf'] = sum(surfs)/len(surfs)
    df[f'{column_prefix_name}_total_length'] = sum(lengths)
    df[f'{column_prefix_name}_vols_to_length'] = sum(vols)/sum(lengths)    
    df[f'{column_prefix_name}_avg_length'] = sum(lengths)/len(lengths)
    df[f'{column_prefix_name}_max_tortuosity'] = max(torts)
    df[f'{column_prefix_name}_avg_tortuosity'] = sum(torts)/len(torts)
    df[f'{column_prefix_name}_min_tortuosity'] = min(torts)

    return df 

def get_metrics(polydata, dicts):
    total_vol, total_surf = 0.0, 0.0
    for i in range(polydata.GetNumberOfCells()):
        v,s = trim_utils.calc_vol_surf(polydata, i)
        total_vol += v
        total_surf += s
        
    
    gen_to_cells = {}
    for dict in dicts:
        gen_num = dict['gen']
        cell_id = dict['cell']
        if gen_num in gen_to_cells:
            gen_to_cells[gen_num].append(cell_id)
        else:
            gen_to_cells[gen_num] = [cell_id]

    cell_to_dict = {d['cell']: d for d in dicts}


    for i in range(max(gen_to_cells) + 1):
    #for i in range(2):
        ids = gen_to_cells[i]
        if i == 0:
            full_df = analyse_cells(polydata, ids, column_prefix_name=f'Gen0')
        else:
            L_cells, R_cells = [], []
            for id in ids:
                dict = cell_to_dict[id]
                if dict['side'] in ['L','LPA']:
                    L_cells.append(id)
                elif dict['side'] in ['R', 'RPA']:
                    R_cells.append(id)
                else:
                    raise ValueError('Something has gone wrong...')
            print(f'analysing Gen {i}')
            print(L_cells)
            L_df = analyse_cells(polydata, L_cells, column_prefix_name = f'Gen{i}_L')
            R_df = analyse_cells(polydata, R_cells, column_prefix_name = f'Gen{i}_R')


            df = pd.DataFrame()
            # L vs R
            # Vol ratio, surf ratio, max radius, length ratio, average tortuosity ratio
            df[f'Gen{i}_LR_total_vol_ratio'] = [L_df[f'Gen{i}_L_total_vol'][0]/R_df[f'Gen{i}_R_total_vol'][0]]
            df[f'Gen{i}_LR_total_surf_ratio'] = L_df[f'Gen{i}_L_total_surf'][0]/R_df[f'Gen{i}_R_total_surf'][0]
            df[f'Gen{i}_LR_max_radius_ratio'] = L_df[f'Gen{i}_L_max_radius'][0]/R_df[f'Gen{i}_R_max_radius'][0]
            df[f'Gen{i}_LR_length_ratio'] = L_df[f'Gen{i}_L_total_length'][0]/R_df[f'Gen{i}_R_total_length'][0]
            df[f'Gen{i}_LR_avg_tortuosity_ratio'] = L_df[f'Gen{i}_L_avg_tortuosity'][0]/R_df[f'Gen{i}_R_avg_tortuosity'][0]
            
            # Get max/min/avgs of L and R combined
            df[f'Gen{i}_max_radius'] = [max(L_df[f'Gen{i}_L_max_radius'][0], R_df[f'Gen{i}_R_max_radius'][0])]
            df[f'Gen{i}_avg_radius'] = (L_df[f'Gen{i}_L_avg_radius'][0] + R_df[f'Gen{i}_R_avg_radius'][0])/2
            df[f'Gen{i}_min_radius'] = min(L_df[f'Gen{i}_L_min_radius'][0], R_df[f'Gen{i}_R_min_radius'][0])
            df[f'Gen{i}_taper'] = (L_df[f'Gen{i}_L_taper'][0] + R_df[f'Gen{i}_R_taper'][0])/2
            df[f'Gen{i}_max_curvature'] = max(L_df[f'Gen{i}_L_max_curvature'][0], R_df[f'Gen{i}_R_max_curvature'][0])
            df[f'Gen{i}_avg_curvature'] = (L_df[f'Gen{i}_L_avg_curvature'][0] + R_df[f'Gen{i}_R_avg_curvature'][0])/2
            df[f'Gen{i}_min_curvature'] = min(L_df[f'Gen{i}_L_min_curvature'][0], R_df[f'Gen{i}_R_min_curvature'][0])
            df[f'Gen{i}_max_torsion'] = max(L_df[f'Gen{i}_L_max_torsion'][0], R_df[f'Gen{i}_R_max_torsion'][0])
            df[f'Gen{i}_avg_torsion'] = (L_df[f'Gen{i}_L_avg_torsion'][0] + R_df[f'Gen{i}_R_avg_torsion'][0])/2
            df[f'Gen{i}_min_torsion'] = min(L_df[f'Gen{i}_L_min_torsion'][0], R_df[f'Gen{i}_R_min_torsion'][0])
            df[f'Gen{i}_total_vol'] = L_df[f'Gen{i}_L_total_vol'][0] + R_df[f'Gen{i}_R_total_vol'][0]
            df[f'Gen{i}_vol_to_whole_vol'] = (L_df[f'Gen{i}_L_total_vol'][0] + R_df[f'Gen{i}_R_total_vol'][0])/total_vol
            df[f'Gen{i}_avg_vol'] = (L_df[f'Gen{i}_L_avg_vol'][0] + R_df[f'Gen{i}_R_avg_vol'][0])/2
            df[f'Gen{i}_total_surf'] = L_df[f'Gen{i}_L_total_surf'][0] + R_df[f'Gen{i}_R_total_surf'][0]
            df[f'Gen{i}_surf_to_whole_surf'] = (L_df[f'Gen{i}_L_total_surf'][0] + R_df[f'Gen{i}_R_total_surf'][0])/total_surf
            df[f'Gen{i}_avg_surf'] = (L_df[f'Gen{i}_L_avg_surf'][0] + R_df[f'Gen{i}_R_avg_surf'][0])/2
            df[f'Gen{i}_total_length'] = L_df[f'Gen{i}_L_total_length'][0] + R_df[f'Gen{i}_R_total_length'][0]
            df[f'Gen{i}_avg_length'] = (L_df[f'Gen{i}_L_avg_length'][0] + R_df[f'Gen{i}_R_avg_length'][0])/2
            df[f'Gen{i}_vols_to_length'] = (L_df[f'Gen{i}_L_total_vol'][0] + R_df[f'Gen{i}_R_total_vol'][0])/(L_df[f'Gen{i}_L_total_length'][0] + R_df[f'Gen{i}_R_total_length'][0])     
            df[f'Gen{i}_max_tortuosity'] = max(L_df[f'Gen{i}_L_max_tortuosity'][0], R_df[f'Gen{i}_R_max_tortuosity'][0])
            df[f'Gen{i}_avg_tortuosity'] = (L_df[f'Gen{i}_L_avg_tortuosity'][0] + R_df[f'Gen{i}_R_avg_tortuosity'][0])/2
            df[f'Gen{i}_min_tortuosity'] = min(L_df[f'Gen{i}_L_min_tortuosity'][0], R_df[f'Gen{i}_R_min_tortuosity'][0])
            # Gen i-1 vs Gen i
            # i-1 vol / i vol
            # i-1 surf/ i surf
            # i-1 max radius / i max radius
            # i-1 length / i length
            # 
            df[f'Gen{i-1}_Gen{i}_total_vol_ratio'] = full_df[f'Gen{i-1}_total_vol'][0]/df[f'Gen{i}_total_vol'][0]
            df[f'Gen{i-1}_Gen{i}_total_surf_ratio'] = full_df[f'Gen{i-1}_total_surf'][0]/df[f'Gen{i}_total_surf'][0]
            df[f'Gen{i-1}_Gen{i}_avg_vol_ratio'] = full_df[f'Gen{i-1}_avg_vol'][0]/df[f'Gen{i}_avg_vol'][0]
            df[f'Gen{i-1}_Gen{i}_avg_surf_ratio'] = full_df[f'Gen{i-1}_avg_surf'][0]/df[f'Gen{i}_avg_surf'][0]
            df[f'Gen{i-1}_Gen{i}_max_radius_ratio'] = full_df[f'Gen{i-1}_max_radius'][0]/df[f'Gen{i}_max_radius'][0]
            df[f'Gen{i-1}_Gen{i}_avg_radius_ratio'] = full_df[f'Gen{i-1}_avg_radius'][0]/df[f'Gen{i}_avg_radius'][0]
            df[f'Gen{i-1}_Gen{i}_total_length_ratio'] = full_df[f'Gen{i-1}_total_length'][0]/df[f'Gen{i}_total_length'][0]
            df[f'Gen{i-1}_Gen{i}_avg_length_ratio'] = full_df[f'Gen{i-1}_avg_length'][0]/df[f'Gen{i}_avg_length'][0]

            full_df = pd.concat([full_df, L_df, R_df, df], axis=1)

    #Gen (0 and 1) vs total
    df = pd.DataFrame()
    mpalparpa_vol = full_df[f'Gen0_total_vol'][0] + full_df[f'Gen1_total_vol'][0]
    df[f'Gen01_Gen2Plus_vol'] = [mpalparpa_vol / (total_vol - mpalparpa_vol)]

    full_df = pd.concat([full_df, df], axis=1)
            
    return full_df


In [6]:
def trim_network_pa_tree(file_name, n_bifurcation):
    # Input: file_name = 'base_path/ID.nii.gz'
    # Saves 
    logic = slicer_utils.ExtractCenterlineLogic()
    
    base_name = os.path.basename(file_name)
    ID = base_name.split('.nii.gz')[0]
    base_path = file_name.split(ID)[0]
    vtk_folder = os.path.join(base_path,'vmtk_out_pa')
    save_path = os.path.join(vtk_folder, f'{n_bifurcation}_out/')
    os.makedirs(save_path, exist_ok=True)

    #
    network_path = os.path.join(vtk_folder, f'{ID}_network.vtp')
    network = read_vtpfile(network_path)

    network_save_file = os.path.join(save_path,f'{ID}_trimmed_net.vtp')
    root_save_file = os.path.join(save_path,f'{ID}_root.vtp')
    dicts_file = os.path.join(save_path, f'{ID}_dicts.json')
    
    if os.path.exists(network_save_file):
        trimmed_net = read_vtpfile(network_save_file)
        print(f'Reading file {network_save_file}')
        
        with open(dicts_file, 'r') as json_file:
            dicts = json.load(json_file)
    else:
        source_cell = trim_utils.find_root_cellID_from_network(network)
        trimmed_net, dicts = trim_utils.reconstruct_network_with_cell_data(network, source_cell, n_bifurcation)
        write_vtpfile(trimmed_net, network_save_file)

        root_net = trim_utils.extract_single_cell_as_polydata(network, source_cell)
        write_vtpfile(root_net, root_save_file)
        
        with open(dicts_file, 'w') as json_file:
                json.dump(dicts, json_file)
            

    dicts = assign_topology(dicts, network)
    
    id_df = get_metrics(network, dicts)
    id_df.index = [ID]

    calculate_centerlines = False
    if calculate_centerlines:
        preprocessed_file_name = os.path.join(vtk_folder, f'{ID}_meshPreProcessed.vtk')
        preProcess = read_vtkfile(preprocessed_file_name)
        
        startPointPosition=None
        endPointsMarkupsNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "Centerline endpoints")
        endpointPositions = logic.getEndPoints(trimmed_net, startPointPosition)
        endPointsMarkupsNode.GetDisplayNode().PointLabelsVisibilityOff()
        endPointsMarkupsNode.RemoveAllMarkups()
    
        
        for position in endpointPositions:
            endPointsMarkupsNode.AddControlPoint(vtk.vtkVector3d(position))
    
        if endPointsMarkupsNode.GetNumberOfControlPoints() > 0:
            endPointsMarkupsNode.SetNthControlPointSelected(0, False)
        ##########

        if not endPointsMarkupsNode or endPointsMarkupsNode.GetNumberOfControlPoints() < 2:
            print('*****Two end points are needed for centerline extraction!!!******')
            print(f'Skipping ID = {ID}. Best to do manually.')
            slicer.mrmlScene.RemoveNode(endPointsMarkupsNode)
    
            return 0
            
        utils_logic = utils.ExtractCenterlineLogic()
        centerlinePolyData, voronoiDiagramPolyData = utils_logic.extractCenterline(preProcess, endPointsMarkupsNode)
        centerlinePropertiesTableNode = None

        snap_endpoints = True
        if snap_endpoints:
            
            referenceVolume = slicer.util.loadVolume(file_name)
            #### Snap endpoints ####            
            # Get affine matrix from segmentation labelMapVolumeNode
            vtkAff = vtk.vtkMatrix4x4()
            aff = np.eye(4)
            #labelmapVolumeNode.GetIJKToRASMatrix(vtkAff)
            referenceVolume.GetIJKToRASMatrix(vtkAff)
            vtkAff.DeepCopy(aff.ravel(), vtkAff)
            new_endpoints = []
            print('check 5')
            #endpoints = np.matrix.transpose(np.array(endpointPositions))

            segmentationArray = slicer.util.arrayFromVolume(referenceVolume)
            for endpoint in endpointPositions:
                new_endpoint = utils.robustEndPointDetection(endpoint, segmentationArray, aff, n=5)
                new_endpoints.append(new_endpoint)
        
            endPointsMarkupsNode2 = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", "Centerline endpoints NEW")
            endPointsMarkupsNode2.GetDisplayNode().PointLabelsVisibilityOff()
            endPointsMarkupsNode2.RemoveAllMarkups()
            print('check 6')
            
            for position in new_endpoints:
                endPointsMarkupsNode2.AddControlPoint(vtk.vtkVector3d(position))
        
            if endPointsMarkupsNode2.GetNumberOfControlPoints() > 0:
                endPointsMarkupsNode2.SetNthControlPointSelected(0, False)
            ####################################################
    
            centerlinePolyData2, voronoiDiagramPolyData = utils_logic.extractCenterline(preProcess, endPointsMarkupsNode2)
            centerlinePropertiesTableNode2 = None
    
            # Save centerlinePolyData as vtk
            centerline_file_path2 = f"{save_path}/{ID}_centerline2.vtk"
            write_vtkfile(centerlinePolyData2,centerline_file_path2)
            print(f"Saved centerline data to {centerline_file_path2}")
    
            # Save endPointsMarkupsNode as fcsv (fiducial list)
            endpoints_file_path = f"{save_path}/{ID}_endpoints2.fcsv"
            slicer.util.saveNode(endPointsMarkupsNode2, endpoints_file_path)
            print(f"Saved endpoints to {endpoints_file_path}")
    
            #slicer.mrmlScene.RemoveNode(endPointsMarkupsNode2)
        centerline_file_path = f"{save_path}/{ID}_centerline2.vtk"
        write_vtkfile(centerlinePolyData,centerline_file_path)
        print(f"Saved centerline data to {centerline_file_path}")

        # Save endPointsMarkupsNode as fcsv (fiducial list)
        endpoints_file_path = f"{save_path}/{ID}_endpoints.fcsv"
        slicer.util.saveNode(endPointsMarkupsNode, endpoints_file_path)
        print(f"Saved endpoints to {endpoints_file_path}")

    return id_df

    

In [7]:
max_bifurcations = 5
id_df = trim_network_pa_tree(file_name, max_bifurcations)

Reading file ./vmtk_out_pa\5_out/example_patree_trimmed_net.vtp
analysing Gen 1
[104]
analysing Gen 2
[77, 103]
analysing Gen 3
[62, 76, 134, 135, 136, 137]
analysing Gen 4
[49, 61, 63, 64, 101, 102, 161, 162, 163, 164, 165, 166, 167, 168]
analysing Gen 5
[34, 48, 73, 74, 75, 78, 79, 80, 81, 131, 753, 132, 133, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206]


# Visualise

In [8]:
mesh = read_vtkfile('./vmtk_out_pa/example_patree_meshPreProcessed.vtk')
net = read_vtpfile('./vmtk_out_pa/example_patree_network.vtp')
root = read_vtpfile('./vmtk_out_pa/5_out/example_patree_root.vtp')
trimmed_net = read_vtpfile('./vmtk_out_pa/5_out/example_patree_trimmed_net.vtp')

In [9]:
slicer_utils.visualise_slicer(mesh,'mesh')
slicer_utils.visualise_slicer(net,'net')
slicer_utils.visualise_slicer(root,'root')
slicer_utils.visualise_slicer(trimmed_net,'trimmed_net')