## Notebook to check VTK functionality

In [1]:
import vtk
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


### Load CSD data and compute mesh

In [2]:
#Load data into VTK table object (https://github.com/Kitware/VTK/blob/master/Examples/Infovis/Python/tables3.py)
file_name = '/home/ntolley/Jones_Lab/lfp_reeb_local/data/gbarEvPyrAmpa_sweep/points/gbarEvPyrAmpa_sweep10.csv'

#Load data values into grid
np_points = np.array(pd.read_csv(file_name))
point_bounds = np.max(np_points,axis=0) + 1
int(point_bounds[0])
topography = np.zeros((int(point_bounds[0]),int(point_bounds[1])))

for p in range(np_points.shape[0]):
    topography[int(np_points[p,0]),int(np_points[p,1])] = np_points[p,2]

points = vtk.vtkPoints()
triangles = vtk.vtkCellArray()

data = vtk.vtkFloatArray()
data.SetNumberOfComponents(1)
data.SetName("Function Value")

# Build the meshgrid manually
count = 0
for i in range(topography.shape[0]-1):
    for j in range(topography.shape[1]-1):
        z1 = topography[i][j]
        z2 = topography[i][j + 1]
        z3 = topography[i + 1][j]

        # Triangle 1
        points.InsertNextPoint(i, j, z1)
        points.InsertNextPoint(i, (j + 1), z2)
        points.InsertNextPoint((i + 1), j, z3)

        data.InsertNextValue(z1)
        data.InsertNextValue(z2)
        data.InsertNextValue(z3)

        triangle = vtk.vtkTriangle()
        triangle.GetPointIds().SetId(0, count)
        triangle.GetPointIds().SetId(1, count + 1)
        triangle.GetPointIds().SetId(2, count + 2)

        triangles.InsertNextCell(triangle)

        z1 = topography[i][j + 1]
        z2 = topography[i + 1][j + 1]
        z3 = topography[i + 1][j]

        # Triangle 2
        points.InsertNextPoint(i, (j + 1), z1)
        points.InsertNextPoint((i + 1), (j + 1), z2)
        points.InsertNextPoint((i + 1), j, z3)

        data.InsertNextValue(z1)
        data.InsertNextValue(z2)
        data.InsertNextValue(z3)
        
        triangle = vtk.vtkTriangle()
        triangle.GetPointIds().SetId(0, count + 3)
        triangle.GetPointIds().SetId(1, count + 4)
        triangle.GetPointIds().SetId(2, count + 5)

        count += 6

        triangles.InsertNextCell(triangle)

# Create a polydata object
trianglePolyData = vtk.vtkPolyData()

# Add the geometry and topology to the polydata
trianglePolyData.SetPoints(points)
trianglePolyData.SetPolys(triangles)
trianglePolyData.GetPointData().SetScalars(data)

# Clean the polydata so that the edges are shared !
cleanPolyData = vtk.vtkCleanPolyData()
cleanPolyData.SetInputData(trianglePolyData)
cleanPolyData.Update()
dir(cleanPolyData)

grid = cleanPolyData.GetOutput()


### Simplify mesh with decimate

In [4]:
decimate = vtk.vtkDecimatePro()
decimate.SetInputData(grid)
decimate.SetTargetReduction(0.99)
decimate.PreserveTopologyOn()
decimate.Update()

decimated = vtk.vtkPolyData()
decimated.ShallowCopy(decimate.GetOutput())


data_decimated = vtk.vtkFloatArray()
data_decimated.SetNumberOfComponents(1)
data_decimated.SetName("Function Value")


### Build Reeb Graph with decimated mesh

In [5]:
reeb_graph = vtk.vtkReebGraph()
err = reeb_graph.Build(decimated, decimated.GetPointData().GetScalars())
err

0

### Iterates over all edges in graph and stores in list

In [6]:
iter_edge = vtk.vtkEdgeListIterator()
iter_edge.SetGraph(reeb_graph)

edge_list = []
while iter_edge.HasNext():
    temp_edge = iter_edge.NextGraphEdge()
    source = temp_edge.GetSource()
    target = temp_edge.GetTarget()
    edge_list.append([source, target])

### Iterates over nodes and stores x,y,z position

In [7]:
iter_vertex = vtk.vtkVertexListIterator()
iter_vertex.SetGraph(reeb_graph)

# Vertex coordinates are stored under original mesh ID's 
point_data = decimated.GetPoints().GetData()
vertex_mapping = reeb_graph.GetVertexData().GetAbstractArray(0)

vertex_list = []
while iter_vertex.HasNext():
    graph_vertex_id = iter_vertex.Next()
    mesh_vertex_id = int(vertex_mapping.GetTuple(graph_vertex_id)[0])

    vertex_pos = point_data.GetTuple(mesh_vertex_id) 
    vertex_list.append(list(vertex_pos))


### Visualize Graph

In [8]:
surface_points = np.array(pd.read_csv(file_name)) 

node_points = np.array(vertex_list)
node_connectivity = np.array(edge_list)

In [9]:
%matplotlib qt

num_pairs = node_connectivity.shape[0]
fig = plt.figure(figsize = (8,6))
ax = plt.axes(projection='3d')

for pair in range(num_pairs):
    pairID = node_connectivity[pair]
    xdata, ydata, zdata = node_points[pairID, 0], node_points[pairID, 1], node_points[pairID, 2]

    ax.plot(xdata,ydata,zdata, 'k', linewidth=0.2)

step_size = 50
ax.plot_trisurf(surface_points[::step_size,0],surface_points[::step_size,1],surface_points[::step_size,2],cmap='viridis',edgecolor='none', alpha=0.8)
plt.show()

In [10]:
G = nx.Graph()
G.add_edges_from(edge_list)
nx.is_tree(G)

True