In [None]:
import matplotlib.pyplot as plt
from data_loading import load_dataset_gino
from deepcardio.utils import plot_tri, create_gif_tri
import torch
from deepcardio.losses import LpLoss, H1Loss

dltrain, dltest, data_processor = load_dataset_gino(
    folder_path='../data', train_batch_sizes=[10], test_batch_sizes=[10])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_processor = data_processor.to(device)

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)

train_loss = l2loss
eval_losses={'l2': l2loss}

from gino import GINO
model = GINO(
    in_channels=3,  # [dr_bl, nm_bl, src_trm]
    out_channels=1,
    gno_coord_dim=2,
    gno_coord_embed_dim=16,
    gno_radius=0.1,
    gno_transform_type='linear',
    fno_n_modes=[16, 16, 16],
    fno_hidden_channels=64,
    fno_use_mlp=True,
    fno_norm='instance_norm',
    fno_ada_in_features=32,
    fno_factorization='tucker',
    fno_rank=0.4,
    fno_domain_padding=0.125,
    fno_mlp_expansion=1.0,
    fno_output_scaling_factor=1,
)
model = model.to(device)
model.load_state_dict(torch.load('ckpt/model_snapshot_dict.pt', map_location='cpu', weights_only=False)['MODEL_STATE'])
data_processor.eval()
data_processor.training

In [None]:
from pathlib import Path
import json
save_dir = Path('./ckpt/')
with open(save_dir.joinpath('metrics_dict.json').as_posix(), 'r') as f:
    list_epoch_metrics = json.load(f)

epochs = []
training_losses = []
test_losses = []

for metrics_data in list_epoch_metrics:
    epochs.append(metrics_data['epoch'])
    training_losses.append(metrics_data['avg_loss'])
    test_losses.append(metrics_data['0_l2'])

plt.plot(epochs, training_losses)
plt.plot(epochs, test_losses)
plt.show()

In [None]:
test_losses = []
for sample in dltest[0].dataset:
    sample = data_processor.preprocess(sample)
    output = model(**sample)
    output, sample = data_processor.postprocess(output, sample)
    test_loss = l2loss(output, sample['y']).item()
    test_losses.append(test_loss)

plt.plot(test_losses)
plt.show()


In [None]:
test_losses[130:140]

In [None]:
sample = dltest[0].dataset[138]
sample['label']

In [None]:
sample = data_processor.preprocess(sample)
output = model(**sample)
output, sample = data_processor.postprocess(output, sample)

In [None]:
from matplotlib.tri import Triangulation
import numpy as np
pos = sample['input_geom'].detach().cpu()
cells = np.load('../data/cells/case_ID_481.npy')
triang = Triangulation(pos[:, 0], pos[:, 1], cells)

t = 50
plot_tri(triang, sample['y'][:, t].cpu())
plot_tri(triang, output[:, t].detach().cpu())
plot_tri(triang, (output[:, t].detach().cpu() - sample['y'][:, t].detach().cpu()).abs() / sample['y'][:, t].detach().cpu().max())

In [None]:
y = sample['y'].cpu()
yhat = output.detach().cpu()
RMAE = (output.detach().cpu() - \
        sample['y'].detach().cpu()).abs() / sample['y'].detach().cpu().max()

create_gif_tri(triang, y, save_dir='results/y/')
create_gif_tri(triang, yhat, save_dir='results/output/')
create_gif_tri(triang, RMAE, save_dir='results/error/')

