In [None]:
import matplotlib.pyplot as plt
from data_loading import load_dataset_graph
from deepcardio.utils import plot_tri
import torch
from deepcardio.losses import LpLoss, H1Loss

dltrain, dltest, data_processor = load_dataset_graph(
    folder_path='../data', train_batch_sizes=[10], test_batch_sizes=[10])

device = 'cuda' if torch.cuda.is_available() else 'cpu'
data_processor = data_processor.to(device)

l2loss = LpLoss(d=2, p=2)
h1loss = H1Loss(d=2)

train_loss = l2loss
eval_losses={'l2': l2loss}

from gnn_model import KernelNN
model = KernelNN(node_prj_dim=32, edge_prj_dim=24, num_layers=6, edge_attrs_dim=3, node_ftrs_dim=5, out_dim=1)
model = model.to(device)

In [None]:
from pathlib import Path
import json
save_dir = Path('./ckpt/')
with open(save_dir.joinpath('metrics_dict.json').as_posix(), 'r') as f:
    list_epoch_metrics = json.load(f)

epochs = []
training_losses = []
test_losses = []

for metrics_data in list_epoch_metrics:
    epochs.append(metrics_data['epoch'])
    training_losses.append(metrics_data['avg_loss'])
    test_losses.append(metrics_data['0_l2'])

plt.plot(epochs, training_losses)
plt.plot(epochs, test_losses)
plt.show()

In [None]:
model.load_state_dict(torch.load('ckpt/model_snapshot_dict.pt', map_location='cpu')['MODEL_STATE'])
data_processor.eval()
data_processor.training

In [None]:
test_losses = []
for sample in dltest[0].dataset:
    sample = data_processor.preprocess(sample)
    output = model(**sample)
    output, sample = data_processor.postprocess(output, sample)
    test_loss = l2loss(output, sample['y']).item()
    test_losses.append(test_loss)

plt.plot(test_losses)
plt.show()

In [None]:
test_losses[230:240]

In [None]:
sample = dltest[0].dataset[230]
sample['label']

In [None]:
from matplotlib.tri import Triangulation
import numpy as np
sample = data_processor.preprocess(sample)
output = model(**sample)
pos = sample['pos'].detach().cpu()
cells = np.load('../data/cells/case_ID_140.npy')
triang = Triangulation(pos[:, 0], pos[:, 1], cells)
output, sample = data_processor.postprocess(output, sample)
plot_tri(triang, sample['y'].cpu())
plot_tri(triang, output.detach().cpu())
plot_tri(triang, (output.detach().cpu() - sample['y'].detach().cpu()).abs() / sample['y'].detach().cpu().max())