In [13]:
import numpy as np
from scipy.optimize import minimize


def objective(params, X, Y, sigma_x, sigma_1, sigma_2, N_i, N_o, N_h):
    """
    Objective function to minimize.
    """
    # Extract W_1 and W_2 from the parameter vector
    num_hidden, num_input = W_1_shape
    num_output, _ = W_2_shape
    W_1 = params[: num_hidden * num_input].reshape(num_hidden, num_input)
    W_2 = params[num_hidden * num_input :].reshape(num_output, num_hidden)

    # Compute terms in the objective
    term1 = (
        0.5
        * sigma_x**2
        * (
            np.linalg.norm(W_2 @ W_1, "fro") ** 2
            + sigma_1**2 * N_i * np.linalg.norm(W_2, "fro") ** 2
            + sigma_2**2 * N_o * np.linalg.norm(W_1, "fro") ** 2
        )
    )
    term2 = (1 / (2 * X.shape[-1])) * (
        sigma_1**2 * np.linalg.norm(W_2, "fro") ** 2 * np.linalg.norm(X, "fro") ** 2
        + N_o * sigma_2**2 * np.linalg.norm(W_1 @ X, "fro") ** 2
        + N_h * sigma_1**2 * N_o * sigma_2**2 * np.linalg.norm(X, "fro") ** 2
    )

    return term1 + term2


def constraint(params, X, Y):
    """
    Constraint: W_2W_1XX^T = YX^T
    """
    # Extract W_1 and W_2 from the parameter vector
    num_hidden, num_input = W_1_shape
    num_output, _ = W_2_shape
    W_1 = params[: num_hidden * num_input].reshape(num_hidden, num_input)
    W_2 = params[num_hidden * num_input :].reshape(num_output, num_hidden)

    # Compute the constraint
    return np.linalg.norm(W_2 @ W_1 @ X @ X.T - Y @ X.T, "fro")


# Problem setup
np.random.seed(42)
num_input = 5
num_hidden = 5
num_output = 5

W_1_shape = (num_hidden, num_input)
W_2_shape = (num_output, num_hidden)

X = np.random.randn(num_input, 100)  # Input matrix
Y = np.random.randn(num_output, 100)  # Target matrix

sigma_x = 1.0
sigma_1 = 1.0
sigma_2 = 1.0
N_i = 100
N_o = 100
N_h = 5

# Initial guess for W_1 and W_2
W_1_init = np.random.randn(*W_1_shape)
W_2_init = np.random.randn(*W_2_shape)
initial_params = np.concatenate([W_1_init.ravel(), W_2_init.ravel()])

# Constraints and bounds
con = {"type": "eq", "fun": constraint, "args": (X, Y)}

# Optimize
result = minimize(
    objective,
    initial_params,
    args=(X, Y, sigma_x, sigma_1, sigma_2, N_i, N_o, N_h),
    constraints=[con],
    method="SLSQP",
    options={"disp": True},
)

# Extract the optimized W_1 and W_2
optimized_params = result.x
W_1_optimized = optimized_params[: num_hidden * num_input].reshape(
    num_hidden, num_input
)
W_2_optimized = optimized_params[num_hidden * num_input :].reshape(
    num_output, num_hidden
)

Iteration limit reached    (Exit mode 9)
            Current function value: 1348.5364249903894
            Iterations: 100
            Function evaluations: 5267
            Gradient evaluations: 100


In [14]:
np.linalg.norm(W_1_optimized)

np.float64(0.8746193108956525)

In [15]:
np.linalg.norm(W_2_optimized)

np.float64(1.1810860124897589)

In [17]:
np.linalg.trace(X @ X.T)

np.float64(480.4894888786302)