In [None]:
import torch
import torch.nn as nn

from torch_geometric.nn import knn_graph
from torch_geometric.loader import DataLoader

import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import gc

from plotly.subplots import make_subplots


In [None]:

from segnn.segnn import SEGNN
from e3nn.o3 import Irreps, spherical_harmonics
from segnn.balanced_irreps import BalancedIrreps, WeightBalancedIrreps
from segnn.instance_norm import InstanceNorm

# use it for input features similar to hemodynamics paper
from Utility_functions import print_3D_graph, manual_print_3D_graph, Graph_dataset_with_equiv_features, rotate_graph_coords, translate_graph_coords, denormalize_predictions


In [None]:
import os
import params

print('DATADIR', params.DATADIR)
print('MODEL_DIR', params.MODEL_DIR)
print('NSIM', params.NSIM)
print('BATCH_SIZE', params.BATCH_SIZE)

print('torch.__version__', torch.__version__)
print('torch.cuda.is_available()', torch.cuda.is_available())

In [None]:
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# dev = "cpu"
print(dev)

# Model

In [None]:
gc.collect()
torch.cuda.empty_cache()

# change the path accordingly
path = os.path.join(params.MODEL_DIR, "best_model.pt")
print(f"Loading model checkpoint from {str(path)}")

# checkpoint = torch.load(path)
checkpoint = torch.load(path, map_location=torch.device('cpu')) # my laptop only has integraded graphycs

input_irreps = checkpoint['input_irreps']
hidden_irreps = checkpoint['hidden_irreps']
output_irreps = checkpoint['output_irreps']
edge_attr_irreps = checkpoint['edge_attr_irreps']
node_attr_irreps = checkpoint['node_attr_irreps']
task = checkpoint['task']
norm=checkpoint['norm']
num_layers=checkpoint['num_layers']
additional_message_irreps=checkpoint['additional_message_irreps']

model = SEGNN(input_irreps=input_irreps,
              hidden_irreps=hidden_irreps,
              output_irreps=output_irreps,
              edge_attr_irreps=edge_attr_irreps,
              node_attr_irreps=node_attr_irreps,
              task=task,
              norm=norm,
              num_layers=num_layers,
              additional_message_irreps=additional_message_irreps
              )

model = checkpoint['model']

# loss_func = nn.MSELoss()
loss_func = nn.L1Loss()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))

model.to(dev)
model.eval()

In [None]:
gc.collect()
torch.cuda.empty_cache()

# change the path accordingly
path = os.path.join(params.DATADIR, "checkpoints/best_model.pt")
print(f"Loading model checkpoint from {str(path)}")

# checkpoint = torch.load(path)

input_irreps = checkpoint['input_irreps']
hidden_irreps = checkpoint['hidden_irreps']
output_irreps = checkpoint['output_irreps']
edge_attr_irreps = checkpoint['edge_attr_irreps']
node_attr_irreps = checkpoint['node_attr_irreps']
task = checkpoint['task']
norm=checkpoint['norm']
num_layers=checkpoint['num_layers']
additional_message_irreps=checkpoint['additional_message_irreps']

model = SEGNN(input_irreps=input_irreps,
              hidden_irreps=hidden_irreps,
              output_irreps=output_irreps,
              edge_attr_irreps=edge_attr_irreps,
              node_attr_irreps=node_attr_irreps,
              task=task,
              norm=norm,
              num_layers=num_layers,
              additional_message_irreps=additional_message_irreps
              )

model = checkpoint['model']

# loss_func = nn.MSELoss()
loss_func = nn.L1Loss()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(model))

model.to(dev)
model.eval()

In [None]:
# change the root path accordingly
dataset = Graph_dataset_with_equiv_features(params.DATADIR)
graph_connectivity = torch.tensor(np.load(params.DATADIR+"/connectivity.npy"))
print(f"Loading dataset from {str(params.DATADIR)}")

In [None]:
# Dataset split 80-10-10
num_workers = 8

dataset_length = len(dataset)
if params.NSIM >= 10:
    train_length, test_length = train_test_split(range(dataset_length), test_size = 0.2, shuffle = False)
    val_length, test_length = train_test_split(range(len(test_length)), test_size = 0.5, shuffle = False)
else:
    train_length, test_length = list(range(1)), [0]
    val_length, test_length = [0], [0]
    
test_dataset = dataset[test_length]

loader = DataLoader(test_dataset, batch_size = 1, shuffle = False, num_workers = num_workers)

# Test loss

In [None]:
instance_norm = True

input_norm = InstanceNorm(input_irreps)
edge_norm = InstanceNorm(edge_attr_irreps)
# node_norm = InstanceNorm(node_attr_irreps)
# output_norm = InstanceNorm(output_irreps)

In [None]:
outputs = []
loss = []

# mask = True
neighbours = params.NEIGHBOURS

fluid_nodes = torch.tensor(2)

# s is the sample graph
for s in tqdm(loader): 


    # edge_index = knn_graph(s.pos, neighbours, s.batch)
    edge_index = graph_connectivity
    # print(edge_index)
    s.edge_index = edge_index
    # print(len(s.pos))
    edge_relativePos = (torch.index_select(s.pos, 0, edge_index[1]) - torch.index_select(s.pos, 0, edge_index[0]))
    edge_relativeDist = torch.norm(edge_relativePos, dim = -1, keepdim = True) 
    edge_attr = torch.cat([edge_relativeDist, edge_relativePos], dim = -1) 

    s.edge_attr = edge_attr
    #s.node_attr = s.pos
    #print(s.node_attr)
            
    if instance_norm: # <----------------------------------------------- VERY IMPORTANT!!!
        s.x = input_norm(s.x, s.batch)
        s.edge_attr = edge_norm(s.edge_attr, s.edge_index[1,:])
    
    s = s.to(dev)
    with torch.no_grad():
        
        if hasattr(s, "mask"):
            mask = s.mask
        else:
            mask = torch.ones(s.x.shape[0], dtype=torch.bool)
            
        out = model(s)
        print(out[0])
        #print(out)
        loss_val = loss_func(out[mask], s.y[mask])
        #print(loss_val)
        
    
    outputs.append(out)
    loss.append(loss_val.item())
    # print(loss)
    # print(len(loss))

In [None]:
save = True

if save:
    nploss = np.save('loss.npy', loss)

In [None]:
plt.figure(figsize=(15,8))
plt.plot(range(len(loss)), loss, marker = 'o', markersize=5,
         linewidth=0)
plt.title('MSE error across test samples', size = 30)
plt.grid()
plt.show()

# Test set plots

In [None]:
loadloss = np.load('loss.npy')

plt.figure(figsize=(16,8))

plt.plot(range(len(loadloss)), loadloss, marker='o', markersize=5,
         linewidth=1, label='MSE on test sample')

plt.xlabel('Test samples', size=25)
plt.ylabel('Mean Squared Error (MSE)', size=25)
plt.title('MSE error across test samples', size=30)
plt.grid()
legend = plt.legend()
legend.get_frame().set_facecolor('lightgray')

print("Number of test samples:", len(loadloss))

plt.show()

# Single test plot

In [None]:

# idx ofgraph (steady state) to plot
test_graph = next(iter(loader))
print(test_graph.x)

#test_graph = s
with torch.no_grad():

    # edge_index = knn_graph(test_graph.pos, neighbours, test_graph.batch)
    edge_index = graph_connectivity
    test_graph.edge_index = edge_index
    
    edge_relativePos = (torch.index_select(test_graph.pos, 0, edge_index[1]) - torch.index_select(test_graph.pos, 0, edge_index[0]))
    edge_relativeDist = torch.norm(edge_relativePos, dim = -1, keepdim = True) 
    edge_attr = torch.cat([edge_relativeDist, edge_relativePos], dim = -1) 
    #print(edge_attr)
    
    test_graph.edge_attr = edge_attr
    #test_graph.node_attr = test_graph.pos
            
    if instance_norm: # <----------------------------------------------- VERY IMPORTANT!!!
        test_graph.x = input_norm(test_graph.x, test_graph.batch)
        test_graph.edge_attr = edge_norm(test_graph.edge_attr, test_graph.edge_index[1,:])

    test_graph = test_graph.to(dev)
    # print(test_graph)
    # print(test_graph.node_attr)
    #print(test_graph.x)

    print(f"test_graph: {test_graph}")
    print(test_graph.x.tolist())
    pred = model(test_graph.to(dev)) 
    # print(pred[0]) # press, vel_x, vel_y, vel_z

    # print(test_graph.node_attr)


In [None]:
# remember that to plot the graphs ignoring the ground truth velocities and pressures, you have
# to plot pred[sample.mask], and not simply the "pred" tensor (which contains predictions on all nodes,
# included the ones that were masked out during training because true values were used in the input, and 
# that were not taken care of by backpropagation)

# Ground truth

In [None]:



### GROUND TRUTH DENORMALIZATION <------------------------------------------------------

ground_truth_to_denormalize = {
    "pressure": test_graph.cpu().y[...,0],
    "vel_x": test_graph.y[..., 1],
    "vel_y": test_graph.y[..., 2],
    "vel_z": test_graph.y[..., 3]
}

denormalized_ground_truth = denormalize_predictions(ground_truth_to_denormalize, params.DATADIR+'/normalization_params.json')
# print(denormalized_ground_truth)
true_press = denormalized_ground_truth['pressure']
# print(true_press)
pred_vel_x = denormalized_ground_truth['vel_x']
pred_vel_y = denormalized_ground_truth['vel_y']
pred_vel_z = denormalized_ground_truth['vel_z']

true_vel = torch.stack([pred_vel_x, pred_vel_y, pred_vel_z], dim=1)
true_vel = torch.norm(true_vel[..., 0:3], dim=-1).cpu()
# print(true_vel)

# print(true_press.shape)
# print(true_vel.shape)

###


# edges = knn_graph(test_graph.pos, neighbours)
edges = graph_connectivity
# print(len(test_graph.pos))
# print(test_graph.pos[1:4].tolist())
# print(test_graph.pos[0].tolist())

# print(test_graph.num_features)
# print(test_graph.pos)

# print(test_graph.node_attr[..., 0:4])


true_magvel = true_vel # torch.norm(test_graph.cpu().y[...,1:4], dim=-1) # y is ground truth

print_3D_graph(test_graph.pos.cpu(), edges = edges, color = true_magvel)
print_3D_graph(test_graph.pos.cpu(), edges = edges, color = true_press)
# print_3D_graph(test_graph.pos.cpu(), edges = edges, color = test_graph.edge_index) # edge_index is MIS, wrong for testing purposes

# Prediction

In [None]:



### DENORMALIZATION <------------------------------------------------------

pred_to_denormalize = {
"pressure": pred[:, 0],
"vel_x": pred[:, 1],
"vel_y": pred[:, 2],
"vel_z": pred[:, 3]
}
    
denormalized_prediction = denormalize_predictions(pred_to_denormalize, params.DATADIR+'/normalization_params.json')
# print(denormalized_prediction)
pred_pressure = denormalized_prediction['pressure']
pred_vel_x = denormalized_prediction['vel_x']
pred_vel_y = denormalized_prediction['vel_y']
pred_vel_z = denormalized_prediction['vel_z']

pred = torch.stack([pred_pressure, pred_vel_x, pred_vel_y, pred_vel_z], dim=1)
# print(pred)

# print(pred[:, 0].shape)
# print(pred[:, 1:4].shape)

###



# edges = knn_graph(test_graph.pos, neighbours)
edges = graph_connectivity

for i in range(5):
    print(pred[i].tolist())
#print(pred[-1].tolist())
pred_magvel = torch.norm(pred[...,1:4], dim=-1)
pred_press = torch.norm(pred[...,[0]], dim=-1) 

print_3D_graph(test_graph.pos.cpu(), edges = edges, color = pred_magvel.cpu())
print_3D_graph(test_graph.pos.cpu(), edges = edges, color = pred_press.cpu())


In [None]:

global_min = torch.min(torch.cat((true_magvel.cpu(), pred_magvel.cpu())))
print(global_min)
global_max = torch.max(torch.cat((true_magvel.cpu(), pred_magvel.cpu())))
print(global_max)

fig1_colors = true_magvel
fig1 = manual_print_3D_graph(test_graph.pos.cpu(), edges.cpu(), fig1_colors.cpu(),
                              cmin=global_min.item(), cmax=global_max.item(), colorscale='Viridis') # manual_print_3D_graph needs fig.show() below

fig2_colors = pred_magvel
fig2 = manual_print_3D_graph(test_graph.pos.cpu(), edges.cpu(), fig2_colors.cpu(),
                              cmin=global_min.item(), cmax=global_max.item(), colorscale='Viridis')

fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
                    subplot_titles=("True vel", "Pred vel"))

for trace in fig1['data']:
    fig.add_trace(trace, row=1, col=1)

for trace in fig2['data']:
    fig.add_trace(trace, row=1, col=2)

fig.show()


In [None]:

print(true_press.shape)
print(pred_press.shape)

global_min = torch.min(torch.cat((true_press.cpu(), pred_press.cpu())))
print(global_min)
global_max = torch.max(torch.cat((true_press.cpu(), pred_press.cpu())))
print(global_max)

fig3_colors = true_press
fig3 = manual_print_3D_graph(test_graph.pos.cpu(), edges.cpu(), fig3_colors, 
                              cmin=global_min.item(), cmax=global_max.item(), colorscale='Viridis')

fig4_colors = pred_press
fig4 = manual_print_3D_graph(test_graph.pos.cpu(), edges.cpu(), fig4_colors.cpu(), 
                              cmin=global_min.item(), cmax=global_max.item(), colorscale='Viridis')

fig_press = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
                          subplot_titles=("True press", "Pred press"))

for trace in fig3['data']:
    fig_press.add_trace(trace, row=1, col=1)

for trace in fig4['data']:
    fig_press.add_trace(trace, row=1, col=2)

fig_press.show()


# Prediction with INPUT

In [None]:

# idx ofgraph (steady state) to plot
test_graph_with_input = next(iter(loader))
# print(test_graph_with_input.x)

#test_graph_with_input = s
with torch.no_grad():

    # edge_index = knn_graph(test_graph_with_input.pos, neighbours, test_graph_with_input.batch)
    edge_index = graph_connectivity
    test_graph_with_input.edge_index = edge_index
    
    edge_relativePos = (torch.index_select(test_graph_with_input.pos, 0, edge_index[1]) - torch.index_select(test_graph_with_input.pos, 0, edge_index[0]))
    edge_relativeDist = torch.norm(edge_relativePos, dim = -1, keepdim = True) 
    edge_attr = torch.cat([edge_relativeDist, edge_relativePos], dim = -1) 
    #print(edge_attr)
    
    test_graph_with_input.edge_attr = edge_attr
    #test_graph_with_input.node_attr = test_graph_with_input.pos
            
    if instance_norm: # <----------------------------------------------- VERY IMPORTANT!!!
        test_graph_with_input.x = input_norm(test_graph_with_input.x, test_graph_with_input.batch)
        test_graph_with_input.edge_attr = edge_norm(test_graph_with_input.edge_attr, test_graph_with_input.edge_index[1,:])

    test_graph_with_input = test_graph_with_input.to(dev)
    # print(test_graph_with_input)
    # print(test_graph_with_input.node_attr)
    #print(test_graph_with_input.x)
    
    ### INPUT TEST
    test_graph_with_input.x[:, 8] = 100 # velocity_range = (320, 360) # m/s
    test_graph_with_input.x[:, 9] = -40 # angle_range = (-10, 10) # degrees
    print(test_graph_with_input.x)

    pred_with_input = model(test_graph_with_input.to(dev)) 
    # print(pred[0]) # press, vel_x, vel_y, vel_z

    # print(test_graph_with_input.node_attr)


### DENORMALIZATION <------------------------------------------------------

pred_to_denormalize_with_input = {
"pressure": pred_with_input[:, 0],
"vel_x": pred_with_input[:, 1],
"vel_y": pred_with_input[:, 2],
"vel_z": pred_with_input[:, 3]
}
    
denormalized_prediction_with_input = denormalize_predictions(pred_to_denormalize_with_input, params.DATADIR+'/normalization_params.json')
# print(denormalized_prediction)
pred_pressure_with_input = denormalized_prediction['pressure']
pred_vel_x_with_input = denormalized_prediction['vel_x']
pred_vel_y_with_input = denormalized_prediction['vel_y']
pred_vel_z_with_input = denormalized_prediction['vel_z']

pred_with_input = torch.stack([pred_pressure_with_input, pred_vel_x_with_input, pred_vel_y_with_input, pred_vel_z_with_input], dim=1)
print(pred_with_input)

# print(pred[:, 0].shape)
# print(pred[:, 1:4].shape)

###


In [None]:
# edges = knn_graph(test_graph_with_input.pos, neighbours)
edges = graph_connectivity

for i in range(5):
    print(pred_with_input[i].tolist())
#print(pred_with_input[-1].tolist())
pred_magvel_with_input = torch.norm(pred_with_input[...,1:4], dim=-1)
pred_press_with_input = torch.norm(pred_with_input[...,[0]], dim=-1) 

# print_3D_graph(test_graph_with_input.pos.cpu(), edges = edges, color = pred_magvel_with_input.cpu())
# print_3D_graph(test_graph_with_input.pos.cpu(), edges = edges, color = pred_press_with_input.cpu())


In [None]:

global_min = torch.min(torch.cat((pred_magvel.cpu(), pred_magvel_with_input.cpu())))
print(global_min)
global_max = torch.max(torch.cat((pred_magvel.cpu(), pred_magvel_with_input.cpu())))
print(global_max)

fig5_colors = pred_magvel
fig5 = manual_print_3D_graph(test_graph_with_input.pos.cpu(), edges.cpu(), fig5_colors.cpu(), 
                              cmin=global_min.item(), cmax=global_max.item(), colorscale='Viridis')

fig6_colors = pred_magvel_with_input
fig6 = manual_print_3D_graph(test_graph_with_input.pos.cpu(), edges.cpu(), fig6_colors.cpu(), 
                              cmin=global_min.item(), cmax=global_max.item(), colorscale='Viridis')

fig_press = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter3d'}, {'type': 'scatter3d'}]],
                          subplot_titles=("Vel pred", "INPUT Vel pred"))

for trace in fig5['data']:
    fig_press.add_trace(trace, row=1, col=1)

for trace in fig6['data']:
    fig_press.add_trace(trace, row=1, col=2)

fig_press.show()


In [None]:
loss_vel_with_without_input = 100*abs((pred_magvel.cpu()-pred_magvel_with_input.cpu())/pred_magvel.cpu())
# print(loss_vel_with_without_input)

plt.figure(figsize=(15,10))

loss_vel_flat = loss_vel_with_without_input.cpu()
true_mis_flat = [1] * len(loss_vel_flat.cpu())
color = plt.cm.hsv(np.linspace(0, 1, len(true_mis_flat)))
plt.scatter(range(len(loss_vel_with_without_input)), loss_vel_with_without_input, marker = ".", c = None, cmap = 'Viridis', s = 10)
plt.colorbar()

plt.title("Input - Non Input VEL difference", size= 15)
plt.xlabel("MIS", size=15)
plt.ylabel("Loss in %", size=15)
# plt.semilogy()
plt.grid()

plt.show()


In [None]:
loss_press_with_without_input = 100*abs((pred_press.cpu()-pred_press_with_input.cpu())/pred_press.cpu())
# print(loss_press_with_without_input)

plt.figure(figsize=(15,10))

loss_press_flat = loss_press_with_without_input.cpu()
true_mis_flat = [1] * len(loss_press_flat.cpu())
color = plt.cm.hsv(np.linspace(0, 1, len(true_mis_flat)))
plt.scatter(range(len(loss_press_with_without_input)), loss_press_with_without_input, marker = ".", c = None, cmap = 'hsv', s = 10)
# plt.colorbar()

plt.title("Input - Non Input PRESS difference", size= 15)
plt.xlabel("MIS", size=15)
plt.ylabel("Loss in %", size=15)
# plt.semilogy()
plt.grid()

plt.show()


# Test equivariance

In [None]:

# print(test_graph)
# print(pred)
# print(true_magvel)
print(len(pred))
print(len(true_magvel))
# print_3D_graph(test_graph.pos, edges = edges, color = true_magvel.cpu())


In [None]:

translation_vector = np.random.randint(0, 30, size=3)
rotation_angle = np.random.randint(40, 80) # rotation is anti clockwise
translated_test_graph = translate_graph_coords(test_graph.cpu(), translation_vector)
rotated_test_graph = rotate_graph_coords(translated_test_graph.cpu(), rotation_angle, axis="x")
# print(rotated_test_graph)

rotated_graph_pred = model(rotated_test_graph.to(dev))
print(rotated_graph_pred[:10])
rotated_graph_magvel_pred = torch.norm(rotated_graph_pred[...,1:4], dim=-1)
# print(rotated_graph_pred)
# print(rotated_graph_magvel_pred)
print(len(rotated_graph_magvel_pred))

print_3D_graph(rotated_test_graph.pos, edges = edges, color = true_magvel)


In [None]:

equivariance_loss = pred - rotated_graph_pred
# print(equivariance_loss.tolist())
print(f"Total error is: {equivariance_loss.sum()}\n")

for row in range(pred.size(0)):
    diff = pred[row, :] - rotated_graph_pred[row, :]
    if (diff != 0).any():
        print(f"On node {row}: {diff.tolist()}")

print(" ")
print(f"Graph has: {len(pred)} nodes")


# Loss analysis

In [None]:
import glob
import re

numbers = re.compile(r'(\d+)')

def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

probes_path = os.path.join(params.DATADIR, 'probes')
probes = sorted(glob.glob(os.path.join(probes_path, '*.npy')), key = numericalSort)
# print(probes[0])

probe_data_list = []

for data_file in probes:
    read_points = np.load(data_file)
    probe_data_list.append(torch.from_numpy(read_points))

pos_list = probe_data_list
#print(pos_list)
# print(pos_list[0][...,3])

In [None]:

# idx ofgraph (steady state) to plot
# test_graph = test_dataset[1]
# print(len(test_dataset))

loss_vel = []
loss_press = []
true_mis = []

true_vel = []
true_press = []

i = 1
for test_graph in tqdm(loader):
        
    # print(test_graph.pos[10])

    #test_graph = s
    with torch.no_grad():

        edge_index = graph_connectivity
        test_graph.edge_index = edge_index
        edge_relativePos = (torch.index_select(test_graph.pos, 0, edge_index[1]) - torch.index_select(test_graph.pos, 0, edge_index[0]))
        edge_relativeDist = torch.norm(edge_relativePos, dim = -1, keepdim = True) 
        edge_attr = torch.cat([edge_relativeDist, edge_relativePos], dim = -1) 
        test_graph.edge_attr = edge_attr

        if instance_norm: # <----------------------------------------------- VERY IMPORTANT!!!
            test_graph.x = input_norm(test_graph.x, test_graph.batch)
            test_graph.edge_attr = edge_norm(test_graph.edge_attr, test_graph.edge_index[1,:])
            

        model_pred_before_test = model(test_graph.to(dev)) 



    ### DENORMALIZATION <------------------------------------------------------

    pred_to_denormalize_before_test = {
        "pressure": model_pred_before_test[:, 0],
        "vel_x": model_pred_before_test[:, 1],
        "vel_y": model_pred_before_test[:, 2],
        "vel_z": model_pred_before_test[:, 3]
    }
            
    denormalized_prediction_before_test = denormalize_predictions(pred_to_denormalize_before_test, params.DATADIR+'/normalization_params.json')
    # print(denormalized_prediction)
    pred_pressure_before_test = denormalized_prediction_before_test['pressure']
    pred_vel_x_before_test = denormalized_prediction_before_test['vel_x']
    pred_vel_y_before_test = denormalized_prediction_before_test['vel_y']
    pred_vel_z_before_test = denormalized_prediction_before_test['vel_z']

    model_pred_before_test = torch.stack([pred_pressure_before_test, pred_vel_x_before_test, pred_vel_y_before_test, pred_vel_z_before_test], dim=1)
    # print(pred)

    # print(pred[:, 0].shape)
    # print(pred[:, 1:4].shape)

    ###



    # model prediction
    out = model_pred_before_test.cpu()
    out_vel = torch.norm(out[...,1:4], dim=-1).cpu()
    out_press = out[...,0].cpu()
    # print(f"out_vel: {out_vel.tolist()}")
    # print(f"out_press: {out_press.tolist()}")
    true_vel_i = torch.norm(test_graph.y[...,1:4], dim=-1).cpu()
    true_press_i = test_graph.y[...,0].cpu()



    ### GROUND TRUTH DENORMALIZATION <------------------------------------------------------

    ground_truth_to_denormalize = {
        "pressure": test_graph.y[..., 0],
        "vel_x": test_graph.y[..., 1],
        "vel_y": test_graph.y[..., 2],
        "vel_z": test_graph.y[..., 3]
    }

    denormalized_ground_truth = denormalize_predictions(ground_truth_to_denormalize, params.DATADIR+'/normalization_params.json')
    # print(denormalized_ground_truth)
    true_press_i = denormalized_ground_truth['pressure']
    # print(true_press)
    pred_vel_x = denormalized_ground_truth['vel_x']
    pred_vel_y = denormalized_ground_truth['vel_y']
    pred_vel_z = denormalized_ground_truth['vel_z']

    true_vel_i = torch.stack([pred_vel_x, pred_vel_y, pred_vel_z], dim=1)
    true_vel_i = torch.norm(true_vel_i[..., 0:3], dim=-1).cpu()
    # print(true_vel)

    # print(true_press.shape)
    # print(true_vel.shape)

    ###

    

    true_vel.append(true_vel_i)
    true_press.append(true_press_i)
    # print(f"true_vel: {true_vel.tolist()}")
    # print(f"true_press: {true_press.tolist()}")

    loss_vel.append(100*abs((true_vel_i.cpu()-out_vel.cpu())/true_vel_i.cpu()))
    # print(loss_vel)
    loss_press.append(100*abs((true_press_i.cpu()-out_press.cpu())/true_press_i.cpu()))

    # loss_vel = loss_vel[:(len(loss_vel)-1)]
    # loss_press = loss_press[:(len(loss_press)-1)]
    # true_mis = true_mis[:(len(true_mis)-1)]

    true_mis.append(pos_list[i][...,3])
    i += 1
    

In [None]:
# print(len(loss_vel))
# print(len(true_mis))

# print(loss_vel[0])
# print(true_mis[0])

plt.figure(figsize=(15,10))

true_mis_flat = np.concatenate(true_mis)
loss_vel_flat = np.concatenate(loss_vel)
color = plt.cm.hsv(np.linspace(0, 1, len(true_mis_flat)))
plt.scatter(range(len(loss_vel_flat)), loss_vel_flat, marker = ".", c = None, cmap = 'hsv', s = 10)
# plt.colorbar()

plt.title("Velocity error on test dataset", size= 15)
plt.xlabel("Test Node idx", size=15)
plt.ylabel("Loss in %", size=15)
# plt.semilogy()
plt.grid()

plt.show()


In [None]:

plt.figure(figsize=(15,10))

loss_press_flat = np.concatenate(loss_press)
# color = np.arange(len(loss_press))
color = plt.cm.hsv(np.linspace(0, 1, len(true_mis_flat)))
# print(f"loss_press: {loss_press}")
plt.scatter(range(len(loss_press_flat)), loss_press_flat, marker = ".", c = None, cmap = 'hsv', s = 10)
# plt.colorbar()

# for i in range(len(loss_vel_z)):
#     plt.text(true_mis[i], loss_vel_z[i], str(i), fontsize=4, ha='right')

plt.title("Pressure error on test sample", size= 15)
plt.xlabel("Test Node idx", size=15)
plt.ylabel("Loss in %", size=15)
# plt.semilogy()
plt.grid()

plt.show()


# TEST DATASET

In [None]:

# print(true_vel[0])
# print(true_mis[0])
# print(len(true_vel))
# print(len(true_mis))

plt.figure(figsize=(15,10))

true_vel_flat = np.concatenate(true_vel)
color = plt.cm.hsv(np.linspace(0, 1, len(true_mis_flat)))
# print(len(color))
plt.scatter(range(len(true_vel_flat)), true_vel_flat, marker = ".", c = None, cmap = 'hsv', s = 10)
# plt.colorbar()

plt.title("True velocity", size= 15)
plt.xlabel("Test Node idx", size=15)
plt.ylabel("velocity (flow rate)", size=15)
# plt.semilogy()
plt.grid()

plt.show()


In [None]:

# print(true_vel[0])
# print(true_mis[0])
# print(len(true_vel))
# print(len(true_mis))

plt.figure(figsize=(15,10))

true_press_cpu = [tensor.cpu().numpy() for tensor in true_press]
true_press_flat = np.concatenate(true_press_cpu)
color = plt.cm.hsv(np.linspace(0, 1, len(true_mis_flat)))
# print(len(color))
plt.scatter(range(len(true_press_flat)), true_press_flat, marker = ".", c = None, cmap = 'Viridis', s = 10)
# plt.colorbar()

plt.title("True pressure", size= 15)
plt.xlabel("Test Node idx", size=15)
plt.ylabel("Pressure", size=15)
# plt.semilogy()
plt.grid()

plt.show()
