In [None]:
import torch
from joblib import load
from visualization.visualization import *
from evaluation.evaluate_predictions import get_metrics
from data.data_processing import reshape_raveled_data, ravel_data
from utils.data import phys_unit_to_index
from utils.jmag import get_jmag
from globals.initialization import set_constants
import globals.constants as const

# Initialize constants
set_constants(no_cuda=True, noise=False)

In [None]:
# Load reconstruction model
model_path = '/path/to/model.pth'

# Pytorch
model = torch.load(model_path, map_location=torch.device('cpu'))

# kNNRegressor (Sklearn)
# model = load(model_path)

In [None]:
# Predict
# Pytorch
prediction = model(const.st_eval.float())

# kNNRegressor (Sklearn)
# prediction = model.forward(const.st_eval)

pred_reshaped = reshape_raveled_data(prediction, const.U_red.shape[1:])

In [None]:
# Calculate Jmags
jmag_real = get_jmag(const.x_red, const.y_red, const.t_red, const.U_red[5:]).unsqueeze(0)
jmag_pred = get_jmag(const.x_red, const.y_red, const.t_red, pred_reshaped[5:]).unsqueeze(0)

# Concatenate Jmag to MHD vectors
const.U_eval = torch.cat((const.U_eval, jmag_real.ravel().unsqueeze(1)), 1)
prediction = torch.cat((prediction, jmag_pred.ravel().unsqueeze(1)), 1)

const.U_red = torch.cat((const.U_red, jmag_real))
pred_reshaped = torch.cat((pred_reshaped, jmag_pred))

In [None]:
# Calculate metrics
get_metrics(const.U_eval[:, index], prediction[:, index])

In [None]:
# Choose physical unit
phys_unit = 'Density'
# phys_unit = 'Vx'
# phys_unit = 'Vy'
# phys_unit = 'Vz'
# phys_unit = 'P'
# phys_unit = 'Bx'
# phys_unit = 'By'
# phys_unit = 'Bz'
# phys_unit = 'Jmag'
index = phys_unit_to_index(phys_unit)

In [None]:
# Choose colormap
cmap = 'hot'
# cmap = 'rainbow'

In [None]:
# Create color plot
color_plot_for_specific_time(const.x_red, const.y_red, pred_reshaped, index, cmap=cmap)

In [None]:
# Animate color plot over time
color_plot_animation(const.x_red, const.y_red, const.t_red, pred_reshaped, index, cmap=cmap)

In [None]:
# Create scatter plot
scatter_plot(const.st_eval, prediction, index, cmap=cmap)

In [None]:
# Create binned heatmap
create_binned_heatmap_from_original_data(const.U_eval, prediction, index)

In [None]:
# Create line plot
x_data = const.y_red
line_plot(x_data, const.U_red, pred_reshaped, index, t_val=40)

In [None]:
# Create kernel density plot
kdp = kernel_density_plot(const.U_eval, prediction, index)

In [None]:
# Create bar plot
data = {
    'MSE': [0, 0, 0, 0, 0, 0, 0],
    'MAE': [0, 0, 0, 0, 0, 0, 0],
    'PC': [1, 1, 1, 1, 1, 1, 1]
}
xticks = ['None', 'Cuboid', 'Cylinder', 'Physical', 'Trade-off', 'Coefficient', 'Num-Diff']

bar_plot(data, xticks)