In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import mean_squared_error

In [18]:
# True function
def true_function(x):
    return 1 + np.tanh(2 * (x - 2))

# Model definitions
def lin_0(x, b):
    return b*x

def lin(x, a, b):
    return a + b*x

def par(x, a, b, c):
    return a + b*x + c*x**2

def cubic(x, a, b, c, d):
    return a + b*x + c*x**2 + d*x**3

def poly_4(x, a, b, c, d, e):
    return a + b*x + c*x**2 + d*x**3 + e*x**4

def exp_model(x, a, b):
    return a * np.exp(-b * (x - 2))

def exp_plus(x, a, b, c):
    return a * np.exp(-b * (x - 2)) + c

# Define models
models = {
    "LIN_0": (lin_0, [1]),
    "LIN": (lin, [1, 1]),
    "PAR": (par, [1, 1, 1]),
    "CUBIC": (cubic, [1, 1, 1, 1]),
    "POLY-4": (poly_4, [1, 1, 1, 1, 1]),
    "EXP": (exp_model, [1, 1]),
    #"EXP^+": (exp_plus, [1, 1, 1])
}

In [6]:
def generate_data(n_points=50, noise_std=0.1):
    x = np.linspace(0, 3.5, n_points)
    y = true_function(x) + np.random.normal(0, noise_std, size=x.shape)
    return x, y

# Generate synthetic data
x, y = generate_data()

# Split data into calibration and generalization sets
x_cal, y_cal = x[:25], y[:25]
x_gen, y_gen = x[25:], y[25:]


In [None]:
(func, p0) = models["LIN_0"]
x_train, y_train, x_test, y_test = x_cal, y_cal, x_gen, y_gen
popt, _ = curve_fit(func, x_train, y_train, p0=p0)
y_pred_test = func(x_test, *popt)
mse = mean_squared_error(y_test, y_pred_test)
mse

In [None]:
def fit_and_evaluate(models, x_train, y_train, x_test, y_test):
    results = {}
    for name, (func, p0) in models.items():
        print(name)
        popt, _ = curve_fit(func, x_train, y_train, p0=p0)
        y_pred_test = func(np.array(x_test), *popt)
        mse = mean_squared_error(y_test, y_pred_test)
        results[name] = (popt, mse)
    return results

# Fit models and evaluate performance
results = fit_and_evaluate(models, x_cal, y_cal, x_gen, y_gen)

# Display results
for model, (params, mse) in results.items():
    print(f"{model}: Generalization MSE = {mse:.4f}")

# Plot results
plt.scatter(x, y, label="Data", color='black')
plt.plot(x, true_function(x), label="True Function", linestyle="dashed")
for model, (func, p0) in models.items():
    popt, _ = curve_fit(func, x, y, p0=p0)
    plt.plot(x, func(x, *popt), label=model)
plt.legend()
plt.xlabel("x")
plt.ylabel("y")
plt.title("Model Fits")
plt.show()

In [None]:
# Plot results (replicating Figure 5)
plt.figure(figsize=(8, 6))
plt.scatter(x, y, label="Data", color='black')
plt.plot(x, true_function(x), label="True Function", linestyle="dashed")
for model, (func, p0) in models.items():
    popt, _ = curve_fit(func, x, y, p0=p0)
    plt.plot(x, func(x, *popt), label=model)
plt.axvline(x=3.5, color='gray', linestyle='dotted', label='Extrapolation Boundary')
plt.legend()
plt.xlabel("x")
plt.ylabel("y")
plt.title("Model Fits with Extrapolation (Replicating Figure 5)")
plt.show()