In [None]:
import matplotlib.pyplot as plt
from data_loading import load_dataset_for_gino
from deepcardio.meshdata import BipartiteData
import torch
from deepcardio.losses import LpLoss, H1Loss

dltrain, dltest, data_processor = load_dataset_for_gino(
    folder_path='../data_processed/data.pt',
    train_batch_sizes=[1], test_batch_sizes=[1, 1], use_distributed=False,
    dataset_format=BipartiteData)

dltrain

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_processor = data_processor.to(device)

l2loss = LpLoss(d=3, p=2)
h1loss = H1Loss(d=2)

train_loss = l2loss
eval_losses={'l2': l2loss}

from gino import GINO
model = GINO(
    in_channels=5,  # [ploc_bool, D_iso, ef_vector]
    out_channels=1,
    gno_coord_dim=3,
    gno_coord_embed_dim=16,
    gno_radius=0.1,
    gno_transform_type='linear',
    fno_n_modes=[16, 16, 16, 1],  # x_1, x_2, x_3, t
    fno_hidden_channels=64,
    fno_use_mlp=True,
    fno_norm='instance_norm',
    fno_ada_in_features=32,
    fno_factorization='tucker',
    fno_rank=0.4,
    fno_domain_padding=0.125,
    fno_mlp_expansion=1.0,
    fno_output_scaling_factor=1,
)

best_train_state = torch.load('ckpt/best_model_snapshot_dict.pt', map_location='cpu', weights_only=False)
model = model.to(device)
model.load_state_dict(best_train_state['MODEL_STATE'])
data_processor.eval()
print(f"EPOCH: {best_train_state["CURRENT_EPOCH"]}, LOSS: {best_train_state["BEST_LOSS"]}")
data_processor.training

In [None]:
from deepcardio.neuralop_core.utils import count_model_params
count_model_params(model)

In [None]:
from pathlib import Path
import json
save_dir = Path('./ckpt/')
with open(save_dir.joinpath('metrics_dict.json').as_posix(), 'r') as f:
    list_epoch_metrics = json.load(f)

epochs = []
training_losses = []
test_losses = []

for metrics_data in list_epoch_metrics:
    epochs.append(metrics_data['epoch'])
    training_losses.append(metrics_data['avg_loss'])
    test_losses.append(metrics_data['0_l2'])

plt.plot(epochs, training_losses, label="Training")
plt.plot(epochs, test_losses, label="Validation")
plt.yscale("log")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
for i, sample in enumerate(dltrain[0].dataset):
    if sample['label'] == '0':
        print(i)
        break

In [None]:
sample['label']

In [None]:
dltrain[0].dataset[1863]['label']

In [None]:
sample['label']

In [None]:
val_losses = []
for sample in dltest[0]:
    sample = data_processor.preprocess(sample)
    output = model(**sample)
    output, sample = data_processor.postprocess(output, sample)
    test_loss = l2loss(output, sample['y']).item()
    val_losses.append(test_loss)
    del sample, output
    torch.cuda.empty_cache()
plt.plot(val_losses)
plt.show()


In [None]:
torch.tensor(val_losses).mean()

In [None]:
torch.topk(torch.tensor(val_losses), k=5)

In [None]:
test_losses = []
for sample in dltest[1]:
    sample = data_processor.preprocess(sample)
    output = model(**sample)
    output, sample = data_processor.postprocess(output, sample)
    test_loss = l2loss(output, sample['y']).item()
    test_losses.append(test_loss)
    del sample, output
    torch.cuda.empty_cache()
plt.plot(test_losses)
plt.show()

In [None]:
torch.tensor(test_losses).mean()

In [None]:
109, 505, 986, 620, 174

In [None]:
sample = dltest[0].dataset[174]
sample['label'], val_losses[174]

In [None]:
sample = data_processor.preprocess(sample)
output = model(**sample)
output, sample = data_processor.postprocess(output, sample)
y = sample['y'].cpu()
output = output.detach().cpu()
error = torch.linalg.vector_norm(
    torch.flatten(output, start_dim=2) - torch.flatten(y, start_dim=2),
    ord=2, dim=-1, keepdim=True) / torch.linalg.vector_norm(
        torch.flatten(y, start_dim=2), ord=2, dim=-1, keepdim=True)
num_timesteps = y.shape[1]
data_points = sample['input_geom'].cpu()
case_ID = sample['label']
# del sample

In [None]:
import meshio
meshfile = '../data/mesh/case' + case_ID + '.vtk'
xdmffile = './results/xdmf/case' + case_ID + '.xdmf'

mesh = meshio.read(meshfile)
meshio_points = mesh.points
cells = mesh.cells_dict["tetra"]

In [None]:
from scipy.spatial import cKDTree

tree = cKDTree(data_points)
distances, indices = tree.query(meshio_points)

reordered_y = y[indices]
reordered_error = error[indices]

In [None]:
with meshio.xdmf.TimeSeriesWriter(xdmffile) as writer:
    writer.write_points_cells(mesh.points, mesh.cells)
    for i in range(num_timesteps):
      data1 = reordered_y[:, i, 0].numpy()
      data2 = output[indices][:, i, 0].numpy()
      data3 = reordered_error[:, i, 0].numpy()
      data4 = sample['a'][indices, i, 0].cpu().numpy()
      data5 = sample['a'][indices, i, 1].cpu().numpy()
      data6 = sample['a'][indices, i, 2:].cpu().numpy()
      writer.write_data(
         i, point_data={"y_true": data1,
                        "y_est": data2,
                        "error": data3,
                        "ploc_bool": data4,
                        "D_iso": data5,
                        "ef": data6})