In [2]:
import json

from dataset import NormalizedProfilesDataset
from utils import load_normalization_metadata
from models import RNN_New, BasicRNN
from torch.utils.data import DataLoader
import torch

In [3]:
# Paths
model_params_path = "Data/model/model_parameters.json"
model_save_path = "Data/model/best_model.pth"
data_folder = "data/normalize_profiles"
save_path = "figures"

# Load the model parameters
with open(model_params_path, 'r') as f:
    model_params = json.load(f)

# Display the model parameters
print("Loaded model parameters:")
print(json.dumps(model_params, indent=4))

Loaded model parameters:
{
    "model_type": "BasicRNN",
    "RNN_type": "LSTM",
    "nx": 4,
    "ny": 1,
    "nx_sfc": 0,
    "nneur": [
        32,
        32
    ],
    "outputs_one_longer": false,
    "concat": false,
    "batch_size": 4,
    "learning_rate": 0.0001,
    "epochs": 500,
    "input_variables": [
        "pressure",
        "temperature",
        "Tstar",
        "flux_surface_down"
    ],
    "target_variables": [
        "net_flux"
    ]
}


In [4]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retrieve model_type
model_type = model_params.get("model_type", "RNN_New")

# Initialize the model dynamically based on model_type
if model_type == 'BasicRNN':
    model = BasicRNN(
        RNN_type=model_params['RNN_type'],
        nx=model_params['nx'],
        ny=model_params['ny'],
        nneur=tuple(model_params['nneur']),
        outputs_one_longer=model_params['outputs_one_longer'],
        concat=model_params['concat']
    )
elif model_type == 'RNN_New':
    model = RNN_New(
        RNN_type=model_params['RNN_type'],
        nx=model_params['nx'],
        ny=model_params['ny'],
        nneur=tuple(model_params['nneur']),
        outputs_one_longer=model_params['outputs_one_longer'],
        concat=model_params['concat']
    )
else:
    raise ValueError(f"Unknown model type: {model_type}")

# Load the trained model weights
model.load_state_dict(torch.load(model_save_path, map_location=device))
model.to(device)
model.eval()

print(f"model '{model_type}' loaded and ready for evaluation.")


  model.load_state_dict(torch.load(model_save_path, map_location=device))


Model 'BasicRNN' loaded and ready for evaluation.


In [5]:
# Load normalization metadata
normalization_metadata = load_normalization_metadata()

# Expected length of profiles
expected_length = 50

# Initialize the test dataset
test_dataset = NormalizedProfilesDataset(
    data_folder,
    expected_length=50,
    input_variables=['pressure', 'temperature', 'Tstar', 'flux_surface_down'],
    target_variables=['net_flux']
)

# Create DataLoader
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False)

print("Test dataset loaded.")


Test dataset loaded.


In [6]:
import torch
import numpy as np
import pickle as pkl
from torch.utils.data import DataLoader
from pysr import PySRRegressor

# Ensure model is in evaluation mode
model.eval()

# We'll gather inputs and outputs from the model
X_all = []
Y_all = []
Preds_all = []

# Iterate over a portion of the test data to build a dataset for symbolic regression
# For symbolic regression, a few thousand samples might be enough.
max_samples = 2000
count = 0

with torch.no_grad():
    for X_batch, Y_batch in test_loader:
        X_batch = X_batch.to(device)  # Move to GPU if available
        Y_batch = Y_batch.to(device)

        # Get model predictions
        Y_pred = model(X_batch)

        # Move data back to CPU and convert to numpy
        X_np = X_batch.cpu().numpy()
        Y_np = Y_batch.cpu().numpy()
        Y_pred_np = Y_pred.cpu().numpy()

        X_all.append(X_np)
        Y_all.append(Y_np)
        Preds_all.append(Y_pred_np)

        count += X_np.shape[0]
        if count >= max_samples:
            break

# Concatenate all arrays
X_all = np.concatenate(X_all, axis=0)  # shape: (N, seq_len, nx)
Y_all = np.concatenate(Y_all, axis=0)  # shape: (N, ny)
Preds_all = np.concatenate(Preds_all, axis=0)  # shape: (N, ny)

# If your model output is multi-dimensional (ny > 1), select one dimension for demonstration
if Preds_all.ndim > 1 and Preds_all.shape[1] > 1:
    Preds_all = Preds_all[:, 0]
    Y_all = Y_all[:, 0]

# Now we have:
# X_all: shape (N, seq_len, nx)
# Preds_all: shape (N,)
# We need a suitable input representation for PySR.
# PySR typically expects a 2D array of inputs (N, features).
# One common approach:
# Flatten sequence input or extract features from it.
# For demonstration, let's just flatten the sequence dimension:
N, seq_len, nx = X_all.shape
X_for_pysr = X_all.reshape(N, seq_len * nx)

Y_for_pysr = Preds_all  # We'll try to fit PySR to match model predictions.

# Optionally, choose a random subset for PySR
subset_size = 1000
rstate = np.random.RandomState(0)
idx = rstate.choice(N, size=subset_size, replace=False)
X_for_pysr = X_for_pysr[idx]
Y_for_pysr = Y_for_pysr[idx]

# Save the recorded data if desired
with open("rnn_recordings.pkl", "wb") as f:
    pkl.dump({"X_for_pysr": X_for_pysr, "Y_for_pysr": Y_for_pysr}, f)

# Perform symbolic regression with PySR
model_pysr = PySRRegressor(
    niterations=50,
    binary_operators=["+", "-", "*", "/"],
    unary_operators=["cos", "sin", "exp", "log", "square"],
    progress=True,
    random_state=0,
    deterministic=True,
    parallelism='serial'
)
model_pysr.fit(X_for_pysr, Y_for_pysr)

print("Discovered equations:")
print(model_pysr.equations_)

best_equation = model_pysr.get_best()
print("Best equation found by PySR:")
print(best_equation)




Compiling Julia backend...


[ Info: Started!



Expressions evaluated per second: 5.820e+04
Progress: 330 / 1550 total iterations (21.290%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           9.950e-01  1.594e+01  y = -0.0025429
2           9.810e-01  1.424e-02  y = cos(x₇₅)
3           7.175e-04  7.220e+00  y = x₁₃₉ * -0.99716
5           7.111e-04  4.528e-03  y = (x₁₄₇ * -0.99716) - 0.0025433
6           7.006e-04  1.482e-02  y = square(x₁₁₉ * 0.054007) - x₁₉₅
7           6.135e-04  1.327e-01  y = square(square(x₁₉₅) * -0.037592) - x₁₄₃
8           5.032e-04  1.981e-01  y = square(square(square(x₁₅₁) * 0.083594)) - x₁₉₅
10          4.803e-04  2.336e-02  y = square(square(square(0.083594 - x₁₅₁) * 0.083594)) - x...
                                      ₁₉₅
11          4.581e-04  4.718e-02  y = square(square(x₁₉₅ * (x₁₉₅ * -0.04943

[ Info: Final population:
[ Info: Results saved to:



Expressions evaluated per second: 5.380e+04
Progress: 933 / 1550 total iterations (60.194%)
════════════════════════════════════════════════════════════════════════════════════════════════════
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           9.950e-01  1.594e+01  y = -0.0025429
2           9.810e-01  1.424e-02  y = cos(x₇₅)
3           7.175e-04  7.220e+00  y = x₈₇ * -0.99716
5           7.111e-04  4.528e-03  y = (x₁₄₇ * -0.99716) - 0.0025433
6           6.951e-04  2.269e-02  y = (exp(x₇) * 0.0014811) - x₇
7           4.466e-04  4.423e-01  y = (exp(square(x₁₇₉)) * 7.8481e-05) - x₇
9           4.176e-04  3.355e-02  y = (exp(square(x₁₇₉)) * 7.8481e-05) - (x₇ + 0.0053842)
12          4.052e-04  1.005e-02  y = ((exp(square(x₁₄₃)) * 7.75e-05) - (0.0053793 + x₁₄₃)) ...
                                      / cos(-0.063419)
15          4.044e-04  6.642e-04  y = ((exp(square(x₁₄₃ / cos(-0

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe2 in position 4095: unexpected end of data

Error in callback _flush_stdio (for post_execute):


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe2 in position 4095: unexpected end of data

Error in callback _flush_stdio (for post_execute):


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe2 in position 4095: unexpected end of data

Error in callback _flush_stdio (for post_execute):


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe2 in position 4095: unexpected end of data

Error in callback _flush_stdio (for post_execute):


UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe2 in position 4095: unexpected end of data