# Mean-field theory of asymmetric Ising models with vector spins

In [None]:
import jax
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np


key = jax.random.PRNGKey(2666)
N = 64
J = N**-0.5 * jax.random.normal(key, shape=(N, N))

G = nx.from_numpy_array(np.array(J), create_using=nx.DiGraph)
pos = nx.circular_layout(G)
edge_weights = [G[u][v]["weight"] for u, v in G.edges()]

plt.figure(figsize=(8, 8))

nodes = nx.draw_networkx_nodes(
    G, pos, node_size=100, node_color="white", edgecolors="black"
)
edges = nx.draw_networkx_edges(
    G,
    pos,
    arrowstyle="->",
    arrowsize=10,
    edge_color=edge_weights,
    edge_cmap=plt.cm.RdBu_r,
    width=2 * edge_weights,
)

for i in range(G.number_of_edges()):
    edges[i].set_alpha(0.9)

ax = plt.gca()
ax.set_box_aspect(1)
ax.set_axis_off()
plt.show()

In [None]:
from functools import partial

import jax
import jax.numpy as jnp

from jaxopt import AndersonAcceleration


def _gamma(x, beta, r):
    """See Eq. (39)."""
    return jnp.sqrt(1 + beta**2 * jnp.sum(x**2, axis=-1, keepdims=True) / r**2)


def _phi(theta, beta, r):
    """See Eq. (38)."""
    return beta / (1 + _gamma(theta, beta, r)) * theta


def update_naive_mf(m0, _, x, J, beta, r):
    """See Eq. (47)."""
    theta = x + jnp.einsum("i j, j d -> i d", J, m0)
    m1 = _phi(theta, beta, r)
    return m1, m0


def _inv_phi(m, beta, r):
    """See Eq. (64)."""
    return 2 * r**2 / (beta * (r**2 - jnp.sum(m**2, axis=-1, keepdims=True))) * m


def _d2_m_d_alpha_2(m1, m0, x, J, beta, r):
    """See Eq. (58)."""
    g0 = _gamma(_inv_phi(m0, beta, r), beta, r)
    g1 = _gamma(_inv_phi(m1, beta, r), beta, r)
    v = -_inv_phi(m1, beta, r) + x + jnp.einsum("i j, j d -> i d", J, m0)

    return (
        (beta**2 * (1 + 3 * g1))
        / (r**4 * g1**3)
        * (
            jnp.einsum("i d, i d -> i", m1, v)[:, None] ** 2
            + jnp.einsum(
                "i j, i d -> i d",
                J**2,
                jnp.sum(m1**2, axis=-1, keepdims=True),
            )
            / (1 + g0)
            - jnp.einsum(
                "i j, i d, j d, i e, j e -> i",
                J**2,
                m1,
                m0,
                m1,
                m0,
            )[:, None]
            / (r**2 * g0)
        )
        * m1
        - (beta**2)
        / (r**2 * (g1**2 + g1))
        * (
            jnp.sum(v**2, axis=-1, keepdims=True)
            + jnp.einsum(
                "i j, j -> i",
                J**2,
                r**2 - jnp.sum(m0**2, axis=-1),
            )[:, None]
        )
        * m1
        - 2.0
        * beta**2
        / (r**2 * (g1**2 + g1))
        * (
            jnp.einsum("i d, i d, i f -> i f", v, m1, v)
            + jnp.einsum("i j, i d -> i d", J**2, m1 / (1 + g0))
            - jnp.einsum(
                "i j, i d, j d, j f -> i f",
                J**2,
                m1,
                m0,
                m0,
            )
            / (r**2 * g0)
        )
    )


def _f(m1, m0, x, J, beta, r):
    """See Eq. (61)."""
    g1 = _gamma(_inv_phi(m1, beta, r), beta, r)
    d2_m_d_alpha_2 = _d2_m_d_alpha_2(m1, m0, x, J, beta, r)

    ff = (
        (1 + g1)
        / (2 * beta)
        * (
            d2_m_d_alpha_2
            + (
                jnp.einsum("i d, i d -> i", m1, d2_m_d_alpha_2)[:, None]
                / ((r**2 * g1) / (1 + g1) - jnp.sum(m1**2, axis=-1, keepdims=True))
                * m1
            )
        )
    )
    return x + jnp.einsum("i j, j d -> i d", J, m0) + ff


def update_tap_mf(m0, _, x, J, beta, r):
    """See Eq. (65)."""

    def tap(m1, _m0, _x, _J, _beta, _r):
        return _phi(_f(m1, _m0, _x, _J, _beta, _r), _beta, _r)

    out = (
        AndersonAcceleration(
            fixed_point_fun=tap,
            tol=1e-3,
            maxiter=100,
        )
    ).run(_phi(x + J @ m0, beta, r), m0, x, J, beta, r)

    jax.debug.print("{error}", error=out.state.error)

    return out.params, m0


def time_evolution(m0, steps, update_fun):
    final_carry, stacked_outputs = jax.lax.scan(update_fun, init=m0, xs=steps)
    return final_carry, stacked_outputs


def simulate(x, J, m0, steps, beta, r, update_fun=update_tap_mf):
    wrapped_time_evolution = partial(
        time_evolution,
        steps=steps,
        update_fun=partial(update_fun, x=x, J=J, beta=beta, r=r),
    )
    final_carry, stacked_outputs = jax.vmap(wrapped_time_evolution)(m0)
    return final_carry, stacked_outputs

In [None]:
import matplotlib.pyplot as plt
import optax


def simulate_and_plot_only_naive(x, J, m0, steps, beta, r):
    _, stacked_outputs_naive = simulate(
        x, J, m0, steps, beta, r, update_fun=update_naive_mf
    )

    y_naive = stacked_outputs_naive.transpose((1, 0, 2, 3))

    y_naive_m_buffer = jnp.zeros_like(y_naive[:, :, :, 0])
    y_naive_x_buffer = jnp.zeros_like(y_naive[:, :, :, 0])
    y_naive_m0_buffer = jnp.zeros_like(y_naive[:, :, :, 0])

    for step in steps:
        if step == 0:
            y_naive_m_buffer = y_naive_m_buffer.at[step].set(
                optax.cosine_similarity(y_naive[0], y_naive[0])
            )
        else:
            y_naive_m_buffer = y_naive_m_buffer.at[step].set(
                optax.cosine_similarity(y_naive[step - 1], y_naive[step])
            )
        y_naive_x_buffer = y_naive_x_buffer.at[step].set(
            optax.cosine_similarity(x, y_naive[step])
        )
        y_naive_m0_buffer = y_naive_m0_buffer.at[step].set(
            optax.cosine_similarity(m0, y_naive[step])
        )

    with plt.style.context("ggplot"):
        plt.clf()
        plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
        fig, axes = plt.subplots(2, 1, sharex=True, sharey="row", figsize=(3.5, 5.5))

        axes[0].set_ylabel(r"cosine similarity")
        axes[0].plot(
            steps[1:],
            y_naive_m_buffer.squeeze()[1:],
            color="tab:green",
            linewidth=0.2,
            alpha=0.5,
        )

        axes[0].plot(
            steps,
            y_naive_m0_buffer.squeeze(),
            color="tab:olive",
            linewidth=0.2,
            alpha=0.5,
        )
        axes[0].plot(
            steps,
            y_naive_x_buffer.squeeze(),
            color="tab:blue",
            linewidth=0.2,
            alpha=0.5,
        )
        axes[0].legend(loc="lower center")

        axes[1].set_xlabel(r"$t$")
        axes[1].set_ylabel(r"Euclidean norm")
        axes[1].plot(
            steps,
            jnp.linalg.norm(y_naive, axis=-1).squeeze(),
            color="tab:red",
            linewidth=0.2,
            alpha=0.5,
        )
        axes[1].set_xlabel(r"$t$")

        fig.tight_layout()
        plt.show()


def simulate_and_plot(x, J, m0, steps, beta, r):
    _, stacked_outputs_naive = simulate(
        x, J, m0, steps, beta, r, update_fun=update_naive_mf
    )
    _, stacked_outputs_tap = simulate(
        x, J, m0, steps, beta, r, update_fun=update_tap_mf
    )

    y_naive = stacked_outputs_naive.transpose((1, 0, 2, 3))
    y_tap = stacked_outputs_tap.transpose((1, 0, 2, 3))

    y_naive_m_buffer = jnp.zeros_like(y_naive[:, :, :, 0])
    y_naive_x_buffer = jnp.zeros_like(y_naive[:, :, :, 0])
    y_naive_m0_buffer = jnp.zeros_like(y_naive[:, :, :, 0])
    y_tap_m_buffer = jnp.zeros_like(y_tap[:, :, :, 0])
    y_tap_x_buffer = jnp.zeros_like(y_tap[:, :, :, 0])
    y_tap_m0_buffer = jnp.zeros_like(y_tap[:, :, :, 0])

    for step in steps:
        if step == 0:
            y_naive_m_buffer = y_naive_m_buffer.at[step].set(
                optax.cosine_similarity(y_naive[0], y_naive[0])
            )
            y_tap_m_buffer = y_tap_m_buffer.at[step].set(
                optax.cosine_similarity(y_tap[0], y_tap[0])
            )
        else:
            y_naive_m_buffer = y_naive_m_buffer.at[step].set(
                optax.cosine_similarity(y_naive[step - 1], y_naive[step])
            )
            y_tap_m_buffer = y_tap_m_buffer.at[step].set(
                optax.cosine_similarity(y_tap[step - 1], y_tap[step])
            )
        y_naive_x_buffer = y_naive_x_buffer.at[step].set(
            optax.cosine_similarity(x, y_naive[step])
        )
        y_tap_x_buffer = y_tap_x_buffer.at[step].set(
            optax.cosine_similarity(x, y_tap[step])
        )
        y_naive_m0_buffer = y_naive_m0_buffer.at[step].set(
            optax.cosine_similarity(m0, y_naive[step])
        )
        y_tap_m0_buffer = y_tap_m0_buffer.at[step].set(
            optax.cosine_similarity(m0, y_tap[step])
        )

    with plt.style.context("ggplot"):
        plt.clf()
        plt.rc("text.latex", preamble=r"\usepackage{amsmath}")
        fig, axes = plt.subplots(2, 2, sharex=True, sharey="row", figsize=(6, 6))

        axes[0, 0].set_ylabel(r"cosine similarity")
        axes[0, 0].plot(
            steps[1:],
            y_naive_m_buffer.squeeze()[1:].mean(axis=-1),
            color="tab:green",
            linewidth=2,
            label="$\mathbf{m}_{t-1}$",
        )
        axes[0, 0].fill_between(
            steps[1:],
            y_naive_m_buffer.squeeze()[1:].min(axis=-1),
            y_naive_m_buffer.squeeze()[1:].max(axis=-1),
            color="tab:green",
            alpha=0.2,
        )
        axes[0, 0].plot(
            steps,
            y_naive_m0_buffer.squeeze().mean(axis=-1),
            color="tab:olive",
            linewidth=2,
            label="$\mathbf{m}_{0}$",
        )
        axes[0, 0].fill_between(
            steps,
            y_naive_m0_buffer.squeeze().min(axis=-1),
            y_naive_m0_buffer.squeeze().max(axis=-1),
            color="tab:olive",
            alpha=0.2,
        )
        axes[0, 0].plot(
            steps,
            y_naive_x_buffer.squeeze().mean(axis=-1),
            color="tab:blue",
            linewidth=2,
            label="$\mathbf{x}$",
        )
        axes[0, 0].fill_between(
            steps,
            y_naive_x_buffer.squeeze().min(axis=-1),
            y_naive_x_buffer.squeeze().max(axis=-1),
            color="tab:blue",
            alpha=0.2,
        )
        axes[0, 0].legend(loc="lower center")
        axes[0, 0].set_title(r"update_naive_mf")

        axes[0, 1].plot(
            steps[1:],
            y_tap_m_buffer.squeeze()[1:].mean(axis=-1),
            color="tab:green",
            linewidth=2,
            label="$\mathbf{m}_{t-1}$",
        )
        axes[0, 1].fill_between(
            steps[1:],
            y_tap_m_buffer.squeeze()[1:].min(axis=-1),
            y_tap_m_buffer.squeeze()[1:].max(axis=-1),
            color="tab:green",
            alpha=0.2,
        )
        axes[0, 1].plot(
            steps,
            y_tap_m0_buffer.squeeze().mean(axis=-1),
            color="tab:olive",
            linewidth=2,
            label="$\mathbf{m}_{0}$",
        )
        axes[0, 1].fill_between(
            steps,
            y_tap_m0_buffer.squeeze().min(axis=-1),
            y_tap_m0_buffer.squeeze().max(axis=-1),
            color="tab:olive",
            alpha=0.2,
        )
        axes[0, 1].plot(
            steps,
            y_tap_x_buffer.squeeze().mean(axis=-1),
            color="tab:blue",
            linewidth=2,
            label="$\mathbf{x}$",
        )
        axes[0, 1].fill_between(
            steps,
            y_tap_x_buffer.squeeze().min(axis=-1),
            y_tap_x_buffer.squeeze().max(axis=-1),
            color="tab:blue",
            alpha=0.2,
        )
        axes[0, 1].legend(loc="lower center")
        axes[0, 1].set_title(r"update_tap_mf")

        axes[1, 0].set_xlabel(r"$t$")
        axes[1, 0].set_ylabel(r"Euclidean norm")
        axes[1, 0].plot(
            steps,
            jnp.linalg.norm(y_naive, axis=-1).squeeze().mean(axis=-1),
            color="tab:red",
            linewidth=2,
        )
        axes[1, 0].fill_between(
            steps,
            jnp.linalg.norm(y_naive, axis=-1).squeeze().min(axis=-1),
            jnp.linalg.norm(y_naive, axis=-1).squeeze().max(axis=-1),
            color="tab:red",
            alpha=0.2,
        )
        axes[1, 1].plot(
            steps,
            jnp.linalg.norm(y_tap, axis=-1).squeeze().mean(axis=-1),
            color="tab:red",
            linewidth=2,
        )
        axes[1, 1].fill_between(
            steps,
            jnp.linalg.norm(y_tap, axis=-1).squeeze().min(axis=-1),
            jnp.linalg.norm(y_tap, axis=-1).squeeze().max(axis=-1),
            color="tab:red",
            alpha=0.2,
        )
        axes[1, 1].set_xlabel(r"$t$")

        fig.tight_layout()
        plt.show()

In [None]:
N = 1024
D = 512
beta = 1.0
r = (D / 2 - 1) ** 0.5

key = jax.random.PRNGKey(2666)
x_key, J_key = jax.random.split(key)

x = jax.random.normal(x_key, shape=(N, D))
x = r * x / jnp.linalg.norm(x, axis=-1, keepdims=True)

J = N**-0.5 * jax.random.normal(J_key, shape=(N, N))
print(N**-1, J.var(axis=-1).mean(), J.var(), N**-0.5)

m0 = jnp.ones((1, N, D))
m0 = m0 / jnp.linalg.norm(m0, axis=-1, keepdims=True)

simulate_and_plot_only_naive(x, J, m0, jnp.arange(0, 20), beta, r)

In [None]:
N = 1024
D = 512
beta = 1.0
r = (D / 2 - 1) ** 0.5

key = jax.random.PRNGKey(2666)
x_key, J_key = jax.random.split(key)

x = jax.random.normal(x_key, shape=(N, D))
x = r * x / jnp.linalg.norm(x, axis=-1, keepdims=True)

J = N**-0.5 * jax.random.normal(J_key, shape=(N, N))
print(N**-1, J.var(axis=-1).mean(), J.var(), N**-0.5)

m0 = jnp.ones((1, N, D))
m0 = m0 / jnp.linalg.norm(m0, axis=-1, keepdims=True)


simulate_and_plot(x, J, m0, jnp.arange(0, 20), beta, r)

In [None]:
N = 1024
D = 512
beta = 2.0
r = (D / 2 - 1) ** 0.5

key = jax.random.PRNGKey(2666)
x_key, J_key = jax.random.split(key)

x = jax.random.normal(x_key, shape=(N, D))
x = r * x / jnp.linalg.norm(x, axis=-1, keepdims=True)

J = N**-0.5 * jax.random.normal(J_key, shape=(N, N))
print(N**-1, J.var(axis=-1).mean(), J.var(), N**-0.5)

m0 = jnp.ones((1, N, D))
m0 = m0 / jnp.linalg.norm(m0, axis=-1, keepdims=True)


simulate_and_plot(x, J, m0, jnp.arange(0, 20), beta, r)

In [None]:
N = 1024
D = 512
beta = 1.0
r = (D / 2 - 1) ** 0.5

key = jax.random.PRNGKey(2666)
x_key, J_key = jax.random.split(key)

x = jax.random.normal(x_key, shape=(N, D))
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)

J = 2 * N**-0.5 * jax.random.normal(J_key, shape=(N, N))
print(N**-1, J.var(axis=-1).mean(), J.var(), N**-0.5)

m0 = jnp.ones((1, N, D))
m0 = m0 / jnp.linalg.norm(m0, axis=-1, keepdims=True)


simulate_and_plot(x, J, m0, jnp.arange(0, 20), beta, r)

In [None]:
def _phi_norm(theta, beta, r):
    """See Eq. (38)."""
    return beta / (1 + jnp.sqrt(1 + beta**2 * theta**2 / r**2)) * theta


x_values = np.linspace(0, 40, 200)
betas = [0.1, 0.5, 1.0, 2.0, 10.0]

D = 512
r = (D / 2 * 1) ** 0.5

with plt.style.context("ggplot"):
    plt.clf()

    plt.figure(figsize=(6, 6))
    plt.rc("text", usetex=True)
    plt.rc("text.latex", preamble=r"\usepackage{amsmath}")

    for i, beta in enumerate(betas[::-1]):
        plt.plot(
            x_values,
            _phi_norm(x_values, beta, r),
            label=r"$\beta = $" + f"${beta}$",
            lw=2.0,
        )

    plt.axhline(y=r, linestyle="--", color="gray")
    plt.axvline(x=r, linestyle="--", color="gray")

    plt.xlabel(r"$\Vert \boldsymbol{\theta} \rVert$")
    plt.xlim(x_values.min(), x_values.max())
    plt.ylabel(r"$\Vert \varphi ( \boldsymbol{\theta} ) \rVert$")
    plt.legend(loc="best", bbox_to_anchor=(0.5, 0.0, 0.5, 0.5))

    plt.show()