# Example 2: High-dimensional function fitting
## SI Section 2.4: Regression of deterministic physical functions
### Reference: INN paper https://www.nature.com/articles/s41467-025-63790-8

In this example, we fit a 10-input 5-output function described in the SI Section 2.4 of the INN paper.

From this function, 1,000,000 randomly sampled data using Latin hypercube sampling are generated and divided into 70% for training, 15% for validation, and 15% for testing.

Import the pyinn package.

In [1]:
from jax import config
import jax.numpy as jnp
config.update("jax_enable_x64", True)
import os, sys

# Add pyinn to path for development
sys.path.insert(0, '../pyinn')
import dataset_regression, train, plot

Set up GPUs (optional)

In [None]:
# gpu_idx = 0
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_idx)

Create the dataset. The datafile will be named "10D_5D_physics_1000000.csv"

In [2]:
data_name = '10D_5D_physics'
config = {}
config["DATA_PARAM"] = {
    "data_name": data_name,
    "input_col": [0,1,2,3,4,5,6,7,8,9],
    "output_col": [10,11,12,13,14],
    "bool_data_generation": True,
    "data_size": 1000000,
    "split_ratio": [0.7, 0.15, 0.15],
    "bool_normalize": True,
    "bool_shuffle": True
}
config["TRAIN_PARAM"] = {
    "num_epochs_INN": 100,
    "num_epochs_MLP": 100,
    "batch_size": 128,
    "learning_rate": 1e-3,
    "validation_period": 10,
    "patience": 10,
    "stopping_loss_train": 4e-4
}

data = dataset_regression.Data_regression(data_name, config)

Data file data/10D_5D_physics_1000000.csv does not exist. Creating data...
Loaded 1000000 datapoints from 10D_5D_physics dataset
  Train: 700000, Val: 150000, Test: 150000


Visualize the data.

In [3]:
import pandas as pd

data_size = int(config["DATA_PARAM"]["data_size"])
df = pd.read_csv(f'./data/{data_name}_{data_size}.csv')
df.head()

Unnamed: 0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,u1,u2,u3,u4,u5
0,0.401902,0.718854,0.232017,0.161808,0.884976,0.317746,0.982725,0.995127,0.297908,0.783719,-208.080275,0.30252,5.975549,2.51571,252.29111
1,0.758739,0.616715,0.929224,0.980533,0.778406,0.205109,0.82132,0.972133,0.572294,0.893597,-235.012683,0.498989,5.145345,1.239977,357.321692
2,0.955968,0.734802,0.820269,0.735265,0.477393,0.686022,0.281908,0.539758,0.197026,0.908514,-254.875432,0.489113,5.044315,0.256249,330.661354
3,0.753676,0.176305,0.015002,0.073013,0.258551,0.901161,0.419439,0.129414,0.805084,0.276157,-76.062427,0.369208,3.545304,1.549309,218.844478
4,0.373259,0.718595,0.326246,0.065567,0.07453,0.364601,0.876408,0.697464,0.868685,0.931163,-173.567354,0.341459,6.258758,1.420536,266.979629


The datafile contains 15 columns: x1~x10 (inputs) and u1~u5 (outputs).

Define INN hyperparameters according to Figure 4(a) of the INN paper:
* nmode: 14
* nseg: 10
* s_patch: 2
* INNactivation: polynomial
* p_order: 2

In [4]:
config["MODEL_PARAM"] = {
    "nmode": 14,
    "nseg": 10,
    "s_patch": 2,
    "INNactivation": "polynomial",
    "p_order": 2,
    "radial_basis": "cubicSpline",
    "alpha_dil": 20
}

Finalize configuration.

In [5]:
if config["MODEL_PARAM"]["s_patch"] > 0:
    config['interp_method'] = "nonlinear"
else:
    config['interp_method'] = "linear"
config['TD_type'] = "CP"

Train INN

In [6]:
regressor = train.Regression_INN(data, config)
regressor.train()

params = regressor.params
errors_train = regressor.errors_train
errors_val = regressor.errors_val
error_test = regressor.error_test

edex_max / ndex_max: 6 / 5
------------ INN CP nonlinear, nmode: 14, nseg: 10, s=2, P=2 -------------
# of training parameters: 7700
Reached target training loss at epoch 2
Training completed in 31.07 seconds
Test rmse: 1.7020e-02
Inference time: 0.000000 seconds
