Once a model that performs as hoped is found, the whole goal of the interpolation nn is interpolating from a tabulated EoS of Neutron Stars, where the known x variables are temperature (T) and baryon number (n_b) and the y variables are pressure (P) and Entropy (S).

We know the true values of both the x and y variables (tabulated EoS in compose archive). To prove that the interpolating algorithm works good, we stick to a small region of the range and perform local interpolation, taking (x_1, x_3) that we do know and interpolating y_2 for x_2, from the test set.

### INTERPOLATION

1. Select Known Points:
Let’s assume we have points (T1, n1) and (T3, n3) that we know, where (T2, n2) will be the intermediate point. (T2, n2) need to be defined within the range of the known points.

2. Interpolation Procedure
Treat the prediction for (T2, n2) as an unknown during the actual interpolation process. During this phase, we use the model to predict (P2) and (S2).

3. Evaluation
Calculate the difference between predicted (P2) and (S2) from the model and the exact known values from the tabulated EoS. Metrics like Mean Absolute Error (MAE) or Mean Squared Error (MSE) to quantify how close the interpolated values are to the actual known values.

In [None]:
# Function for performing inverse transform to predictions to get physical values
def inv_log_scaled(values, scaler):
    real_val = scaler.inverse_transform(np.array(values).reshape(-1, 1)).flatten()
    physical_val = 10**real_val
    return physical_val

# Function to get triplet indices
def get_interpolation_indices(center_index, data_length, gap):
    idx1 = center_index - gap
    idx3 = center_index + gap
    if idx1 < 0 or idx3 >= data_length:
        return None
    return idx1, center_index, idx3

# Function to load activation from string
def load_activation_fn(activation_str):
    activation_map = {
        "ReLU": nn.ReLU,
        "LeakyReLU": nn.LeakyReLU,
        "SiLU": nn.SiLU
    }

    if "LeakyReLU" in activation_str:
        return activation_map["LeakyReLU"]
    elif "ReLU" in activation_str:
        return activation_map["ReLU"]
    elif "SiLU" in activation_str:
        return activation_map["SiLU"]
    else:
        raise ValueError(f"Activation function '{activation_str}' not recognized.")

In [None]:
#######################################################
# NN INTERPOLATION
#######################################################
def interpolation_test_nn(model, inputs_set, outputs_set, scaler_P, scaler_S, device, indices, save_dir, gap):
    os.makedirs(save_dir, exist_ok=True)
    metrics = []
    start_time = time.time()

    for i, center_index in enumerate(indices):
        idx1, idx2, idx3 = get_interpolation_indices(center_index, len(inputs_set), gap)
        if idx1 is None:
            print(f"Skipping index {center_index} (gap={gap}) due to limits.")
            continue

        x1, x2, x3 = in_test[idx1], in_test[idx2], in_test[idx3]
        y1, y2_true, y3 = out_test[idx1], out_test[idx2], out_test[idx3]

        # predict y2 with the trained NN
        x_tensor = torch.FloatTensor(x2).unsqueeze(0).to(device)
        model.eval()
        with torch.no_grad():
            y2_pred = model(x_tensor).cpu().numpy()[0]

        # Denormalize and inverse log transform
        P_true = 10 ** np.array([y1[0], y2_true[0], y3[0]])
        S_true = 10 ** np.array([y1[1], y2_true[1], y3[1]])
        P_pred = inv_log_scaled([y2_pred[0]], scaler_P)[0]
        S_pred = inv_log_scaled([y2_pred[1]], scaler_S)[0]

        # Metrics
        mse_P = mean_squared_error([P_true[1]], [P_pred])
        mae_P = mean_absolute_error([P_true[1]], [P_pred])
        mse_S = mean_squared_error([S_true[1]], [S_pred])
        mae_S = mean_absolute_error([S_true[1]], [S_pred])
        metrics.append({"method": "NN", "index": int(center_index), "mse_P": mse_P, "mae_P": mae_P, "mse_S": mse_S, "mae_S": mae_S})

        elapsed = time.time() - start_time
        for m in metrics:
            m["elapsed_sec"] = elapsed

        # Plots
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        x_labels = [f'Idx {idx1}', f'Idx {idx2}', f'Idx {idx3}']
        x_ticks = np.arange(3)

        axs[0].plot(x_ticks, P_true, label='True P', marker='o', color='green')
        axs[0].scatter(x_ticks[1], P_pred, label='Predicted P', marker='x', s=100, color='red')
        axs[0].set_title('Pressure (P)', fontsize=14)
        axs[0].set_xticks(x_ticks)
        axs[0].tick_params(axis='y', labelsize=14)
        axs[0].set_xticklabels(x_labels, fontsize=14)
        axs[0].grid(True)
        axs[0].legend(fontsize=14)
        axs[0].ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))

        axs[1].plot(x_ticks, S_true, label='True S', marker='o', color='blue')
        axs[1].scatter(x_ticks[1], S_pred, label='Predicted S', marker='x', s=100, color='orange')
        axs[1].set_title('Entropy (S)', fontsize=14)
        axs[1].set_xticks(x_ticks)
        axs[1].tick_params(axis='y', labelsize=14)
        axs[1].set_xticklabels(x_labels, fontsize=14)
        axs[1].grid(True)
        axs[1].legend(fontsize=14)
        axs[1].ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))

        plt.suptitle(f"NN Interpolation Triplet Test #{i+1} (Index {center_index})", fontsize=14)
        plt.tight_layout()
        filename = f"NN_triplet_{i+1}_index_{center_index}.png"
        plt.savefig(os.path.join(save_dir, filename), dpi=300)
        plt.close()
    return metrics

In [None]:
#######################################################
# b-SPLINES INTERPOLATION
#######################################################
def interpolation_test_bspline(in_train, out_train, in_test, out_test, indices, save_dir, gap=1):
    os.makedirs(save_dir, exist_ok=True)

    x, y = in_train[:, 0], in_train[:, 1]
    z_P, z_S = out_train[:, 0], out_train[:, 1]

    # Create the cubic Bspline based on training data
    bspline_P = SmoothBivariateSpline(x, y, z_P, kx=3, ky=3)
    bspline_S = SmoothBivariateSpline(x, y, z_S, kx=3, ky=3)

    metrics = []
    start_time = time.time()

    for i, center_index in enumerate(indices):
        idx1, idx2, idx3 = get_interpolation_indices(center_index, len(in_test), gap)

        if idx1 is None:
            print(f"Saltando índice {center_index} (gap={gap}) debido a límites.")
            continue

        x1, x2, x3 = in_test[idx1], in_test[idx2], in_test[idx3]
        y1, y2_true, y3 = out_test[idx1], out_test[idx2], out_test[idx3]

        # Predictions for y2
        P_pred_log = bspline_P.ev(x2[0], x2[1])
        S_pred_log = bspline_S.ev(x2[0], x2[1])

        # Inverse log transform
        P_true = 10 ** np.array([y1[0], y2_true[0], y3[0]])
        S_true = 10 ** np.array([y1[1], y2_true[1], y3[1]])
        P_pred = 10 ** P_pred_log
        S_pred = 10 ** S_pred_log

        # Metrics
        mse_P = mean_squared_error([P_true[1]], [P_pred])
        mae_P = mean_absolute_error([P_true[1]], [P_pred])
        mse_S = mean_squared_error([S_true[1]], [S_pred])
        mae_S = mean_absolute_error([S_true[1]], [S_pred])
        metrics.append({"method": "B-Spline", "index": int(center_index), "mse_P": mse_P, "mae_P": mae_P, "mse_S": mse_S, "mae_S": mae_S})

        elapsed = time.time() - start_time
        for m in metrics:
            m["elapsed_sec"] = elapsed

        # Plot
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        x_labels = [f'Idx {idx1}', f'Idx {idx2}', f'Idx {idx3}']
        x_ticks = np.arange(3)

        axs[0].plot(x_ticks, P_true, label='True P', marker='o', color='green')
        axs[0].scatter(x_ticks[1], P_pred, label='Predicted P', marker='x', s=100, color='red')
        axs[0].set_title('Pressure (P)', fontsize=14)
        axs[0].set_xticks(x_ticks)
        axs[0].set_xticklabels(x_labels, fontsize=14)
        axs[0].tick_params(axis='y', labelsize=14)
        axs[0].grid(True)
        axs[0].legend(fontsize=14)
        axs[0].ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))

        axs[1].plot(x_ticks, S_true, label='True S', marker='o', color='blue')
        axs[1].scatter(x_ticks[1], S_pred, label='Predicted S', marker='x', s=100, color='orange')
        axs[1].set_title('Entropy (S)', fontsize=14)
        axs[1].set_xticks(x_ticks)
        axs[1].tick_params(axis='y', labelsize=14)
        axs[1].set_xticklabels(x_labels, fontsize=14)
        axs[1].grid(True)
        axs[1].legend(fontsize=14)
        axs[1].ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))

        plt.suptitle(f"B-Spline Interpolation Triplet Test #{i+1} (Index {center_index})", fontsize=14)
        plt.tight_layout()
        filename = f"SPLINE_triplet_{i+1}_index_{center_index}.png"
        plt.savefig(os.path.join(save_dir, filename), dpi=300)
        plt.close()
    return metrics

In [None]:
# Upload the model and instantiate
top_path = #path where top model is saved
config_path = os.path.join(top_path, "config.json")
weights_path = os.path.join(top_path, "model_top1.pth")

def load_model_from_config(config_path, in_size, out_size):
    with open(config_path, "r") as f:
        config = json.load(f)
    hidden_layer_sizes = config.get("hidden_layer_sizes")
    dropout = config.get("dropout")
    activation_str = config.get("activation_fn")
    activation_cls = load_activation_fn(activation_str)

    model = Interpolation(in_size, out_size, hidden_layer_sizes, dropout, activation_cls)
    return model

in_size, out_size = 2, 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = load_model_from_config(config_path, in_size, out_size)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)

# 5 random index for both methods
max_index = len(in_test)
indices = random.sample(range(max_index), 5)
print(indices)

dir_nn = # path for saving nn interpolation
os.makedirs(dir_nn, exist_ok=True)
dir_spl = #path for saving bsplines interpolation
os.makedirs(dir_spl, exist_ok=True)

In [None]:
# 1. NN interpolation
metrics_nn = interpolation_test_nn(model=model, inputs_set=in_test, outputs_set=out_test, scaler_P=scaler_P, scaler_S=scaler_S, device=device, indices=indices, save_dir=dir_nn)
for metric in metrics_nn:
    print(metric)

# 2. B-Spline interpolation
metrics_spline = interpolation_test_bspline(in_train=in_train, out_train=out_train, in_test=in_test, out_test=out_test, indices=indices, save_dir=dir_spl)
for metric in metrics_spline:
    print(metric)

# 3. Save metrics
save_dir = # path to save metrics
os.makedirs(save_dir, exist_ok=True)

df_all = pd.DataFrame(metrics_nn + metrics_spline)
csv_path = os.path.join(dir_nn, "interpolation_metrics_comparison.csv")
df_all.to_csv(csv_path, index=False)