## Notebook to check VTK functionality

In [9]:
import vtk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


### Load CSD data and compute mesh

In [2]:
#Load data into VTK table object (https://github.com/Kitware/VTK/blob/master/Examples/Infovis/Python/tables3.py)
file_name = '/home/ntolley/Jones_Lab/lfp_reeb_local/data/gbarEvPyrAmpa_sweep/points/gbarEvPyrAmpa_sweep1.csv'

csv_source = vtk.vtkDelimitedTextReader()
csv_source.SetFieldDelimiterCharacters(",")
csv_source.SetHaveHeaders(False)
csv_source.SetDetectNumericColumns(True)
csv_source.SetFileName(file_name)
csv_source.Update()

T = csv_source.GetOutput()

# USER: vtkReebGraph's Build() function will rely on a vtkPolyData that
#  specifically only has triangles.  vtkUnstructuredGrid could be used
#  for tetrahedra.  
#  See http://www.vtk.org/doc/nightly/html/classvtkReebGraph.html#details
grid = vtk.vtkPolyData()

points = vtk.vtkPoints()
data = vtk.vtkFloatArray()
data.SetNumberOfComponents(1)
data.SetName("Function Value")

points.SetNumberOfPoints(T.GetNumberOfRows())
for r in range(T.GetNumberOfRows()):
    points.InsertPoint(r, T.GetColumn(0).GetValue(r),T.GetColumn(1).GetValue(r),T.GetColumn(2).GetValue(r))
    data.InsertNextValue(T.GetColumn(2).GetValue(r))
    

grid.SetPoints(points)
grid.GetPointData().SetScalars(data)

point_bounds = points.GetBounds()
x, y = int(point_bounds[1]), int(point_bounds[3])


tris = vtk.vtkCellArray()

for j in range(y) :
   for i in range(x) :
      tri = [j*x+i, j*x+i+1, (j+1)*x+i]
      tris.InsertNextCell(3, tri)
      
      tri = [j*x+i+1, (j+1)*x+i, (j+1)*x+i+1]
      tris.InsertNextCell(3, tri)

grid.SetPolys(tris)

### Simplify mesh with decimate

In [3]:
decimate = vtk.vtkDecimatePro()
decimate.SetInputData(grid)
decimate.SetTargetReduction(0.99)
decimate.PreserveTopologyOn()
decimate.Update()

decimated = vtk.vtkPolyData()
decimated.ShallowCopy(decimate.GetOutput())


data_decimated = vtk.vtkFloatArray()
data_decimated.SetNumberOfComponents(1)
data_decimated.SetName("Function Value")


### Build Reeb Graph with decimated mesh

In [4]:
reeb_graph = vtk.vtkReebGraph()
err = reeb_graph.Build(decimated, decimated.GetPointData().GetScalars())

### Iterates over all edges in graph and stores in list

In [5]:
iter_edge = vtk.vtkEdgeListIterator()
iter_edge.SetGraph(reeb_graph)

edge_list = []
while iter_edge.HasNext():
    temp_edge = iter_edge.NextGraphEdge()
    source = temp_edge.GetSource()
    target = temp_edge.GetTarget()
    edge_list.append([source, target])

### Iterates over nodes and stores x,y,z position

In [6]:
iter_vertex = vtk.vtkVertexListIterator()
iter_vertex.SetGraph(reeb_graph)

# Vertex coordinates are stored under original mesh ID's 
point_data = decimated.GetPoints().GetData()
vertex_mapping = reeb_graph.GetVertexData().GetAbstractArray(0)

vertex_list = []
while iter_vertex.HasNext():
    graph_vertex_id = iter_vertex.Next()
    mesh_vertex_id = int(vertex_mapping.GetTuple(graph_vertex_id)[0])

    vertex_pos = point_data.GetTuple(mesh_vertex_id) 
    vertex_list.append(list(vertex_pos))


### Visualize Graph

In [7]:
surface_points = np.array(pd.read_csv(file_name)) 

node_points = np.array(vertex_list)
node_connectivity = np.array(edge_list)

In [14]:
# %matplotlib qt

num_pairs = node_connectivity.shape[0]
fig = plt.figure(figsize = (8,6))
ax = plt.axes(projection='3d')

for pair in range(num_pairs):
    pairID = node_connectivity[pair]
    xdata, ydata, zdata = node_points[pairID, 0], node_points[pairID, 1], node_points[pairID, 2]

    ax.plot(xdata,ydata,zdata, 'k', linewidth=0.2)

step_size = 50
ax.plot_trisurf(surface_points[::step_size,0],surface_points[::step_size,1],surface_points[::step_size,2],cmap='viridis',edgecolor='none', alpha=0.8)
plt.show()