In [1]:
from __future__ import print_function, division, absolute_import

import gc

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')

import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu # neural network layers
from jax.experimental import optimizers

ModuleNotFoundError: No module named 'jax'

In [None]:
from sklearn.datasets import make_moons
x_train, y_train = make_moons(n_samples = 5000, noise = 0.10)
x_test, y_test = make_moons(n_samples = 2000, noise = 0.10)
y_test = y_test.astype(np.float32)
y_train = y_train.astype(np.float32)

In [None]:
def getMeanStd(predict, params, x, Npred = 50):
    # per-example predictions
    p = np.concatenate([predict(params, x)[np.newaxis,...] for k in range(Npred)], axis = 0)
    return np.mean(p, axis = 0), np.std(p, axis = 0)


def makeModel(x_train, y_train):
    train_size = len(x_train)

    lengthscale = 1e-4
    wd = lengthscale**2/train_size
    dd = 2./train_size
    layers = [200, 100, 50, 5]

    net_init, net_apply = stax.serial(
        Dense(200), Relu,
        Dense(100), Relu,
        Dense(50), Relu,
        Dense(5), Relu,
        Dense(1),
    )

    L = x_train.shape[1]
    out_shape, net_params = net_init((-1, L))
    # Make a batched version of the `predict` function
    batch_net_apply = vmap(net_apply, in_axes=(None, 0))

    def loss(params, x, y):
        y_pred = batch_net_apply(params, x)
        return -np.mean(np.power(y - y_pred, 2))

    # Define a compiled update step
    @jit
    def step(i, opt_state, opt_update, x1, y1):
        p = get_params(opt_state)
        g = grad(loss)(p, x1, y1)
        return opt_update(i, g, opt_state)

    def fit(params, x_train, y_train):
        batch_size = 20
        Nepochs = 500
        opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
        opt_state = opt_init(params)
        for e in range(Nepochs)
            for i in range(int(len(x_train)/batch_size)):
                opt_state = step(i, opt_state, opt_update, x_train, y_train)
        return get_params(opt_state)

    params = fit(net_params, x_train, y_train)

    return net_apply, params



In [None]:
predict, params = makeModel(x_train, y_train)

In [None]:
def plot_contour(model, x, y, getFunction):
    # make countour
    mins = [np.min(x[:,0]), np.min(x[:,1])]
    maxs = [np.max(x[:,0]), np.max(x[:,1])]
    step = [(maxs[0] - mins[0])/50.0, (maxs[1] - mins[1])/50.0]
    bx, by = np.mgrid[mins[0]:(maxs[0]+0.5*step[0]):step[0], mins[1]:(maxs[1]+0.5*step[0]):step[1]]
    inputs = np.vstack([bx.flatten(), by.flatten()]).T
    inputs = inputs.astype(np.float32)

    pred_m, pred_s = getFunction(model, inputs, Npred = 50)
    pred_m_2d = pred_m.reshape( (-1, bx.shape[1]) )
    pred_s_2d = pred_s.reshape( (-1, bx.shape[1]) )

    # if one wants to smoothen the results
    #for data in [pred_m_2d, pred_s_2d]:
    #    data = gaussian_filter(data, 0.1)

    fig, ax = plt.subplots(nrows = 2, ncols = 1, sharex = True, figsize = (10, 8))
    cmap = sns.diverging_palette(250, 12, s=85, l=25, as_cmap=True)
    contour_s = ax[0].contourf(bx, by, pred_s_2d, cmap = cmap)
    cbar_s = plt.colorbar(contour_s, ax = ax[0])
    cbar_s.ax.set_ylabel('Unc.')
    contour_m = ax[1].contourf(bx, by, pred_m_2d, cmap = cmap)
    cbar_m = plt.colorbar(contour_m, ax = ax[1])
    cbar_m.ax.set_ylabel('Mean')
    for a in [ax[0], ax[1]]:
        a.scatter(x[y == 1,0], x[y == 1,1], color = 'r', marker = 's', s = 5, label = 'y = 1')
        a.scatter(x[y == 0,0], x[y == 0,1], color = 'b', marker = 's', s = 5, label = 'y = 0')
        a.set(xlabel = 'A', ylabel = 'B', title = '')
        a.set_xlim([mins[0], maxs[0]])
        a.set_ylim([mins[1], maxs[1]])
        a.legend(frameon = True)
    ax[0].set_xlabel('')
    fig.subplots_adjust(hspace = 0)
    fig.tight_layout()
    plt.show()


In [None]:
plot_contour(params, x_test, y_test, getFunction = getMeanStd)