In [None]:
import numpy as np
from scipy.optimize import least_squares

In [None]:
activation_prime = np.cos

In [None]:
def glorot_uniform(shape, rng):
    fan_in, fan_out = shape[0], shape[1] if len(shape) > 1 else 1
    limit = np.sqrt(6.0 / (fan_in + fan_out))
    return rng.uniform(-limit, limit, size=shape)

In [None]:
def fit_two_layer_mlp_leastsq(x_values, y_values, H, *, activation=np.sin):
    """
    Fit y = w2 · phi(W1 x + b1) + b2  by least-squares.

    Args
    ----
    x_values : (B, N)  input samples
    y_values : (B,)    targets
    H        : int     hidden width
    activation : callable, element-wise non-linearity (default tanh)

    Returns
    -------
    params = { 'W1': (N,H), 'b1': (H,), 'w2': (H,), 'b2': float }
    """
    B, N = x_values.shape

    dim_W1 = N * H
    dim_b1 = H
    dim_w2 = H
    dim_b2 = 1
    total_dim = dim_W1 + dim_b1 + dim_w2 + dim_b2

    rng = np.random.default_rng(0)

    W1 = glorot_uniform((N, H), rng) # fan_in=N, fan_out=H
    b1 = np.zeros(H)

    w2 = glorot_uniform((H,), rng) # fan_in=H, fan_out=1
    b2 = np.zeros(1)

    theta0 = np.concatenate([W1.ravel(), b1, w2, b2])

    def unpack(theta):
        idx = 0
        W1 = theta[idx : idx+dim_W1].reshape(N, H); idx += dim_W1
        b1 = theta[idx : idx+dim_b1];              idx += dim_b1
        w2 = theta[idx : idx+dim_w2];              idx += dim_w2
        b2 = float(theta[idx])
        return W1, b1, w2, b2

    # residuals ---------------------------------------------------------------
    def residuals(theta):
        W1, b1, w2, b2 = unpack(theta)
        H1 = activation(x_values @ W1 + b1)        # (B, H)
        y_pred = H1 @ w2 + b2                      # (B,)
        return y_pred - y_values

    def jac(theta):
        W1, b1, w2, b2 = unpack(theta)
        Z  = x_values @ W1 + b1          # (B, H)
        H1 = activation(Z)               # (B, H)
        dphi = activation_prime(Z)       # cos or relu'
        # dr/dW1:   shape (B, N*H)
        J_W1 = (x_values[:,:,None] * (w2[None,:]*dphi)[:,None,:]).reshape(B, -1)
        # dr/db1:   shape (B, H)
        J_b1 = (w2[None,:]*dphi)
        # dr/dw2:   shape (B, H)
        J_w2 = H1
        # dr/db2:   shape (B, 1)
        J_b2 = np.ones((B,1))
        return np.hstack([J_W1, J_b1, J_w2, J_b2])

    sol = least_squares(residuals, theta0,
                        method='trf', xtol=1e-8,
                        ftol=1e-8, jac=jac)

    W1_opt, b1_opt, w2_opt, b2_opt = unpack(sol.x)
    return {'W1': W1_opt, 'b1': b1_opt, 'w2': w2_opt, 'b2': b2_opt,
            'activation': activation}


def two_layer_mlp_forward(X, params):
    """
    Evaluate the fitted MLP on a batch X (shape (B,N)).
    """
    W1, b1, w2, b2, act = (params['W1'], params['b1'],
                           params['w2'], params['b2'],
                           params['activation'])
    H1 = act(X @ W1 + b1)            # (B, H)
    y_pred = H1 @ w2 + b2            # (B,)
    return y_pred

In [None]:
def func1(x, y, a=1.5, b=1.0, c=0.5, d=0.5):
    return x**2 + y**2 - a * np.exp(-((x - 1)**2 + y**2) / c) - b * np.exp(-((x + 1)**2 + y**2) / d)

In [None]:
def func2(x, y, a=1, b=2):
    return (a - x)**2 + b*(y - x**2)**2

In [None]:
input_dim = 2
H = 10

val_range = np.linspace(1e-2, 1, 100)

xx, yy = np.meshgrid(val_range, val_range)  # both shape (N, N)

# Flatten to x_values
x_values = np.stack([xx.ravel(), yy.ravel()], axis=1)

errors = []
settings = []

for func in [func2]:
    y_values = func(x_values[:,0], x_values[:,1])
    for scalefactor in [1,2,3,4,5,6]:
        H_new = H * scalefactor
        params = fit_two_layer_mlp_leastsq(x_values, y_values, H_new, activation=np.sin)

        y_fit = two_layer_mlp_forward(x_values, params)
        error = np.linalg.norm(y_values - y_fit) / np.linalg.norm(y_values)
        print(f'H={H_new}Error = {error}')
        errors.append(error)
        settings.append((func.__name__, H_new))

H=10Error = 6.991048645563245e-05
H=20Error = 1.146296544559736e-05
H=30Error = 5.959944702412707e-06
H=40Error = 4.236680791898118e-06
H=50Error = 1.910689349602223e-06
H=60Error = 9.249514027234608e-07


In [None]:
import pickle as pkl
with open('/content/drive/MyDrive/mlp_2layer_sine_multi_dim_settings_func2.pkl', 'wb') as f:
    pkl.dump(settings, f)
with open('/content/drive/MyDrive/mlp_2layer_sine_multi_dim_errors_func2.pkl', 'wb') as f:
    pkl.dump(errors, f)