In [None]:
import numpy as np
from scipy.optimize import fsolve, least_squares

def generate_two_layer_phases(input_dim, G1, H, G2):
    """
    Generate orthogonal sine bases for a two-layer SineKAN model.

    Args:
        input_dim (int): Number of input dimensions (N).
        G1 (int): Grid size for Layer 1 (number of sines per input coordinate).
        H (int): Number of hidden units in Layer 1.
        G2 (int): Grid size for Layer 2 (number of sines per hidden unit).

    Returns:
        dict with:
            'alpha_1': np.ndarray of shape (input_dim, G1)
            'beta_2': np.ndarray of shape (H, G2)
    """
    # Phase offsets for Layer 1
    grid_phase1 = np.array([(k + 1) / (G1 + 1) for k in range(G1)])[None,:]
    input_phase1 = np.array([(n * np.pi) / (input_dim) for n in range(input_dim)])[:,None]
    alpha_1 = grid_phase1 + input_phase1


    # Phase offsets for Layer 2
    grid_phase2 = np.array([(k2 + 1) / (G2 + 1) for k2 in range(G2)])[None,:]
    input_phase2 = np.array([(h * np.pi) / (H) for h in range(H)])[:,None]
    beta_2 = grid_phase2 + input_phase2

    return {
        'alpha_1': alpha_1,
        'beta_2': beta_2
    }

def fit_two_layer_sinekan_leastsq(x_values, y_values, bases, G1, H, G2):
    """
    Args:
        x_values : np.ndarray, shape (B, input_dim)
            Input samples.
        y_values : np.ndarray, shape (B,)
            Target outputs.
        bases : dict with keys 'alpha_1','beta_2'
            Generated by generate_two_layer_bases().
        G1 : int, grid size for layer 1.
        H  : int, hidden‐layer dimension.
        G2 : int, grid size for layer 2.

    Returns:
        params : dict containing fitted arrays:
            'A1' of shape (input_dim, G1, H),
            'b1' of shape (H,),
            'A2' of shape (H, G2),
            'b2' a scalar.
    """

    B, input_dim = x_values.shape
    alpha_1 = bases['alpha_1']      # shape (input_dim, G1)
    beta_2  = bases['beta_2']       # shape (H, G2)

    # Pack all unknowns into one long vector theta of length:
    dim_A1 = input_dim * G1 * H
    dim_b1 = H
    dim_A2 = H * G2
    dim_b2 = 1
    dim_w1 = G1
    dim_w2 = G2
    total_dim = dim_A1 + dim_b1 + dim_A2 + dim_b2 + dim_w1 + dim_w2

    theta0 = np.ones(total_dim)*0.01

    freq_start = total_dim - (G1 + G2)
    theta0[freq_start : freq_start+G1] = np.pi * np.arange(G1)
    theta0[freq_start+G1 : ]           = np.pi * np.arange(G2) / G2

    # Helper to unpack theta -> (A1,b1,A2,b2)
    def unpack_params(theta):
        idx = 0
        A1 = theta[idx: idx+dim_A1].reshape(input_dim, G1, H)
        idx += dim_A1

        b1 = theta[idx: idx+dim_b1];                idx += dim_b1
        A2 = theta[idx: idx+dim_A2].reshape(H, G2); idx += dim_A2
        b2 = float(theta[idx]);                     idx += dim_b2

        omega_1 = theta[idx: idx+G1];               idx += G1
        omega_2 = theta[idx: idx+G2]

        return A1, b1, A2, b2, omega_1, omega_2

    # Define the residuals: for each sample i, r_i = y_pred(i; theta) - y_values[i]
    def residuals(theta):
        A1, b1, A2, b2, omega_1, omega_2 = unpack_params(theta)

        Phi_1 = np.sin(x_values[:, :, None] * omega_1[None, None, :] +
                  alpha_1[None, :, :])

        # --- Layer 1 forward: compute H1(i,h) for i=0..B-1, h=0..H-1
        #    H1[i,h] = sum_{n,k} A1[n,k,h] * Phi_1[i,n,k] + b1[h]
        #    We can flatten (n,k) -> one axis of length input_dim*G1.
        Phi_1_flat = Phi_1.reshape((B, input_dim * G1))      # shape (B, in*G1)
        H1 = np.zeros((B, H))
        for h in range(H):
            A1_h_flat = A1[:, :, h].reshape((input_dim * G1,))
            H1[:, h] = Phi_1_flat.dot(A1_h_flat) + b1[h]

        # --- Layer 2 forward: compute y_pred[i]
        #    y_pred[i] = sum_{h=0..H-1} sum_{k2=0..G2-1} A2[h,k2] * sin(omega_2[k2]*H1[i,h] + beta_2[h,k2]) + b2
        y_pred = np.zeros(B)
        for h in range(H):
            z = H1[:, h]  # shape (B,)
            for k2 in range(G2):
                y_pred += A2[h, k2] * np.sin(omega_2[k2] * z + beta_2[h, k2])
        y_pred += b2

        return y_pred - y_values  # shape (B,)

    def jac(theta):
        # --- unpack same as in your fit function ---
        A1, b1, A2, b2, omega1, omega2 = unpack_params(theta)
        # shapes:
        #   A1 (N, G1, H),    b1 (H,),
        #   A2 (H, G2),       b2 scalar,
        #   omega1 (G1,),     omega2 (G2,)

        B, N = x_values.shape
        _, G1, H = A1.shape
        _, G2    = A2.shape

        # --- forward pass ---
        # Layer-1 sine activations
        # phi1[i,n,k] = sin(x[i,n]*omega1[k] + alpha1[n,k])
        Phi1 = np.sin(
            x_values[:, :, None] * omega1[None, None, :] +
            bases['alpha_1'][None, :, :]
        )  # -> (B, N, G1)

        # flatten and compute H1
        Phi1_flat = Phi1.reshape(B, N*G1)                # (B, N*G1)
        A1_flat   = A1.reshape(N*G1, H)                  # (N*G1, H)
        H1        = Phi1_flat @ A1_flat + b1             # (B, H)

        # Layer-2 sine activations
        # phi2[i,h,k2] = sin(omega2[k2]*H1[i,h] + beta2[h,k2])
        Phi2 = np.sin(
            H1[:, :, None] * omega2[None, None, :] +
            bases['beta_2'][None, :, :]
        )  # -> (B, H, G2)

        # output
        y_pred = (A2[None, :, :] * Phi2).sum(axis=(1,2)) + b2  # (B,)

        # --- backprop derivatives ---
        # 1) residuals r = y_pred - y, so dr/dtheta = dy_pred/dtheta

        # --- output‐layer blocks ---
        # dy/db2 = 1
        J_b2 = np.ones((B,1))

        # dy/dA2[h,k2] = phi2[:,h,k2]
        J_A2 = Phi2.reshape(B, H*G2)

        # 2) derivatives w.r.t. H1 come through phi2's arg:
        #   d2[i,h] = Sum_{k2} A2[h,k2] * omega2[k2] * cos(omega2[k2]*H1[i,h] + beta2[h,k2])
        cos2 = np.cos(
            H1[:, :, None] * omega2[None, None, :] +
            bases['beta_2'][None, :, :]
        )  # (B, H, G2)
        # multiply by omega2 and A2, sum over k2
        d2 = (A2[None,:,:] * (omega2[None,None,:] * cos2)).sum(axis=2)  # (B, H)

        # dy/db1[h] = d2[:,h]
        J_b1 = d2                                                          # (B, H)

        # dy/dA1[n,k,h] = d2[:,h] * phi1[:,n,k]
        # build (B, N, G1, H) then reshape to (B, N*G1*H)
        J_A1 = (Phi1[:,:,:,None] * d2[:,None,None,:]).reshape(B, -1)       # (B, N*G1*H)

        # --- frequency derivatives ---

        # (a) dy/domega1[k]:
        #   H1[i,h] depends on omega1[k] via phi1,
        #   dphi1/domega1 = cos(x*omega1+alpha1)*x
        cos1 = np.cos(
            x_values[:, :, None] * omega1[None, None, :] +
            bases['alpha_1'][None, :, :]
        )  # (B, N, G1)
        dphi1_domega1 = cos1 * x_values[:,:,None]                              # (B, N, G1)
        # now dy/domega1[k] = Sum_{h,n} d2[i,h] * A1[n,k,h] * dphi1_domega1[i,n,k]
        # build (B, N, G1, H), multiply by d2, sum over n,h
        tmp1 = dphi1_domega1[:,:,:,None] * A1[None,:,:,:]  # (B, N, G1, H)
        J_omega1 = (tmp1 * d2[:,None,None,:]).sum(axis=(1,3))  # -> (B, G1)

        # (b) dy/domega2[k2]:
        #   phi2 = sin(omega2·H1 + beta2) ⇒ dphi2/domega2 = cos(...)*H1
        J_omega2 = (A2[None,:,:] * (H1[:,:,None] * cos2)).sum(axis=1)         # (B, G2)

        # --- stitch everything in the same order as theta0 ---
        return np.hstack([
            J_A1,       # (B, N*G1*H)
            J_b1,       # (B, H)
            J_A2,       # (B, H*G2)
            J_b2,       # (B, 1)
            J_omega1,       # (B, G1)
            J_omega2        # (B, G2)
        ])

    # Call Levenberg–Marquardt (method='lm') – note: 'lm' requires B >= total_dim, but if B < total_dim you can switch to 'trf' or 'dogbox'.
    sol = least_squares(residuals, theta0,
                        method='trf', xtol=1e-8,
                        ftol=1e-8, jac=jac)

    # Unpack final parameters
    A1_opt, b1_opt, A2_opt, b2_opt, omega1_opt, omega2_opt = unpack_params(sol.x)
    return {
        'A1': A1_opt,
        'b1': b1_opt,
        'A2': A2_opt,
        'b2': b2_opt,
        'omega_1': omega1_opt,
        'omega_2': omega2_opt
    }

# Forward with fitted parameters
def two_layer_forward(X, bases, params, H):
    B, _ = X.shape
    alpha_1 = bases['alpha_1']
    beta_2 = bases['beta_2']
    A1, b1 = params['A1'], params['b1']
    A2, b2 = params['A2'], params['b2']
    omega_1, omega_2 = params['omega_1'], params['omega_2']

    Phi_1 = np.sin(x_values[:, :, None] * omega_1 + alpha_1)   # (B, N, G1)
    H1 = Phi_1.reshape(B, -1) @ A1.reshape(-1, H) + b1         # (B, H)

    sin_H = np.sin(H1[:, :, None] * omega_2 + beta_2)          # (B, H, G2)
    y_pred = (A2[None, :, :] * sin_H).sum(axis=(1, 2)) + b2    # (B,)
    return y_pred

In [None]:
def func1(x, y, a=1.5, b=1.0, c=0.5, d=0.5):
    return x**2 + y**2 - a * np.exp(-((x - 1)**2 + y**2) / c) - b * np.exp(-((x + 1)**2 + y**2) / d)

In [None]:
def func2(x, y, a=1, b=2):
    return (a - x)**2 + b*(y - x**2)**2

In [None]:
input_dim = 2
G1, H, G2 = 2, 5, 1
assert G1 * H >= 2 * input_dim**2 + input_dim
assert H * G2 >= 2 * input_dim + 1

val_range = np.linspace(1e-2, 1, 100)

xx, yy = np.meshgrid(val_range, val_range)  # both shape (N, N)

# Flatten to x_values
x_values = np.stack([xx.ravel(), yy.ravel()], axis=1)

errors = []
settings = []

for func in [func2]:
    y_values = func(x_values[:,0], x_values[:,1])
    for scalefactor in [1,2,3]:
        G1_new = G1 * scalefactor
        G2_new = G2 * scalefactor
        for sf in [1,2,3]:
            H_new = H * sf

            bases = generate_two_layer_phases(input_dim, G1_new, H_new, G2_new)
            params = fit_two_layer_sinekan_leastsq(x_values, y_values, bases, G1_new, H_new, G2_new)

            y_fit = two_layer_forward(x_values, bases, params, H_new)
            error = np.linalg.norm(y_values - y_fit) / np.linalg.norm(y_values)
            print(f'G1={G1_new} H={H_new} G2={G2_new} Error = {error}')
            errors.append(error)
            settings.append((func.__name__, G1_new, H_new, G2_new))

In [None]:
import pickle as pkl
with open('/content/drive/MyDrive/sinekan_multi_dim_settings_func2.pkl', 'wb') as f:
    pkl.dump(settings, f)
with open('/content/drive/MyDrive/sinekan_multi_dim_errors_func2.pkl', 'wb') as f:
    pkl.dump(errors, f)