In [None]:
# Third-party imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Project imports
import twinlab as tl

In [None]:
# Parameters
dataset_id = "wiggle"
campaign_id = dataset_id
err_sig = 0.25
n_train = 400
n_eval = 101
random_seed = 42
n_cycle = 3
grad = 0.3
xmin, xmax = 0, 3.

def f(x):
    f1 = grad*x
    f2 = np.sin(2*np.pi*x*n_cycle)
    return np.where(np.abs(x-1.5)>0.5, f1, f1+f2)

In [None]:
# Seed the random-number generator
np.random.seed(random_seed)

# Training Data
X = np.random.uniform(xmin, xmax, n_train)
y = f(X)+np.random.normal(0., err_sig, n_train)
df_train = pd.DataFrame({'X': X, 'y': y})
display(df_train)
tl.upload_dataset(df_train, dataset_id, verbose=True)

In [None]:
# Evaluation data
eval = {"X": np.linspace(xmin, xmax, n_eval)}
df_test = pd.DataFrame(eval)
display(df_test)

In [None]:
# Training parameters
params = {
    "dataset_id": dataset_id,
    "inputs" : ["X"],
    "outputs": ["y"],
    "test_train_ratio": 1.,
}

# Plot parameters
grid = df_test["X"].values
alpha_fill = 0.25
ns_train = [10, 20, 40, 80, 160, 320]
nrow, ncol = 2, 3
figx, figy = 4, 3

# Loop over different error in data and plot
plt.subplots(nrow, ncol, sharex=True, sharey=True, figsize=(ncol*figx, nrow*figy))
for iplot, n in enumerate(ns_train):

    # Train model
    print("Number of data points used for training:", n)
    params["test_train_ratio"] = n/len(df_train)
    tl.train_campaign(params, campaign_id, verbose=True)

    # Predict
    df_mean, df_std = tl.predict_campaign(df_test, campaign_id)
    mean, err = df_mean["y"].values, df_std["y"].values

    # Plot
    color = f"C{iplot}"
    plt.subplot(nrow, ncol, iplot+1)
    plt.plot(df_train["X"][:n], df_train["y"][:n], ".", color="black")
    plt.plot(grid, mean, "-", color=color, label=f"N = {n}")
    for nsig in [1, 2]:
        plt.fill_between(grid, mean-nsig*err, mean+nsig*err, lw=0, color=color, alpha=alpha_fill)
    # plt.xlabel("X"); plt.ylabel("y")
    plt.xticks([]); plt.yticks([])
    # plt.ylim((-1., 3.))
    plt.legend()

# Finalize plot
plt.tight_layout()
plt.show()

In [None]:
# Delete campaign and dataset if necessary
tl.delete_campaign(campaign_id, verbose=True)
tl.delete_dataset(dataset_id, verbose=True)