In [None]:
import matplotlib.pyplot as plt
from data_loading import load_dataset_gino
import torch
from deepcardio.losses import LpLoss, H1Loss

dltrain, dltest, data_processor = load_dataset_gino(
    folder_path='../data/npy',
    train_batch_sizes=[1], test_batch_sizes=[1], query_res=[32, 32, 32],
    use_distributed=False)

del dltrain

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_processor = data_processor.to(device)

l2loss = LpLoss(d=3, p=2)
h1loss = H1Loss(d=2)

train_loss = l2loss
eval_losses={'l2': l2loss}

from gino import GINO
model = GINO(
    in_channels=3,  # [dr_bl, nm_bl, src_trm]
    out_channels=1,
    gno_coord_dim=3,
    gno_coord_embed_dim=16,
    gno_radius=0.1,
    gno_transform_type='linear',
    fno_n_modes=[16, 16, 16, 16],  # x_1, x_2, x_3, t
    fno_hidden_channels=32,
    fno_use_mlp=True,
    fno_norm='instance_norm',
    fno_ada_in_features=32,
    fno_factorization='tucker',
    fno_rank=0.4,
    fno_domain_padding=0.125,
    fno_mlp_expansion=1.0,
    fno_output_scaling_factor=1,
)
model = model.to(device)
model.load_state_dict(torch.load('ckpt/model_snapshot_dict.pt', map_location='cpu', weights_only=False)['MODEL_STATE'])
data_processor.eval()
data_processor.training

In [None]:
from deepcardio.neuralop_core.utils import count_model_params
count_model_params(model)

In [None]:
from pathlib import Path
import json
save_dir = Path('./ckpt/')
with open(save_dir.joinpath('metrics_dict.json').as_posix(), 'r') as f:
    list_epoch_metrics = json.load(f)

epochs = []
training_losses = []
test_losses = []

for metrics_data in list_epoch_metrics:
    epochs.append(metrics_data['epoch'])
    training_losses.append(metrics_data['avg_loss'])
    test_losses.append(metrics_data['0_l2'])

plt.plot(epochs, training_losses)
plt.plot(epochs, test_losses)
plt.show()

In [None]:
test_losses = []
for sample in dltest[0]:
    sample = data_processor.preprocess(sample)
    output = model(**sample)
    output, sample = data_processor.postprocess(output, sample)
    test_loss = l2loss(output, sample['y']).item()
    test_losses.append(test_loss)
    del sample, output
    torch.cuda.empty_cache()
plt.plot(test_losses)
plt.show()


In [None]:
test_losses[0:10]

In [None]:
sample = dltest[0].dataset[3]
sample['label']

In [None]:
sample = data_processor.preprocess(sample)
output = model(**sample)
output, sample = data_processor.postprocess(output, sample)
y = sample['y'].cpu()
output = output.detach().cpu()
error = torch.abs(output - y) / y.max()
num_timesteps = y.shape[1]
data_points = sample['input_geom'].cpu()
case_ID = sample['label']
del sample

In [None]:
import meshio
meshfile = '../data/mesh/case_ID_' + case_ID + '.msh'
xdmffile = './results/xdmf/case_ID_' + case_ID + '.xdmf'

mesh = meshio.read(meshfile)
meshio_points = mesh.points
cells = mesh.cells_dict["tetra"]

In [None]:
from scipy.spatial import cKDTree

tree = cKDTree(data_points)
distances, indices = tree.query(meshio_points)

reordered_y = y[indices]
reordered_error = error[indices]

In [None]:
with meshio.xdmf.TimeSeriesWriter(xdmffile) as writer:
    writer.write_points_cells(mesh.points, mesh.cells)
    for i in range(num_timesteps):
      t = i / 5
      data1 = reordered_y[:, i, 0].numpy()
      data2 = reordered_error[:, i, 0].numpy()
      writer.write_data(t, point_data={"y": data1, "error":data2})