In [2]:
%matplotlib widget
from LinearRegression import (sigmoid_basis_functions,
                              gaussian_basis_functions,
                              polynomial_basis_functions,
                              BayesianLinearRegression,
                              LinearRegression)
from matplotlib import pyplot as plt
import ipywidgets as widgets
import numpy as np

In [5]:
LIMS = [-25, 25]
xx = np.linspace(LIMS[0], LIMS[1], 500)
f = lambda x: 7*np.sin(x/5) + x
COLOR = u'#1f77b4'

plt.close('all')

def fit(nv, pv, dof, BLR, basis, xs, ys):
    if dof == 1: centers = np.array([])
    else: centers = np.linspace(np.min(xs), np.max(xs), dof-1)
    
    if basis == 'Polynomial': basis_funcs = polynomial_basis_functions(degree=len(centers))
    elif dof == 1: basis_funcs = polynomial_basis_functions(degree=0)
    elif basis == 'Gaussian': basis_funcs = gaussian_basis_functions(centers, beta=2)
    else: basis_funcs = sigmoid_basis_functions(centers)
    
    if BLR:
        mean = np.zeros(dof)
        cov = np.eye(dof)*pv
        mdl = BayesianLinearRegression(theta_mean=mean, theta_cov=cov, sig=nv, basis_functions=basis_funcs)
    else:
        mdl = LinearRegression(basis_funcs)
    mdl = mdl.fit(xs, ys)
    preds = mdl.predict(xx)
    conf_int = np.zeros(len(preds)) if not BLR else mdl.predict_std(xx)
    return preds, conf_int

fig, ax = plt.subplots(figsize=(6, 4))
ax.set_ylim([LIMS[0], LIMS[1]])
ax.set_xlim([LIMS[0], LIMS[1]])

@widgets.interact(
    nv=(0.001, 5, 0.1),
    pv=(0.01, 10, 0.1),
    dof=(1,10,1),
    BLR=False,
    basis=['Polynomial',
           'Gaussian',
           'Sigmoid'],
    N=(1, 100, 1))

def update(nv=0.1, pv=1, dof=1, BLR=False, basis='Polynomial', N=10):
    plt.cla()
    ax.set_ylim([LIMS[0], LIMS[1]])
    ax.set_xlim([LIMS[0], LIMS[1]])
    
    np.random.seed(0)
    xs = 40*np.random.rand(N) - 20
    ys = f(xs) + np.sqrt(nv)*np.random.randn(N)
    ax.scatter(xs, ys, 20, color=COLOR)

    pred, CI = fit(nv, pv, dof, BLR, basis, xs, ys)
    ax.plot(xx, pred, lw=2, color=COLOR)
    ax.fill_between(x=xx, y1=pred-CI, y2=pred+CI, color=COLOR, alpha=.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

interactive(children=(FloatSlider(value=0.1, description='nv', max=5.0, min=0.001), FloatSlider(value=1.0, des…