In [None]:
%matplotlib inline

import numpy as np
try:
    import cma
except ImportError:
    !pip install cma
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import ipywidgets

In [None]:
def optimize_rosenbrock(rosenbrock_dim, init_sigma):
    es = cma.CMAEvolutionStrategy(rosenbrock_dim * [0], init_sigma)
    candidates = []
    while not es.stop():
        next_points = es.ask()
        candidates.append(next_points)
        es.tell(next_points, [cma.ff.rosen(x) for x in next_points])
        es.logger.add()
        es.disp()
    es.result_pretty()
    cma.plot()
    return candidates

In [None]:
plt.rcParams["figure.figsize"] = [16,9]
two_candidates = optimize_rosenbrock(2, 0.5);
def plot_2d(iter_num):
    colors = cm.cividis(np.linspace(0, 1, len(two_candidates)))
    min_val, max_val = 1e100, 1e-100
    for (i, (pts, c)) in enumerate(zip(two_candidates, colors)):
        if iter_num != -1 and iter_num not in (i - 3, i - 2, i - 1, i, i + 1):
            continue
        min_val = min(min(x) for x in pts + [[min_val] * 2])
        max_val = max(max(x, y) for x, y in pts + [[max_val] * 2])
        plt.scatter(
            [x[0] for x in pts],
            [x[1] for x in pts],
            color=c
        )
    try:
        min_val, max_val = round(min_val, 1), round(max_val, 1)
        min_val, max_val = -0.5, 1.5
        step_size = (max_val - min_val) / 16.0
        tick_range = np.arange(min_val, max_val + step_size, step_size)
        plt.xticks(tick_range)
        plt.yticks(tick_range)
    except ValueError:
        pass
    plt.show()

ipywidgets.interact(plot_2d, iter_num = ipywidgets.FloatSlider(value=-1,
                                               min=-1,
                                               max=len(two_candidates),
                                               step=5));

In [None]:
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 unused import
three_candidates = optimize_rosenbrock(3, 0.5)

def plot_3d(iter_num):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    colors = cm.inferno(np.linspace(0, 1, len(three_candidates)))
    min_val, max_val = -0.5, 1.2
    for (i, (pts, c)) in enumerate(zip(three_candidates, colors)):
        if iter_num != -1 and iter_num not in (i - 3, i - 2, i - 1, i, i + 1):
            continue
        pts = [x for x in pts if (np.array(x) >= min_val - 0.2).all() and (np.array(x) <= max_val + 0.2).all()]
        ax.scatter(
            [x[0] for x in pts],
            [x[1] for x in pts],
            [x[2] for x in pts],
            color=c
        )
    ax.set_xlim([min_val, max_val])
    ax.set_ylim([min_val, max_val])
    ax.set_zlim([min_val, max_val])
    plt.show()

ipywidgets.interact(plot_3d, iter_num = ipywidgets.FloatSlider(value=-1,
                                           min=-1,
                                           max=len(three_candidates),
                                           step=5));