In [None]:
import random
import sys

import autograd.numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.signal import convolve
import scipy.stats

sys.path.append("..")

import shared.format
import shared.tools

In [None]:
SEED = 14

np.random.seed(SEED)
random.seed(SEED)

In [None]:
def gauss_random_field(x, y, scale):
    white_field = np.random.standard_normal(size=x.shape)

    pos = np.empty(x.shape + (2,))
    pos[:, :, 0] = x; pos[:, :, 1] = y
    gauss_rv = scipy.stats.multivariate_normal([0,0], cov=np.ones(2))
    gauss_pdf = gauss_rv.pdf(pos)
    red_field = scale * convolve(white_field, gauss_pdf, mode='same')
    return red_field

def plot_cost_surface(cost, N, mesh_extent, ax=None):
    mesh = np.linspace(-mesh_extent, mesh_extent, N)
    weights1, weights2 = np.meshgrid(mesh, mesh)

    if ax is  None:
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')

    ax._axis3don = False

    ax.plot_surface(weights1, weights2, cost(weights1, weights2),
                      rstride=2, cstride=2, linewidth=0.5, edgecolor='C0',
                      alpha=1, color="white",
                      shade=True
                     );

    axis_equal_3d(ax, center=True)

In [None]:
scale = 1.
N = 100
mesh_extent = 10

grf = lambda x, y: gauss_random_field(x, y,  scale=scale)

In [None]:
def axis_equal_3d(ax,center=0):
    # FROM StackO/19933125

    extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
    sz = extents[:,1] - extents[:,0]
    if center == 0:
        centers = [0,0,0]
    else:
        centers = np.mean(extents, axis=1)
    maxsize = max(abs(sz))
    r = maxsize/2
    for ctr, dim in zip(centers, 'xyz'):
        getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)

In [None]:
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
plot_cost_surface(grf, N, mesh_extent, ax)
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
bbox = fig.bbox_inches.from_bounds(1, 2, 6, 4)
fig.savefig("2dgrf.pdf", bbox_inches=bbox)