In [None]:
import pandas as pd
from matplotlib.colors import ListedColormap, BoundaryNorm
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import numpy as np
import scipy as sp
import pyvista as pv
import torch

from skfem import ElementTriP1, FacetBasis
from data_generation.utils_data import convert_mesh, angle_between, plot_circular_space_time_cylinder

lut_vals = 1024
rgb = pd.read_csv("/home/haas/cardiac/inverse_problem_ECGi/src/demo_2D/colormap/coolwarm_extended.csv")
rgba = np.concatenate([rgb.to_numpy()/255, np.ones([rgb.shape[0], 1])], axis=1)
rgba_interp = interp1d(np.linspace(0, 1, num=rgba.shape[0]), rgba.T)(np.linspace(0, 1, num=lut_vals)).T
cmap_new = ListedColormap(rgba_interp)

In [None]:
torso = pv.UnstructuredGrid("data/meshes/torso2d.vtu")
torso["pt_ids"] = np.arange(torso.n_points)
pts_coarse = torso.threshold((2, 3), scalars="gmsh:geometrical").points[:,:2]
heart_pv = torso.threshold((2, 3), scalars="gmsh:geometrical")
heart = convert_mesh(heart_pv)

pts_coarse_inds = heart_pv.point_data["pt_ids"]
cond = torso.cell_data['G'].reshape(-1,2,2)
torso_msh = convert_mesh(torso)
elec_inds = np.loadtxt("data/meshes/elec_inds.txt", dtype=int).reshape(-1,3)

# Define basis on heart surface facets
heart_surf_basis = FacetBasis(
    heart,
    ElementTriP1(),
    facets=heart.boundary_facets(),
)

epi_inds = heart_surf_basis.get_dofs().all()
heart_center = np.array([20, -40]) 
epi_doflocs = heart_surf_basis.doflocs[:, epi_inds].T
angle = angle_between(epi_doflocs - heart_center[np.newaxis], np.array([[1, 0]]))
epi_order = np.argsort(angle)

In [None]:
it = 0
exp_dir = "data/data_functions/heart_potential_"
file_dir = exp_dir + str(it) + ".npz"
data = np.load(file_dir)

func = data['u']
dt = data['dt']

In [None]:
t_all = dt * np.arange(func.shape[-1])
plot_circular_space_time_cylinder(func, None, t_all, angle, epi_order, cmap_new, vmin=func.min(), vmax=func.max(), ax_label=True, colorbar=True)

In [None]:
# Load Model and compute filters
from utils import load_model
device = "cuda:0"
fname = 'MFoE/'   
model = load_model(fname, device=device)

In [None]:
data = np.load('/home/haas/temporal_foe/data/data_fixed/fixed_data.npz')
to_tensor = lambda x: torch.from_numpy(x).to(device=device, dtype=torch.float64)

Ks      = to_tensor(data['Ks']).unsqueeze(0).unsqueeze(0)
proj_p1 = to_tensor(data['proj_p1']).unsqueeze(0).unsqueeze(0)

# Extract filters from model
img = model.l_op.get_filters(proj_p1, Ks, dt=1, d=2)

# Define nodal points (example for 1D FEM)
N = img.shape[-1]
s = N // 2
nodes = np.arange(-s, s + 1)

skip_kernels = {0,2,3,4,5,9,13,15,16,23}
all_indices = list(range(img.shape[1]))
plot_indices = all_indices[:16]

fig, axes = plt.subplots(4, 4, figsize=(10, 8), sharex=True, sharey=True)
axes = axes.flatten()

for ax, i in zip(axes, plot_indices):
    values = img[0, i].detach().cpu().numpy()

    x_fine = np.linspace(nodes[0], nodes[-1], 100)
    y_fine = np.interp(x_fine, nodes, values)

    ax.plot(x_fine, y_fine, linewidth=1.5)
    ax.scatter(nodes, values, color='r', s=15)

    ax.grid(True, alpha=0.3)


# Hide unused axes
for ax in axes[len(plot_indices):]:
    ax.axis("off")

# Global axis labels
fig.supxlabel("Node index")
fig.supylabel("Filter value")

plt.show()
