In [41]:
import numpy as np
import math
import matplotlib as mpl
mpl.use("agg")
import matplotlib.cm as cm
import matplotlib.pyplot as plt

In [36]:
class Logger:
    def __init__(self, dim):
        self.dim = dim
        self.index_key = ["g", "evals"]
        self.scalar_key = ["fval", "sigma"]
        self.vector_key = ["mean", "sqrteig"]
        self.log = self.init_log_dict()

    def init_log_dict(self):
        log = dict()
        for k in self.index_key + self.scalar_key:
            log[k] = []
        for k in self.vector_key:
            for i in range(self.dim):
                log[f"{k}{i}"] = []
        return log

    def generate_log(self, g, evals, fval, sigma, mean, B):
        self.log["g"].append(g)
        self.log["evals"].append(evals)
        self.log["fval"].append(fval)
        self.log["sigma"].append(sigma)
        D = np.linalg.eigvalsh(sigma**2 * B.dot(B.T))
        D = np.sqrt(D)
        for i in range(self.dim):
            self.log[f"mean{i}"].append(mean[i, 0])
            self.log[f"sqrteig{i}"].append(D[i])

    def plot(self, dir_path):
        mpl.rc("lines", linewidth=0.5, markersize=8)
        mpl.rc("font", size=12)
        mpl.rc("grid", color="0.75", linestyle=":")
        mpl.rc("ps", useafm=True)  # force to use
        mpl.rc("pdf", use14corefonts=True)  # only Type 1 fonts
        mpl.rc("text", usetex=True)  # for a paper submission
        mpl.rc("xtick", labelsize=6)
        from cycler import cycler

        mpl.rc("axes", prop_cycle=cycler(color="bgrcmyk"))

        variable_list = self.scalar_key + self.vector_key
        log = self.log
        nfigs = len(variable_list)
        ncols = int(np.ceil(np.sqrt(nfigs)))
        nrows = int(np.ceil(nfigs / ncols))
        figsize = (4 * ncols, 3 * nrows)
        axdict = dict()
        # Figure
        fig = plt.figure(figsize=figsize)

        def plot_fval(idx):
            ax = plt.subplot(nrows, ncols, idx)
            ax.set_title(r"$f(m)$")
            ax.grid(True)
            ax.grid(which="major", linewidth=0.50)
            ax.grid(which="minor", linewidth=0.25)
            ax.set_yscale("log")
            plt.plot(log["evals"], log["fval"])

        def plot_mean(idx):
            ax = plt.subplot(nrows, ncols, idx)
            ax.set_title(r"$m$")
            ax.grid(True)
            ax.grid(which="major", linewidth=0.50)
            ax.grid(which="minor", linewidth=0.25)
            for i in range(self.dim):
                plt.plot(
                    log["evals"], log[f"mean{i}"], color=cm.hsv(float(i) / self.dim)
                )

        def plot_sigma(idx):
            ax = plt.subplot(nrows, ncols, idx)
            ax.set_title(r"$\sigma$ (step-size)")
            ax.grid(True)
            ax.grid(which="major", linewidth=0.50)
            ax.grid(which="minor", linewidth=0.25)
            ax.set_yscale("log")
            plt.plot(log["evals"], log["sigma"])

        def plot_sqrteig(idx):
            ax = plt.subplot(nrows, ncols, idx)
            ax.set_title(r"$\sqrt{\rm eig}$")
            ax.grid(True)
            ax.grid(which="major", linewidth=0.50)
            ax.grid(which="minor", linewidth=0.25)
            ax.set_yscale("log")
            for i in range(self.dim):
                plt.plot(
                    log["evals"], log[f"sqrteig{i}"], color=cm.hsv(float(i) / self.dim)
                )

        idx = 1
        if "fval" in variable_list:
            plot_fval(idx)
            idx += 1
        if "mean" in variable_list:
            plot_mean(idx)
            idx += 1
        if "sigma" in variable_list:
            plot_sigma(idx)
            idx += 1
        if "sqrteig" in variable_list:
            plot_sqrteig(idx)
            idx += 1

        plt.tight_layout()
        plt.savefig(f"{dir_path}/log.pdf")
        return

class Solution(object):
    def __init__(self, dim):
        self.f = float("nan")
        self.x = np.zeros([dim, 1])
        self.z = np.zeros([dim, 1])

def expm_numpy(mat: np.ndarray) -> np.ndarray:
    D, U = np.linalg.eigh(mat)
    expD = np.exp(D)
    return U @ np.diag(expD) @ U.T

In [37]:
def xnes_main(obj_func, dim, lamb, mean, sigma, max_evals, criterion=1e-8, seed=123):
    # validation
    assert isinstance(dim, int)
    assert isinstance(lamb, int)
    assert isinstance(mean, float)
    assert isinstance(sigma, float)
    assert isinstance(seed, int)
    assert isinstance(max_evals, int)

    np.random.seed(seed)
    logger = Logger(dim=dim)
    # constant
    wrh = math.log(lamb / 2.0 + 1.0) - np.log(np.arange(1, lamb + 1))
    w_hat = np.maximum([0 * lamb], wrh)
    w = w_hat / sum(w_hat) - 1.0 / lamb
    eta_m = 1.0
    eta_B = 3.0 * (3.0 + np.log(dim)) / 5.0 / dim / np.sqrt(dim)
    eta_sigma = eta_B
    I = np.eye(dim, dtype=float)
    # dynamic
    mean = np.array([mean] * dim).reshape(dim, 1)
    B = np.eye(dim, dtype=float)
    evals = 0
    g = 0
    best = np.inf
    sols = [Solution(dim) for _ in range(lamb)]

    while evals < max_evals:
        g += 1

        for i in range(lamb):
            sols[i].z = np.random.randn(dim, 1)
            sols[i].x = mean + sigma * B.dot(sols[i].z)
            sols[i].f = obj_func(sols[i].x)
        evals += lamb

        sols = sorted(sols, key=lambda s: s.f)
        fm = obj_func(mean)
        print("#evals:{}, f(m):{}".format(evals, fm)) if g % 1000 == 0 else None

        if fm < criterion:
            break

        logger.generate_log(
            g=g, evals=evals, fval=fm, sigma=sigma, mean=mean, B=B
        )

        # natural gradient estimation
        G_delta = np.sum([w[i] * sols[i].z for i in range(lamb)], axis=0)
        G_M = np.sum(
            [w[i] * (np.outer(sols[i].z, sols[i].z) - I) for i in range(lamb)], axis=0
        )
        G_sigma = G_M.trace() / dim
        G_B = G_M - G_sigma * I

        # update parameters
        mean += eta_m * sigma * np.dot(B, G_delta)
        sigma *= math.exp((eta_sigma / 2.0) * G_sigma)
        B = B.dot(expm_numpy((eta_B / 2.0) * G_B))

    logger.plot(dir_path=".")
    return

In [38]:
def sphere(x):
    return np.sum(x**2)


def ktablet(x):
    n = len(x)
    k = int(n / 4)
    if n == 2:
        k = 1
    if len(x) < 2:
        raise ValueError("dimension must be greater one")
    return np.sum(x[0:k] ** 2) + np.sum((100.0 * x[k:n]) ** 2)


def ellipsoid(x):
    n = len(x)
    if len(x) < 2:
        raise ValueError("dimension must be greater one")
    Dell = np.diag([10 ** (3 * i / (n - 1)) for i in range(n)])
    return sphere(Dell @ x)


def rastrigin(x):
    n = len(x)
    if n < 2:
        raise ValueError("dimension must be greater one")
    return 10 * n + sum(x**2 - 10 * np.cos(2 * np.pi * x))


In [39]:
# experimental setup
obj_func = sphere
dim = 40
lamb = 20
mean = 3.0
sigma = 2.0
max_evals = 100000

In [40]:
xnes_main(obj_func, dim, lamb, mean, sigma, max_evals)

#evals:20000, f(m):1.7228791070613596
#evals:40000, f(m):0.0666288399209203
#evals:60000, f(m):0.0043471226570394724
#evals:80000, f(m):0.00016786735775892653
#evals:100000, f(m):8.309585704241058e-06
