In [None]:
import vtk
from vtk.util.numpy_support import numpy_to_vtk
import numpy as np
import tensorflow as tf
import gcChebyshev
import IPython.display as display
import notebookrender
import ipywidgets

## Load geometry

In [None]:
filePath = '../Variationsanalyse/data/sto_2/run__{:05d}/m_1/vtk_Export/Assembly/Assembly_6.vtu'
reader = vtk.vtkXMLUnstructuredGridReader()
reader.SetFileName(filePath.format(1))
reader.Update()
data = reader.GetOutput()

## Load the Fourier basis

In [None]:
U = np.loadtxt('U-30.csv')
print(U.shape)

## Load model

In [None]:
new_model = tf.keras.models.load_model('gcChebyshev.h5', custom_objects={'gcChebyshev': gcChebyshev.gcChebyshev})
new_model.summary()

## Convert filters to spatial domain
The filters are stored in the PEN as coefficients for the Chebyshev polynomials. Converting them back to the spatial domain to display them on the geometry, we need to do some steps. Let $\theta_k$ be the coefficients for the Chebyshev polynomials and $T_k$ be a matrix containing the Chebyshev polynomials. We can now obtain a filter $f$ in its spatial representation using the Fourier basis as follows:
\begin{align}
    f_\theta &= U g_\theta(\Lambda) \\
    g_\theta(\Lambda) &= \sum\limits_{k=0}^{K-1} T_k(\Lambda) \theta_k
\end{align}

In [None]:
gcnn = new_model.get_layer(index=0)
T = tf.transpose(gcnn.T)

# Equation (2)
F = tf.matmul(T, gcnn.Filters_real)

print(F.shape)
f = []
# Equation (1): Since we deal with multi-channel filters and several filters, we have to iterate over those dimensions
for k in range(F.shape[0]):
    # first channel
    f_0 = U @ tf.reshape(F[k,:,0], (-1,1))
    f_1 = U @ tf.reshape(F[k,:,1], (-1,1))
    f_2 = U @ tf.reshape(F[k,:,2], (-1,1))

    f.append(tf.transpose([f_0[:,0], f_1[:,0], f_2[:,0]]))

print(np.shape(f))

# Visualize Filter

In [None]:
display.display(ipywidgets.HBox([
    ipywidgets.HTML('<h5 style="text-align:center; width:300px;">X</h5>'),
    ipywidgets.HTML('<h5 style="text-align:center; width:300px;">Y</h5>'),
    ipywidgets.HTML('<h5 style="text-align:center; width:300px;">Z</h5>')
]))

for i in range(len(f)):
    scalarRange = [
        tf.math.reduce_min(f[i]),
        tf.math.reduce_max(f[i])
    ]
    
    U_res = numpy_to_vtk(f[i][:,0])
    U_res.SetName('U_res')
    data.GetPointData().AddArray(U_res)
    res1 = notebookrender.rendering(
        data, 
        width=500, 
        height=600, 
        pos=[-50.0, 0.0, 4500.0], 
        foc=[-50.0, 0.0, 1000.0], 
        zoom=1.3, 
        scalarRange=scalarRange, 
        showColorBar=False,
    )
    data.GetPointData().RemoveArray('U_res')

    U_res = numpy_to_vtk(f[i][:,1])
    U_res.SetName('U_res')
    data.GetPointData().AddArray(U_res)
    res2 = notebookrender.rendering(
        data, 
        width=500, 
        height=600, 
        pos=[-50.0, 0.0, 4500.0], 
        foc=[-50.0, 0.0, 1000.0], 
        zoom=1.3, 
        scalarRange=scalarRange, 
        showColorBar=False,
    )
    data.GetPointData().RemoveArray('U_res')
    
    U_res = numpy_to_vtk(f[i][:,2])
    U_res.SetName('U_res')
    data.GetPointData().AddArray(U_res)
    res3 = notebookrender.rendering(
        data, 
        width=800, 
        height=600, 
        pos=[200.0, 0.0, 4500.0], 
        foc=[200.0, 0.0, 1000.0], 
        zoom=1.3, 
        scalarRange=scalarRange, 
        colorBarCoordinate=(0.7,0.1)
    )
    data.GetPointData().RemoveArray('U_res')

    display.display(ipywidgets.HBox([
            res1,
            res2,
            res3
    ], layout=ipywidgets.Layout(height='300px')))