In [None]:
import matplotlib.pyplot as plt
import numpy as np

from src.base.forecasting.evaluation.cross_validation import ErrorBounds
from src.base.forecasting.models import (
    TabularMetric,
    TabularRegressorMLP,
)

from create_dataset import DataSetType, create_dataset
from enum import Enum
from typing import Dict, List, Union

In [None]:
# =================================================================================================
#  Settings
# =================================================================================================
class TestType(Enum):
    GRID_REGULAR = 0
    GRID_SKEWED = 1
    SINGLE_VAR = 2


# test setup
score_metric = TabularMetric.rmse()
test_type = TestType.GRID_SKEWED
n_seeds = 1
n_folds = 5

In [None]:
# =================================================================================================
#  Construct dataset
# =================================================================================================
x_train, y_train = create_dataset(DataSetType.SINE, n=1000, c=7.0)

In [None]:
# =================================================================================================
#  Some helpers
# =================================================================================================
def get_param_grid() -> Union[List[Dict], Dict]:

    if test_type == TestType.GRID_REGULAR:

        return {"wd": [1e-4, 1e-3, 1e-2, 1e-1, 1, 10], "n_epochs": [5, 10, 20, 50, 100], "n_hidden_layers": [5]}

    elif test_type == TestType.GRID_SKEWED:

        return {
            ("wd", "n_epochs"): [
                (1e-4, 10),
                (1e-3, 10),
                (1e-2, 10),
                (1e-3, 20),
                (1e-2, 20),
                (1e-1, 20),
                (1e-2, 50),
                (1e-1, 50),
                (1e-0, 50),
                (1e-2, 75),
                (1e-1, 75),
                (1e-0, 75),
                (10, 60),
            ],
            "n_hidden_layers": [5],
        }

    elif test_type == TestType.SINGLE_VAR:

        return {"wd": [1e-4, 1e-3, 1e-2, 1e-1, 1, 10], "n_epochs": [100], "n_hidden_layers": [5]}

    else:

        raise NotImplementedError(f"Unknown test type: {test_type}")

In [None]:
# =================================================================================================
#  Actual simulation
# =================================================================================================

# --- set up test -----------------------------------------
mlp = TabularRegressorMLP(n_hidden_layers=3, n_seeds=n_seeds, n_epochs=20, layer_width=200)
param_grid = get_param_grid()

# --- run cv ----------------------------------------------
mlp.cv.grid_search(x_train, y_train, param_grid, score_metric, n_jobs=-1, shuffle_data=True, n_splits=n_folds)

In [None]:
# =================================================================================================
#  Plot CV results
# =================================================================================================
if test_type == TestType.SINGLE_VAR:
    fig, ax = mlp.cv.results.plot_1d("wd").set_error_bounds(ErrorBounds.STDEV).create()  # type: plt.Figure, plt.Axes
else:
    fig, ax = mlp.cv.results.plot_2d("wd", "n_epochs").create()  # type: plt.Figure, plt.Axes

In [None]:
# =================================================================================================
#  Plot function fit
# =================================================================================================
x_min = min(x_train)
x_max = max(x_train)
x_min, x_max = x_min - 0.1 * (x_max - x_min), x_max + 0.1 * (x_max - x_min)

y_min = min(y_train)
y_max = max(y_train)
y_min, y_max = y_min - 0.1 * (y_max - y_min), y_max + 0.1 * (y_max - y_min)

x_test = np.linspace(x_min, x_max, 1000)
y_test = mlp.predict(x_test)

# --- plot --------------------------------------------
fig, ax = plt.subplots(1, 1)  # type: plt.Figure, plt.Axes

# plot training data
ax.plot(x_train, y_train, ls="", marker="x", c=(0.6, 0.6, 0.6))

# plot predictions
ax.plot(x_test, y_test, ls="-")

ax.set_ylim(bottom=y_min, top=y_max)

fig.set_size_inches(w=12, h=8)
fig.tight_layout()