In [8]:
from base.closed_form import ClosedFormRegression
from data import data
import matplotlib
import numpy as np

# change this to your own backend
matplotlib.interactive(False)
import matplotlib.pyplot as plt
% matplotlib inline


def plot_data_point(X_train, y_train, X_test, y_test):
    plt.figure(2,figsize=(8,6), dpi= 80, facecolor='w', edgecolor='k')
    plt.ylim([-2,2])
    plt.xlabel("x")
    plt.ylabel("y")
    plt.plot(X_train, y_train, 'o', ms=3, label='Train data')
    plt.plot(X_test, y_test, 'x', ms=3, label='Test data')
    plt.title('Data point/regression model')


def plot_regression_model(model, fmt):
    plt.figure(2)
    plt.ylim([-2,2])
    X = np.arange(-2, 2, 0.01)
    y = model.predict(X)
    plt.plot(X, y, fmt, label='Model for ' + model.description())
    plt.legend(fancybox=True, loc='lower right', framealpha=0.9, prop={'size': 10})


def plot_mse(model, mse,fig, fmt, color):
    print "Train MSE for model", model.description(), "is", mse[0]
    print "Test MSE for model", model.description(), "is", mse[1]


if __name__ == '__main__':
    X_train, y_train, X_test, y_test = data.load(train_coefficient=0.2, normalize=True)
    plt.figure()
    model3 = ClosedFormRegression(order=3, lambda_var=0)
    model5 = ClosedFormRegression(order=5, lambda_var=0)
    model7 = ClosedFormRegression(order=7, lambda_var=0)

    plot_data_point(X_train, y_train, X_test, y_test)
    
    # fit model
    model3.fit(X_train, y_train)
    model5.fit(X_train, y_train)
    model7.fit(X_train, y_train)

    # calculate mse
    mse3 = model3.mse(X_train, y_train), model3.mse(X_test, y_test)
    mse5 = model5.mse(X_train, y_train), model5.mse(X_test, y_test)
    mse7 = model7.mse(X_train, y_train), model7.mse(X_test, y_test)

    # Check if a plot is to be made for the entered alpha
    plot_regression_model(model3, 'r')
    plot_regression_model(model5, 'y')
    plot_regression_model(model7, 'b')

    # plot train mse,
    plot_mse(model3, mse3,1, '-', 'r')
    plot_mse(model5, mse5,2, '-', 'y')
    plot_mse(model7, mse7,3, '-', 'b')

    # show plots
    plt.show()