This notebook contains code to run simulations for feature learning in various infinite-width neural network regimes. In particular, we focus on:
1. Mean field
2. Mean field Langevin dynamics
3. Neural tangent kernel


In [20]:
import numpy as np
import cupy as cp

import matplotlib.pyplot as plt

### Learning k-sparse parities (refer to [Suzuki et al., 2023](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6cc321baf0a8611b1d1bdbd18822667b-Abstract-Conference.html))

Consider the XOR problem where $k=2$ (Wei et al., 2019), we can generate our data from hypercube $\{ \pm 1 / \sqrt{d} \}^d$

In [21]:
def generate_xor_data(n, d):
    # Generate data from {±1/sqrt(d)}^d and labels as XOR of first 2 features.
    X = cp.random.choice([1, -1], size=(n, d)) / cp.sqrt(d)
    y = cp.sign(X[:, 0] * X[:, 1])
    return X, y

Note that we define our neuron $h_z(x)$ with parameters $x = (x_1, x_2, x_3) \in \mathbb{R}^{d + 1 + 1}$ as such:
$$h_{x}(z) = \bar{R} \cdot \tanh(z^T x_1 + x_2) + \frac{2 \tanh(x_3)}{3}$$



In [22]:
def neuron_activation(x, z, R_bar=15):
    # x is in R^(d+2): x[:d] are weights, x[d] is a bias for tanh(z^T x[:d]), and x[d+1] is a further bias.
    d = len(z)
    a = cp.dot(z, x[:d]) + x[d]
    b = x[d+1]
    return R_bar * (cp.tanh(a) + 2 * cp.tanh(b)) / 3.0

Some derivatives for grad calculations

In [23]:
def dtanh(x):
    return 1.0 - cp.tanh(x)**2

In [24]:
def neuron_grad(x, z, R_bar=15):
    # Compute gradients of h_x with respect to the parameters x.
    d = len(z)
    a = cp.dot(z, x[:d]) + x[d]
    b = x[d+1]
    grad_x1 = (R_bar/3.0) * dtanh(a) * z          # gradient w.r.t. x1
    grad_bias = cp.array([R_bar/3.0 * dtanh(a)])      # gradient w.r.t. bias x[d]
    grad_bias2 = cp.array([2*R_bar/3.0 * dtanh(b)])   # gradient w.r.t. bias x[d+1]
    return cp.concatenate([grad_x1, grad_bias, grad_bias2])

In [25]:
def logistic_loss(f, y):
    return cp.log(1 + cp.exp(-y * f))

In [26]:
def d_logistic_loss(f, y):
    return - y / (1 + cp.exp(y * f))

Recall our neural network formulation: 

$$f(x; \theta) = \frac{1}{M} \sum_{i=1}^{M} \sigma \langle x, \theta_i \rangle$$

where $\sigma \langle x, \theta_i \rangle = h_{x}(z_i)$

In [27]:
def predict(particles, z, R_bar=15):
    outputs = cp.array([neuron_activation(x, z, R_bar) for x in particles])
    return cp.mean(outputs)

In [28]:
def evaluate_accuracy(particles, X, y, R_bar=15):
    preds = cp.array([cp.sign(predict(particles, x, R_bar)) for x in X])
    return cp.mean(preds == y)

Below follows from our update function in MFLD:

$$dX_t = - \nabla \frac{\delta F(\mu_t)}{\delta \mu}(X_t)dt + \sqrt{2 \lambda}dW_t$$

Recall:
$\frac{\delta F(\mu)}{\delta \mu}(x) = \frac{1}{n} \sum_{i=1}^{n} l'(y_i f_{\mu}(z_i))y_i h_x(z_i) + \lambda (\lambda_1 \|x\|^2)$

Hence:

$X_{\tau+1}^{i} = X_{\tau}^{i} - \eta \nabla \frac{\delta F(\mu_r)}{\delta \mu}(X_r^i) + \sqrt{2\lambda\eta} \xi_{\tau}^{i}$

where $\xi_{\tau}^{i} \sim N(0, I)$

In [32]:
def update_particles(particles, X, y, eta, lam, lambda_1, R_bar=15):
    """
    Note that the update step follows from eq (4) of (Suzuki et al., 2023)
    Our L2 regularization update is implicity included in \nabla \frac{\delta F(\mu_i)}}{\delta \mu} of eq (4)
    args:
        lam: denotes the variance of the noise updates in Langevin dynamics (i.e. \sqrt{2 \lambda \eta} W, W \sim Brownian)
        lambda_1: denotes strength of L2 reg.
    """
    n, d = X.shape
    N = len(particles)
    # Compute network output for each data point
    f_preds = cp.zeros(n)
    for i in range(n):
        f_preds[i] = cp.mean(cp.array([neuron_activation(x, X[i], R_bar) for x in particles]))
    # Compute derivative of the loss for each data point
    loss_derivs = cp.array([d_logistic_loss(f_preds[i], y[i]) for i in range(n)])
    
    new_particles = []
    for x in particles:
        grad_loss = cp.zeros_like(x)
        # Accumulate gradient contributions from all training data
        for i in range(n):
            grad_loss += (loss_derivs[i] * y[i]) * neuron_grad(x, X[i], R_bar)
        grad_loss = grad_loss / n
        grad_reg = 2 * lambda_1 * x
        grad = grad_loss + grad_reg
        noise = cp.random.randn(*x.shape)
        x_new = x - eta * grad + cp.sqrt(2 * lam * eta) * noise
        new_particles.append(x_new)
    return new_particles

Recall: the annealing process described in [Suzuki et al., 2023]. Convergence rate is dependent on LSI constant hence by decreasing $\lambda$,
 
$\lambda^{(k)} = 2^{(-\kappa)}\lambda^{(0)}$ 

In [33]:
def run_simulation(d=20, n_train=500, n_test=200, 
                   num_particles=1000, eta=0.05, T_per_round=200, 
                   num_rounds=6, lam_init=0.1, R_bar=15):
    # Generate training and test data
    X_train, y_train = generate_xor_data(n_train, d)
    X_test, y_test = generate_xor_data(n_test, d)
    # Initialize particles randomly
    particles = [cp.random.randn(d+2) for _ in range(num_particles)]
    
    accuracy_history = []
    lam_history = []
    
    lam_current = lam_init
    lambda_1_value = 0.1 # Suzuki et al., 2023
    for round in range(num_rounds):
        print(f"Annealing Round {round+1}, lambda = {lam_current:.5f}")
        for t in range(T_per_round):
            particles = update_particles(particles, X_train, y_train, eta, lam_current, lambda_1_value, R_bar)
        acc = evaluate_accuracy(particles, X_test, y_test, R_bar)
        print(f"Test accuracy after round {round+1}: {acc:.3f}")
        accuracy_history.append(acc)
        lam_history.append(lam_current)
        lam_current *= 0.5  # Anneal: reduce lambda by factor of 2.
    return accuracy_history, lam_history

In [35]:
def run_simulation_and_store(d=20, n_train=500, n_test=200,
                             num_particles=1000, eta=0.05, T_per_round=200,
                             num_rounds=6, lam_init=0.1, lambda_1=0.1, R_bar=15):

    X_train, y_train = generate_xor_data(n_train, d)
    X_test, y_test = generate_xor_data(n_test, d)
    particles = [cp.random.randn(d+2, dtype=cp.float32) for _ in range(num_particles)]

    particle_history = {} # Dictionary to store particle states
    accuracy_history = []
    lam_history = []

    print("Storing initial particle state...")
    particle_history[0] = [p.copy() for p in particles] # Store initial state (round 0)

    lam_current = lam_init
    for round_num in range(num_rounds):
        print(f"Annealing Round {round_num+1}, lambda = {lam_current:.5f}")
        for t in range(T_per_round):
            # Pass lambda_1 to update_particles
            particles = update_particles(particles, X_train, y_train, eta, lam_current, lambda_1, R_bar)
            print(f"Finished timestep: {t+1} / {T_per_round}")
            # Optional: could store particles more frequently within a round if needed
            # if (t+1) % 50 == 0: print(f"  Iteration {t+1}/{T_per_round}")

        # Store particles after the round
        print(f"Storing particle state after round {round_num+1}...")
        particle_history[round_num+1] = [p.copy() for p in particles]

        # Evaluate accuracy
        acc = evaluate_accuracy(particles, X_test, y_test, R_bar)
        # Transfer acc to CPU if it's a cupy scalar
        acc_cpu = cp.asnumpy(acc) if isinstance(acc, cp.ndarray) else acc
        print(f"Test accuracy after round {round_num+1}: {acc_cpu:.4f}")
        accuracy_history.append(acc_cpu)
        lam_history.append(lam_current)

        # Anneal
        lam_current *= 0.5

    return particle_history, accuracy_history, lam_history

Note, I am slightly confused why Suzuki et al., 2023 choose $\bar R = 15$ for their experimental results when they mention in paper that $\bar R = k$ 

In [36]:
cp.random.seed(42)

particle_history, accuracy_history, lam_history = run_simulation_and_store(
    d=20, n_train=500, n_test=200,
    num_particles=200, # Reduced particles for faster testing
    eta=0.05, T_per_round=50, # Reduced steps for faster testing
    num_rounds=6, lam_init=0.1,
    lambda_1=0.1, # Provide lambda_1
    R_bar=15
)

# Plot the test accuracy evolution over annealing rounds
plt.figure(figsize=(8,6))
plt.plot(cp.arange(1, len(accuracy_history)+1), accuracy_history, marker='o')
plt.xlabel('Annealing Round')
plt.ylabel('Test Accuracy')
plt.title('Test Accuracy over Annealing Rounds')
plt.grid(True)
plt.show()

# Plot lambda evolution versus annealing round
plt.figure(figsize=(8,6))
plt.plot(cp.arange(1, len(lam_history)+1), lam_history, marker='o', color='r')
plt.xlabel('Annealing Round')
plt.ylabel('Lambda')
plt.title('Regularization Parameter over Annealing Rounds')
plt.grid(True)
plt.show()

Storing initial particle state...
Annealing Round 1, lambda = 0.10000
Finished timestep: 0 / 50
Finished timestep: 1 / 50
Finished timestep: 2 / 50
Finished timestep: 3 / 50
Finished timestep: 4 / 50
Finished timestep: 5 / 50
Finished timestep: 6 / 50
Finished timestep: 7 / 50
Finished timestep: 8 / 50
Finished timestep: 9 / 50
Finished timestep: 10 / 50
Finished timestep: 11 / 50
Finished timestep: 12 / 50
Finished timestep: 13 / 50
Finished timestep: 14 / 50
Finished timestep: 15 / 50
Finished timestep: 16 / 50
Finished timestep: 17 / 50
Finished timestep: 18 / 50


KeyboardInterrupt: 

Let's see if there is any way to visualize the feature learning here. Perhaps we can PCA the neuron weights and visualize how they might change with training (following similar methodology to what was presented in Yang et al., 2022 in their abc-param paper)

In [None]:
# PCA neuron weights over time

print("\n--- Generating PCA Visualization ---")
plt.figure(figsize=(10, 8))
num_rounds_to_plot = len(particle_history)
colors = plt.cm.viridis(np.linspace(0, 1, num_rounds_to_plot))

pca = PCA(n_components=2)
scaler = StandardScaler() 

for round_num, particles_cp in particle_history.items():
    print(f"Processing PCA for round {round_num}")

    particle_matrix_cp = cp.stack(particles_cp) # dims: (num_particles, d+2)

    # Transfer to CPU for scikit-learn
    particle_matrix_np = cp.asnumpy(particle_matrix_cp)

    particle_matrix_scaled = scaler.fit_transform(particle_matrix_np)

    # Fit PCA and transform data
    pca.fit(particle_matrix_scaled)
    particles_pca = pca.transform(particle_matrix_scaled)

    # Plot
    plt.scatter(particles_pca[:, 0], particles_pca[:, 1],
                alpha=0.6, label=f'Round {round_num}', color=colors[round_num], s=10)

plt.title('PCA of MFLD Particle Parameters Over Annealing Rounds')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Visualize how certain weights change over time

print("\n--- Generating Weight Component Visualization ---")
# compare initial and final rounds
rounds_to_compare = [0, len(particle_history) - 1]
d_value = 20
num_bins = 50

fig, axes = plt.subplots(1, len(rounds_to_compare), figsize=(12, 5), sharey=True)
fig.suptitle('Distribution of Relevant vs. Irrelevant Weight Components')

for i, round_num in enumerate(rounds_to_compare):
    print(f"Processing weight distributions for round {round_num}")
    ax = axes[i]
    particles_cp = particle_history[round_num]
    particle_matrix_cp = cp.stack(particles_cp) # shape (num_particles, d+2)

    # Extract weight vectors (first d components)
    weights_cp = particle_matrix_cp[:, :d_value]

    # Separate relevant (first 2) and irrelevant weights
    # recalll, XOR problem looks at first two input dimensions (perhaps first two feature weights are also changed more)
    relevant_weights_cp = weights_cp[:, :2].flatten()
    irrelevant_weights_cp = weights_cp[:, 2:].flatten()

    # Transfer to CPU for plotting
    relevant_weights_np = cp.asnumpy(relevant_weights_cp)
    irrelevant_weights_np = cp.asnumpy(irrelevant_weights_cp)

    # Plot histograms
    bin_min = min(relevant_weights_np.min(), irrelevant_weights_np.min())
    bin_max = max(relevant_weights_np.max(), irrelevant_weights_np.max())
    bins = np.linspace(bin_min, bin_max, num_bins)

    ax.hist(relevant_weights_np, bins=bins, alpha=0.7, label='Relevant Weights (Dims 1-2)', density=True)
    ax.hist(irrelevant_weights_np, bins=bins, alpha=0.7, label=f'Irrelevant Weights (Dims 3-{d_value})', density=True)
    ax.set_title(f'Round {round_num}')
    ax.set_xlabel('Weight Value')
    if i == 0:
        ax.set_ylabel('Density')
    ax.legend()
    ax.grid(True)

plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
plt.show()

print("\n--- Generating Scatter Plot of Relevant Weights ---")
plt.figure(figsize=(10, 8))

for round_num, particles_cp in particle_history.items():
    print(f"Processing scatter plot for round {round_num}")
    particle_matrix_cp = cp.stack(particles_cp) # shape (num_particles, d+2)
    weights_cp = particle_matrix_cp[:, :d_value]
    relevant_weights_cp = weights_cp[:, :2] # Shape (num_particles, 2)

    # Transfer to CPU
    relevant_weights_np = cp.asnumpy(relevant_weights_cp)

    # Plot
    plt.scatter(relevant_weights_np[:, 0], relevant_weights_np[:, 1],
                alpha=0.3, label=f'Round {round_num}', color=colors[round_num], s=10)

plt.title('Scatter Plot of First Two Weight Components Over Annealing Rounds')
plt.xlabel('Weight Component 1')
plt.ylabel('Weight Component 2')
# plt.legend()
plt.grid(True)
plt.axhline(0, color='grey', lw=0.5)
plt.axvline(0, color='grey', lw=0.5)
plt.show()