## Import libraries

In [1]:
from data_loader import DataLoader
from generalized_linear_model import GeneralizedLinearModel
import matplotlib.pyplot as plt
%matplotlib notebook

## Load data

In [2]:
dataset_root_directory = './regression-dataset'
full_dataset = DataLoader.load_full_dataset(dataset_root_directory)
print('%d subsets of %d training examples with true values' % (len(full_dataset), len(full_dataset[0][0])))
print('Total %d training examples and true values' % (len(full_dataset)*len(full_dataset[0][0])))

10 subsets of 20 training examples with true values
Total 200 training examples and true values


## Initialize parameters

In [3]:
input_vector_size = 2

## Train models with different hyperparameters

In [4]:
basis_function_degrees = [d for d in range(1, 5)]
MSE_errors = []
running_times = []

for degree in basis_function_degrees:
    model = GeneralizedLinearModel(input_vector_degree=input_vector_size, feature_vector_degree=degree)
    print('Training in progress for basis function degree =', degree)
    model.learn(full_dataset, report_error=True)
    print()
    MSE_errors.append(model.mse_error)
    running_times.append(model.training_time)

Training in progress for basis function degree = 1
Mean Square Error = 1.301 
Training time = 3.81 seconds

Training in progress for basis function degree = 2
Mean Square Error = 1.024 
Training time = 4.77 seconds

Training in progress for basis function degree = 3
Mean Square Error = 0.066 
Training time = 6.58 seconds

Training in progress for basis function degree = 4
Mean Square Error = 0.067 
Training time = 8.43 seconds



## Plot Mean Square Error vs Basis function degree

In [7]:
plt.plot(basis_function_degrees,MSE_errors, '-o', c='r')
plt.xlabel('degree')
plt.ylabel('E_MSE')
plt.title('E_MSE vs degree')
plt.show()

<IPython.core.display.Javascript object>

In [9]:
plt.plot(basis_function_degrees,running_times, '-o', c='b')
plt.xlabel('degree')
plt.ylabel('time (seconds)')
plt.title('Running time vs degree')
plt.show()

<IPython.core.display.Javascript object>