In [None]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn import preprocessing
import tensorflow as tf


from data.toy_regression import create_split_periodic_data, ground_truth_periodic_function
from core import MapEnsemble

In [None]:
# we are not using gpu, so we might as well use float64 by default as does numpy
tf.keras.backend.set_floatx('float64')
assert tf.executing_eagerly()

figure_dir = './figures'

In [None]:
np.random.seed(0)
n_networks = 5
n_train = 20
batchsize_train = 20

_x_train, y_train = create_split_periodic_data(n_train=n_train)
x_min, x_max = np.min(_x_train), np.max(_x_train)
d = x_max - x_min
lower_bound = x_min - d / 2
upper_bound = x_max + d / 2

scaler = preprocessing.StandardScaler(with_mean=True, with_std=True).fit(_x_train)
# we can use _x_train for plotting and x_train for training
x_train = scaler.transform(_x_train)

_x_test = np.linspace(lower_bound, upper_bound, 500).reshape(-1, 1)
y_test = ground_truth_periodic_function(_x_test)
x_test = scaler.transform(_x_test)

layer_units = [500] * 4 + [1]
layer_activations = ["relu"] * 4 + ["linear"]

In [None]:
fig, ax = plt.subplots()
ax.plot(_x_test, y_test, label="Ground truth", alpha=0.3)
ax.scatter(_x_train, y_train, label="Train data")
ax.set_xlabel("");
ax.set_ylabel("");
ax.legend();

In [None]:
ensemble = MapEnsemble(n_networks=n_networks,
                       input_shape=[1],
                       layer_units=layer_units,
                       layer_activations=layer_activations)

In [None]:
ensemble.train(x_train=x_train, y_train=y_train, batchsize_train=batchsize_train)

In [None]:
predictions = ensemble.predict(x_test)

fig, ax = plt.subplots(figsize=(8, 8))
ax.plot(_x_test, y_test, label="Ground truth", alpha=0.1)
for i, prediction in enumerate(predictions):
    ax.plot(_x_test, prediction, label=f"Model {i+1} prediction", alpha=0.8)
ax.scatter(_x_train, y_train, c='k', marker='x', s=100, label="Train data")
ax.set_xlabel("");
ax.set_ylabel("");
ax.set_ylim([-5, 5])
ax.legend();
#fig.savefig(os.path.join(figure_dir, f"{n_networks}_ml_ensemble.pdf"))