In [None]:
import os
import sys
import vtk
import numpy as np
import pandas as pd
import numpy as np
from vtk.util import numpy_support as vn
from sklearn.decomposition import PCA
import seaborn as sns
import open3d as o3d
sys.path.append("/Users/williamengelhardt/OneDrive/Yale_Lab_Summer/updatedCode")
import Visualize_Linear_Transform as vis

## Packages

In [None]:
def readFile(inputPath,fileName):
    inputFile = os.path.join(inputPath, fileName)
    if not os.path.exists(inputFile):
        return fileName
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(inputFile)
    reader.Update()
    polydata = reader.GetOutput()
    return polydata

In [None]:
def writeFile(combined,output_file_name):
    writer = vtk.vtkPolyDataWriter()
    writer.SetInputData(combined)
    writer.SetFileName(output_file_name)
    writer.Update()

In [None]:
#allows to store VTKs for later use
class store(object):
    def __init__(self, number):
        self.number = number

In [None]:
#change input path, output path doesn't matter unless you are exporting the csv
INPUT_PATH = "/Users/williamengelhardt/OneDrive/Yale_Lab_Summer/data/nonlinear_from_ref"
OUTPUT_PATH = "/Users/williamengelhardt/OneDrive/Yale_Lab_Summer/updatedCode/Machine Learning"

# PCA

## Diff btwn Base and Ref

In [None]:
#read in reference
totalRef = readFile(INPUT_PATH,"a_17_total_trans_pre.vtk")
pointsTotalRef = vn.vtk_to_numpy(totalRef.GetPoints().GetData())
flrRef = readFile(INPUT_PATH,"a_17_flr_trans_pre.vtk")
pointsFLRRef = vn.vtk_to_numpy(flrRef.GetPoints().GetData())

In [None]:
#get number of total and flr files
totalNum = 0
flrNum = 0
for f in os.listdir(INPUT_PATH):
    if f.split('_')[2] == 'total':
        totalNum += 1
    elif f.split('_')[2] == 'flr':
        flrNum += 1
    else:
        print("Filename %s is not in correct format" %f)
        
print("Total sections found: %i" %totalNum)
print("FLR sections found: %i" %flrNum)

In [None]:
VTK_total = []
VTK_flr = []
pointsTotalDiffFlat = np.zeros((totalNum,np.size(pointsTotalRef)))
pointsFLRDiffFlat = np.zeros((flrNum,np.size(pointsFLRRef)))
patientTotalNum = np.zeros(totalNum, dtype='<U4')
patientFLRNum = np.zeros(flrNum, dtype='<U4')
totalCount = 0
flrCount = 0

for filename in os.listdir(INPUT_PATH):
    
    myVTK = readFile(INPUT_PATH,filename)
    nameArray = filename.split('_')
    
    if nameArray[1] == 17 and nameArray[0] == 'a':
        print("Skipping reference image")
        continue
        
    print('Looking at institution %s, patient %s, %s'
          %(nameArray[0], nameArray[1], nameArray[2]))
    
    if nameArray[2] == 'total':
        VTK_total.append(store(myVTK))
        pointsBase = vn.vtk_to_numpy(myVTK.GetPoints().GetData())
        pointsDiff = pointsBase - pointsTotalRef
        pointsTotalDiffFlat[totalCount] = np.ndarray.flatten(pointsDiff)
        patientTotalNum[totalCount] = nameArray[0]+'_'+nameArray[1]
        totalCount += 1
    elif nameArray[2] == 'flr':
        VTK_flr.append(store(myVTK))
        pointsBase1 = vn.vtk_to_numpy(myVTK.GetPoints().GetData())
        pointsDiff1 = pointsBase1 - pointsFLRRef
        pointsFLRDiffFlat[flrCount] = np.ndarray.flatten(pointsDiff1)
        patientFLRNum[flrCount] = nameArray[0]+'_'+nameArray[1]
        flrCount += 1
    
    
print('Finished reading %d patients this run' % (totalCount+flrCount))
pointsTotalDiffFlat = pointsTotalDiffFlat[:totalCount]
pointsFLRDiffFlat = pointsFLRDiffFlat[:flrCount]
patientTotalNum = patientTotalNum[:totalCount]
patientFLRNum = patientFLRNum[:flrCount]


In [None]:
#change this to switch looking at total or flr
useTotal = False #false means use flr
if useTotal:
    pointsFlat = pointsTotalDiffFlat
    patientNum = patientTotalNum
    csvName = "pca_total.csv"
    refPoints = pointsTotalRef
    ref = totalRef
    count = totalCount
    VTK_base = VTK_total
else:
    pointsFlat = pointsFLRDiffFlat
    patientNum = patientFLRNum
    csvName = "pca_flr.csv"
    refPoints = pointsFLRRef
    ref = flrRef
    count = flrCount
    VTK_base = VTK_flr

In [None]:
data = pointsFlat.T
pca = PCA()
pc = pca.fit_transform(data)
print(pca.explained_variance_ratio_)

In [None]:
p = pca.explained_variance_ratio_
sum(p[0:71])

In [None]:
pc_df = pd.DataFrame(pc[::,:3],columns = ["PC1","PC2","PC3"])
sns.lmplot( x="PC1", y="PC2",data=pc_df, fit_reg=False, legend=True, scatter_kws={"s": 0.5})

In [None]:
loadings = pca.components_.T
columnsPC = []
for i in range(np.size(loadings,axis=0)):
    columnsPC.append("PC%s"%(i+1))
loading_matrix = pd.DataFrame(loadings,index=patientNum,columns=columnsPC)
loading_matrix

In [None]:
#uncomment to write csv file
outputCSV = os.path.join(OUTPUT_PATH, csvName)
loading_matrix.to_csv(outputCSV)

# Visualize

In [None]:
#change weights of PCs
PCs = np.zeros((count,pca.n_components_))
VTK_PC_array = []
pc_avg_array = []
numObjects = 70 #should be even number, increase for video resolution
startVal = 0
for n in range(startVal,5):
    PCs[n] = loading_matrix.iloc[::,n].values
    avg = np.mean(PCs[n])
    var = 2*np.std(PCs[n]) #multiplier is num of std devs
    rangeMin = avg - var
    rangeMax = avg + var
    weights = np.zeros(pca.n_components_) #all other PCs are 0
    dx = (rangeMax - rangeMin)/numObjects
    weightVar = np.zeros(numObjects)
    VTK_PCs = []
    for j in range(numObjects):
        weights[n] = rangeMin+dx*j
        points = np.matmul(pc,weights)
        recDiff = np.reshape(points,np.shape(refPoints))
    
        varPC = vtk.vtkPolyData()
        varPC.DeepCopy(ref)
        for i in range(np.size(recDiff,axis=0)):
            refCoord = np.zeros(3)
            varPC.GetPoints().GetPoint(i,refCoord)
            pointI = refCoord + recDiff[i]
            varPC.GetPoints().SetPoint(i,pointI)
        
        VTK_PCs.append(store(varPC))
    VTK_PC_array.append(store(VTK_PCs))
    
    #create average PC ref shape
    PCs[n] = loading_matrix.iloc[::,n].values
    avg = np.mean(PCs[n])
    weights[n] = avg
    points = np.matmul(pc,weights)
    recDiff = np.reshape(points,np.shape(refPoints))

    pcAvg = vtk.vtkPolyData()
    pcAvg.DeepCopy(ref)
    for i in range(np.size(recDiff,axis=0)):
        refCoord = np.zeros(3)
        pcAvg.GetPoints().GetPoint(i,refCoord)
        pointI = refCoord + recDiff[i]
        pcAvg.GetPoints().SetPoint(i,pointI)
    pc_avg_array.append(store(pcAvg))
    print("Stored PC #%d: dx = %s" %((n+1), dx))
print("Done")

In [None]:
#change this number to the PC you want to look at, then run the next two cells
PC_num = 4

In [None]:
#color change in shapes
pcAvg = pc_avg_array[PC_num-1].number
maxVal = 12 #decrease to increase intensity
for j in range(numObjects):
    folPoints = vn.vtk_to_numpy(VTK_PC_array[PC_num-1].number[j].number.GetPoints().GetData())
    basePoints = vn.vtk_to_numpy(pcAvg.GetPoints().GetData())
    diffPoints = folPoints - basePoints
    dist = (diffPoints[::,0]**2+diffPoints[::,1]**2+diffPoints[::,2]**2)**0.5
    vis.visChangeInShapes(VTK_PC_array[PC_num-1].number[j].number, pcAvg, dist, maxVal, "white")

In [None]:
#view first, middle, and last of PC variation
renWin, iren = vis.createWindow()

#blue is grew, red is shrunk
camPos = (1,0,-3)
ren = vis.createRen(renWin,0,0.33)
vis.addActors(ren,VTK_PC_array[PC_num-1].number[0].number, pos=camPos)
ren1 = vis.createRen(renWin,0.33,0.66)
vis.addActors(ren1,VTK_PC_array[PC_num-1].number[int(numObjects/2)-1].number, pos=camPos)
ren2 = vis.createRen(renWin,0.66,1)
vis.addActors(ren2,VTK_PC_array[PC_num-1].number[numObjects-1].number, pos=camPos)


vis.startVis(renWin,iren)

In [None]:
PC_num = 1
#look at PC video, sometimes crashes
class vtkTimerCallback():
    def __init__(self):
        self.timer_count = 0
        self.index = 0
        self.up = 1
        
    def execute(self,obj,event):
        mapper.SetInputData(VTK_PC_array[array_num].number[self.index].number)
        actor.SetMapper(mapper)
        actor.GetProperty().SetDiffuseColor(color[0],color[1],color[2])
        iren = obj
        iren.GetRenderWindow().Render()
        self.timer_count += 1
        self.index += self.up
        if (self.index == numObjects):
            self.up = -1
            self.index = numObjects-1
        elif (self.index == -1):
            self.up = 1
            self.index = 0
array_num = PC_num - startVal - 1
color = vis.getColor("darkGreen")
#Create a mapper and actor
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputData(VTK_PC_array[array_num].number[0].number)
actor = vtk.vtkActor()
actor.GetProperty().SetDiffuseColor(color[0],color[1],color[2])
actor.SetMapper(mapper)
prop = actor.GetProperty()

# Setup a renderer, render window, and interactor
renWin, iren = vis.createWindow()
ren = vis.createRen(renWin,0,1)

#Add the actor to the scene
ren.AddActor(actor)
vis.addLabels(ren,"PC number %d" %(PC_num))

# Initialize must be called prior to creating timer events.
iren.Initialize()
renWin.Render()

# Sign up to receive TimerEvent
cb = vtkTimerCallback()
cb.actor = actor
cb.color = color
iren.AddObserver(vtk.vtkCommand.TimerEvent, cb.execute)
cb.timerId = iren.CreateRepeatingTimer(5);

#start the interaction and timer
iren.Start()


## Recreate Shapes

In [None]:
#recover all shapes
VTK_recBase = []
for n in range(np.size(pca.components_,axis=0)):
    patientWeights = loading_matrix.iloc[n].values
    points = np.matmul(pc,patientWeights)
    recDiff = np.reshape(points,(len(refPoints),3))
    
    recBase = vtk.vtkPolyData()
    recBase.DeepCopy(ref)
    for i in range(np.size(recDiff,axis=0)):
        refCoord = np.zeros(3)
        recBase.GetPoints().GetPoint(i,refCoord)
        pointI = refCoord + recDiff[i]
        recBase.GetPoints().SetPoint(i,pointI)
    
    VTK_recBase.append(store(recBase))

In [None]:
renWin, iren = vis.createWindow()
ren = vis.createRen(renWin,0.25,0.75)
#vis.addActors(ren,VTK_PC_array[0].number[0].number,vis.getColor("yellow"),0.9)
vis.addActors(ren,VTK_PC_array[0].number[69].number,vis.getColor("yellow"),0.9)
vis.addActors(ren,ref, vis.getColor("blue"), 0.9)
    
vis.startVis(renWin,iren)

In [None]:
#check that the reconstruction worked
for i in range(1):
    renWin, iren = vis.createWindow()
    ren = vis.createRen(renWin,0.25,0.75)
    vis.addActors(ren,VTK_recBase[i].number,vis.getColor("yellow"),0.6)
    vis.addActors(ren,VTK_base[i].number,vis.getColor("blue"),0.6)
    vis.addLabels(ren,"Patient number: %s" %patientNum[i])
    
    vis.startVis(renWin,iren)

In [None]:
i=0
renWin, iren = vis.createWindow()
ren = vis.createRen(renWin,0,0.33)
vis.addActors(ren,VTK_recBase[i].number,[1,0,0],1)
vis.addLabels(ren,"Re-constructed shape")

ren2 = vis.createRen(renWin,0.33,0.66)
vis.addActors(ren2,VTK_base[i].number,[0,1,1],1)
vis.addLabels(ren2,"Original shape")

ren3 = vis.createRen(renWin,0.66,1)
vis.addActors(ren3,VTK_recBase[i].number,[1,0,0],0.9)
vis.addActors(ren3,VTK_base[i].number,[0,1,1],0.9)
vis.addLabels(ren3,"Overlay")
    
vis.startVis(renWin,iren)