In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Callable, Union, Sequence
import math
import torch
from scipy.spatial.distance import cdist
import numpy as np


def peaks(meshgrid: torch.Tensor) -> torch.Tensor:
    """
    "Peaks" function that has multiple local minima.

    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates
    """
    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)
    xx = meshgrid[..., 0]
    yy = meshgrid[..., 1]
    return 0.25 * (
        3 * (1 - xx) ** 2 * torch.exp(-(xx**2) - (yy + 1) ** 2)
        - 10 * (xx / 5 - xx**3 - yy**5) * torch.exp(-(xx**2) - yy**2)
        - 1 / 3 * torch.exp(-((xx + 1) ** 2) - yy**2)
    )


def rastrigin(meshgrid: torch.Tensor, shift: int = 0) -> torch.Tensor:
    """
    "Rastrigin" function with `A = 3`
    https://en.wikipedia.org/wiki/Rastrigin_function

    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates
    """
    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)
    xx = meshgrid[..., 0]
    yy = meshgrid[..., 1]
    A = 3
    return A * 2 + (
        ((xx - shift) ** 2 - A * torch.cos(2 * torch.tensor(math.pi, dtype=torch.float, device=xx.device) * xx))
        + ((yy - shift) ** 2 - A * torch.cos(2 * torch.tensor(math.pi, dtype=torch.float, device=xx.device) * yy))
    )


def rosenbrock(meshgrid: torch.Tensor) -> torch.Tensor:
    """
    "Rosenbrock" function
    https://en.wikipedia.org/wiki/Rosenbrock_function

    It has a global minimum at $(x , y) = (a, a^2) = (1, 1)$

    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates
    """
    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)
    xx = meshgrid[..., 0]
    yy = meshgrid[..., 1]

    a = 1
    b = 100
    return (a - xx) ** 2 + b * (yy - xx**2) ** 2


def simple_fn(meshgrid: torch.Tensor) -> torch.Tensor:
    """
    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates
    """
    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)
    xx = meshgrid[..., 0]
    yy = meshgrid[..., 1]

    output = -1 / (1 + xx**2 + yy**2)

    return output


def simple_fn2(meshgrid: torch.Tensor) -> torch.Tensor:
    """
    :params meshgrid: tensor of shape [..., 2], the (x, y) coordinates
    """
    meshgrid = torch.as_tensor(meshgrid, dtype=torch.float)
    xx = meshgrid[..., 0]
    yy = meshgrid[..., 1]

    output = (1 + xx**2 + yy**2) ** (1 / 2)

    return output

In [None]:
ALPHA = 1
SAMPLE_MARKER = "."
SAMPLE_MARKER_SIZE = 7.5


GREEN = "#3EB863"
PURPLE = "#6a4c93"
RED = "#CF294A"
BLUE = "#275299"


anchors = np.asarray(
    [
        [-1.25, 0.5],
        [0, 1.25],
        [1.2, 0],
    ]
)

A_w = 0.1
B_w = 0.1
C_w = 0.8
assert A_w + B_w + C_w == 1

point = A_w * anchors[0] + B_w * anchors[1] + C_w * anchors[2]

colors = [
    RED,
    PURPLE,
    GREEN,
]
sample_color = BLUE

AXIS_OFF = True

In [None]:
from tueplots import bundles
from tueplots import figsizes
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator
import numpy as np

N_ROWS = 1
N_COLS = 2
RATIO = 1
import matplotlib

plt.style.use("default")
plt.rcParams.update(bundles.icml2022())
plt.rcParams.update(figsizes.icml2022_full(ncols=N_COLS, nrows=N_ROWS, height_to_width_ratio=0.7))


fig, [ax, proj] = plt.subplots(
    N_ROWS,
    N_COLS,
    dpi=300,
    sharey=False,
    sharex=False,
    subplot_kw={"projection": "3d"}
    # constrained_layout=True
)


# Make data.
X = np.arange(-2, 2, 0.01)
Y = np.arange(-3, 3, 0.01)
X, Y = np.meshgrid(X, Y)
meshgrid = np.stack((X, Y), -1)


# Plot surface
Z = peaks(meshgrid)
COUNT = 100
surf = ax.plot_surface(
    X, Y, Z, cmap=cm.coolwarm, alpha=ALPHA, linewidth=1, antialiased=True, shade=True, rcount=COUNT, ccount=COUNT
)


anchors = np.concatenate((anchors, peaks(anchors)[:, None].numpy()), axis=-1)
point = np.asarray([*point, peaks(point).item()])

# Plot anchors
for anchor, symbol, color in zip(anchors, ["*", "*", "*"], colors):
    ax.plot(
        anchor[..., 0], anchor[..., 1], anchor[..., 2], c=color, marker=symbol, zorder=10, alpha=1, antialiased=True
    )

# Plot sample
ax.plot(
    [point[0]],
    [point[1]],
    [point[2]],
    c=sample_color,
    markersize=SAMPLE_MARKER_SIZE,
    marker=SAMPLE_MARKER,
    zorder=10,
    antialiased=True,
)

# Plot anchors lines
for anchor, color in zip(anchors, colors):
    ax.plot(
        [point[0], anchor[0]],
        [point[1], anchor[1]],
        [point[2], anchor[2]],
        c=color,
        markersize=0,
        zorder=8,
        linewidth=1,
        linestyle="--",
        antialiased=True,
    )

# Plot relative axis
anchors_dists = cdist(anchors, point[None]).squeeze()
for dist_ax, color in zip(
    (
        ([0, anchors_dists[0]], [0, 0], [0, 0]),
        ([0, 0], [0, anchors_dists[1]], [0, 0]),
        ([0, 0], [0, 0], [0, anchors_dists[2]]),
    ),
    colors,
):
    proj.plot(*dist_ax, c=color, markersize=0, zorder=8, linewidth=2, linestyle="--", antialiased=True)

# Plot anchors axis ends
for axis_end, symbol, color, zorder in zip(
    (([anchors_dists[0]], [0], [0]), ([0], [anchors_dists[1]], [0]), ([0], [0], [anchors_dists[2]])),
    ["*", "*", "*"],
    colors,
    (11, 9.5, 11),
):
    proj.plot(*axis_end, c=color, marker=symbol, markersize=10, zorder=zorder, alpha=1, antialiased=True)

# Plot sample
proj.plot(
    [anchors_dists[0]],
    [anchors_dists[1]],
    [anchors_dists[2]],
    c=sample_color,
    markersize=SAMPLE_MARKER_SIZE,
    marker=SAMPLE_MARKER,
    zorder=zorder,
    antialiased=True,
)

# Plot cube
for lines, zorder in zip(
    (
        (
            [anchors_dists[0], anchors_dists[0], anchors_dists[0], 0, 0, anchors_dists[0], anchors_dists[0]],
            [0, anchors_dists[1], anchors_dists[1], anchors_dists[1], 0, 0, anchors_dists[1]],
            [0, 0, anchors_dists[2], anchors_dists[2], anchors_dists[2], anchors_dists[2], anchors_dists[2]],
        ),
        (
            [anchors_dists[0], anchors_dists[0]],
            [
                0,
                0,
            ],
            [0, anchors_dists[2]],
        ),
        (
            [anchors_dists[0], anchors_dists[0]],
            [
                0,
                0,
            ],
            [0, anchors_dists[2]],
        ),
        ([0, anchors_dists[0]], [anchors_dists[1], anchors_dists[1]], [0, 0]),
        ([0, 0], [anchors_dists[1], anchors_dists[1]], [0, anchors_dists[2]]),
    ),
    (10, 9, 9, 9, 9),
):
    proj.plot(*lines, c=sample_color, linestyle="--", linewidth=0.5, zorder=zorder, alpha=1, antialiased=True)


# proj.set_aspect('auto')
proj.set_box_aspect((anchors_dists[0], anchors_dists[1], anchors_dists[2]))  # aspect ratio is 1:1:1 in data space

proj.set_xlim3d(0, anchors_dists[0] + 0.1)
proj.set_ylim3d(0, anchors_dists[1] + 0.1)
proj.set_zlim3d(0, anchors_dists[2] + 0.1)
proj.view_init(elev=17.0, azim=-50)
ax.view_init(elev=40.0, azim=200)


if AXIS_OFF:
    ax.axis("off")
    proj.axis("off")

In [None]:
fig.savefig("teaser.svg", bbox_inches="tight", pad_inches=0)
!rsvg-convert -f pdf -o teaser.pdf teaser.svg
!rm teaser.svg

In [None]:
# box = proj.get_position()
# proj.set_position([box.x0, box.y0, box.x1, box.y1])
# for axis in [proj.xaxis, proj.yaxis, proj.zaxis]:
#     axis.set_ticklabels([])
#     axis._axinfo['axisline']['linewidth'] = 1
#     axis._axinfo['axisline']['color'] = (0, 0, 0)
#     axis._axinfo['grid']['linewidth'] = 0.25
#     axis._axinfo['grid']['linestyle'] = "-"
#     axis._axinfo['grid']['color'] = (0, 0, 0)
#     axis._axinfo['tick']['inward_factor'] = 0.0
#     axis._axinfo['tick']['outward_factor'] = 0.0
#     axis.set_pane_color((0.95, 0.95, 0.95))