In [1]:
import numpy as np

from numpy.linalg import inv
from math import sqrt

from sklearn.model_selection import LeaveOneOut

# Follow instructions here to get cvxopt working. --- painful!!! 
#https://stackoverflow.com/questions/46009925/how-to-install-cvxopt-on-on-windows-10-on-python-3-6
from cvxopt import matrix, solvers

# Validation

Linear regression model

In [2]:
class LinearRegression:
    """ Linear Regression class. """
    
    def __init__(self, transform=True):
        """Initializer
        Args:
        transform (bool): Apply non-linear transformation to the data
        k (int): Dimensionality of non-linear transformation, 3 <= k <= 7
        lambda (float): Regularization parameter
        Returns:
        np.ndarray: Weights
        """           
        self.non_linear = transform 
    
    def train(self, train_data, k=3, const_hypothesis=False, llambda=0.0):    
        """Train the Linear Regression Algorithm
        Args:
        train_data (np.ndarray): Training data
        k (int): Dimensionality of non-linear transformation, 3 <= k <= 7
        const_hypothesis (bool): If True, use constant hypothesis h(x) = b, 
                                 If False, use linear hypothesis h(x) = ax + b
        lambda (float): Regularization parameter
        Returns:
        np.ndarray: Weights
        """      
        # apply non-linear transformation to inputs
        inputs = train_data[:, 0:np.size(train_data, axis=1) - 1]
        Z = self.transform_inputs(inputs, k) # transform inputs
        
        # get outputs
        y = train_data[:, np.size(train_data, axis=1) - 1]
        
        # linear regression solution with regularization
        if const_hypothesis:
            weights = np.mean(y)
        else:
            Z_trans = np.transpose(Z)
            normZ = np.matmul(Z_trans, Z)  
            I = np.identity(np.size(normZ, 0))
            weights = np.matmul(inv(np.add(normZ, llambda * I)), np.matmul(Z_trans, y))
        return weights
    
    def sample_error(self, weights, data, k=3):    
        """Compute sample classification error
        Args:
        weights (np.ndarray): Weights of hypothesis
        data (np.ndarray): Training/validation/testing data
        k (int): Dimensionality of non-linear transformation, 3 <= k <= 7
        Returns:
        float: Sample error (fraction of misclassified points)
        """
        z = self.transform_inputs(data[:, 0:np.size(data, axis=1) - 1], k)
        y = data[:, np.size(data, axis=1) - 1]
        
        num_errors = 0
        for n in range(0, len(z)):
            #print('weights.T) = ', weights.T)
            #print('z[n] = ', z[n])
            #print('y[n] = ', y[n])
            if np.sign(np.dot(weights.T, z[n])) != y[n]:
               num_errors += 1
        return num_errors / len(z)
    
    def cross_val_error(self, data, const_hypothesis=False, k=3):
        """Compute cross-validation error using squared error
        Args:
        data (np.ndarray): Data
        const_hypothesis (bool): If True, use constant hypothesis h(x) = b, 
                                 If False, use linear hypothesis h(x) = ax + b
        k (int): Dimensionality of non-linear transformation, 3 <= k <= 7
        Returns:
        float: Cross-validation error (using squared error)
        """
        #print('data = ', data)
        squared_errors = []
        loo = LeaveOneOut()
        for train_ind, val_ind in loo.split(data):
            #print('=========================================================')
            train_set = data[train_ind]
            val_set = data[val_ind]
            #print('train_set = ', train_set)
            #print('val_set = ', val_set)
            weights = self.train(train_set, k, const_hypothesis)
            
            #print('weights = ', weights) 
            
            #z = self.transform_inputs(data[:, 0:np.size(data, axis=1) - 1], k)
            #y = data[:, np.size(data, axis=1) - 1]
            z = self.transform_inputs(val_set, k)
            y = val_set[:, np.size(val_set, axis=1) - 1]
            #print('z = ', z)
            #print('y = ', y)

            squared_errors.append((np.dot(weights.T, z[0]) - y[0])**2)
        
        return np.mean(squared_errors)       
         
    # privates
    def transform_inputs(self, inputs, k):
        ones = np.ones(len(inputs)).reshape(len(inputs), 1)       
        if self.non_linear:      
            x1 = inputs[:, 0].reshape(len(inputs), 1)
            x2 = inputs[:, 1].reshape(len(inputs), 1)
            transform = (ones, x1, x2, x1**2, x2**2, x1*x2, np.abs(x1-x2), np.abs(x1+x2))
        else:
            x1 = inputs[:, 0].reshape(len(inputs), 1)
            transform = (ones, x1)
        trans_inputs = np.concatenate(transform[:k+1], axis=1)
        #print('trans_inputs = ', trans_inputs)
        return trans_inputs

Compute validation and out-of-sample (test) errors

In [3]:
def val_test_error(train_set, val_set, test_set):
    low_k_val = 2
    low_k_out = 2
    low_error_val = 999
    low_error_out = 999
    for k in range(3, 8):
        lr = LinearRegression()
        weights = lr.train(train_set, k)
        error_val = lr.sample_error(weights, val_set, k)
        error_out = lr.sample_error(weights, test_set, k)
        if error_val < low_error_val:
            low_error_val = error_val
            low_k_val = k
        if error_out < low_error_out:
            low_error_out = error_out
            low_k_out = k
    print('The lowest validation error is {0} occuring at k = {1}'.format(round(low_error_val, 1), low_k_val))
    print('The lowest out-of-sample (test) error is {0} occuring at k = {1}'.format(round(low_error_out, 1), low_k_out))

Load in.dta and out.dta

In [4]:
in_dta = np.array(
[ -7.7947021e-01,   8.3822138e-01,   1.0000000e+00,
   1.5563491e-01,   8.9537743e-01,   1.0000000e+00,
  -5.9907703e-02,  -7.1777995e-01,   1.0000000e+00,
   2.0759636e-01,   7.5893338e-01,   1.0000000e+00,
  -1.9598312e-01,  -3.7548716e-01,  -1.0000000e+00,
   5.8848947e-01,  -8.4255381e-01,   1.0000000e+00,
   7.1985874e-03,  -5.4831650e-01,  -1.0000000e+00,
   7.3883852e-01,  -6.0339369e-01,   1.0000000e+00,
   7.0464808e-01,  -2.0420052e-02,   1.0000000e+00,
   9.6992666e-01,   6.4137120e-01,  -1.0000000e+00,
   4.3543099e-01,   7.4477254e-01,  -1.0000000e+00,
  -8.4425822e-01,   7.4235423e-01,   1.0000000e+00,
   5.9142471e-01,  -5.4602118e-01,   1.0000000e+00,
  -6.9093124e-02,   3.7659995e-02,  -1.0000000e+00,
  -9.5154865e-01,  -7.3305502e-01,  -1.0000000e+00,
  -1.2988138e-01,   7.5676096e-01,   1.0000000e+00,
  -4.9534647e-01,  -5.6627908e-01,  -1.0000000e+00,
  -9.0399413e-01,   5.0922150e-01,   1.0000000e+00,
   2.9235128e-01,   1.6089015e-01,  -1.0000000e+00,
   6.4798552e-01,  -7.7933769e-01,   1.0000000e+00,
   3.7595574e-01,   7.8203087e-02,  -1.0000000e+00,
   2.4588993e-01,   4.5146739e-03,  -1.0000000e+00,
  -4.5719155e-01,   4.2390461e-01,   1.0000000e+00,
  -4.4127876e-01,   7.0571892e-01,   1.0000000e+00,
   5.0744669e-01,   7.5872586e-01,  -1.0000000e+00,
  -1.3258381e-01,  -5.8178837e-01,  -1.0000000e+00,
  -4.4749067e-01,   1.9576364e-01,   1.0000000e+00,
   8.1658199e-01,  -4.5449182e-01,   1.0000000e+00,
  -9.4422408e-01,   8.8273421e-01,   1.0000000e+00,
   4.6265533e-01,   3.5583605e-01,  -1.0000000e+00,
   8.8311642e-01,  -1.9930013e-01,   1.0000000e+00,
   1.0016050e+00,   5.2575476e-01,  -1.0000000e+00,
   6.0370095e-01,  -5.4553701e-01,   1.0000000e+00,
  -1.4858757e-01,  -2.1308372e-01,  -1.0000000e+00,
   1.1652163e-02,   8.8923931e-01,   1.0000000e+00 ]
).reshape((35, 3))

out_dta = np.array(
[ -1.0600562e-01,  -8.1467034e-02,  -1.0000000e+00,
   1.7792951e-01,  -3.4595141e-01,  -1.0000000e+00,
   1.0216153e-01,   7.1825825e-01,   1.0000000e+00,
   6.9407831e-01,   6.2339743e-01,  -1.0000000e+00,
   2.3541068e-02,   7.2743221e-01,   1.0000000e+00,
  -3.1972776e-01,  -8.3411411e-01,  -1.0000000e+00,
  -1.8674372e-01,   5.3887798e-01,   1.0000000e+00,
  -6.3696719e-01,   1.5268485e-01,   1.0000000e+00,
  -4.7446260e-01,   8.5434436e-01,   1.0000000e+00,
  -3.5627652e-02,  -2.7158819e-01,  -1.0000000e+00,
  -1.4860269e-01,   1.6176177e-01,  -1.0000000e+00,
  -1.8065154e-01,  -1.2873906e-01,  -1.0000000e+00,
  -6.0241113e-01,   9.2550746e-01,   1.0000000e+00,
   6.9808093e-01,   7.9474170e-01,  -1.0000000e+00,
   8.8150888e-01,  -2.0124822e-01,   1.0000000e+00,
  -9.2384936e-01,   3.8662506e-01,   1.0000000e+00,
  -7.6571338e-01,  -1.1281293e-02,   1.0000000e+00,
   1.3559198e-01,   3.1705057e-02,  -1.0000000e+00,
  -1.5515148e-01,  -3.3141997e-01,  -1.0000000e+00,
   4.8517476e-01,   2.9903104e-01,  -1.0000000e+00,
  -6.0290010e-01,   3.3323420e-01,   1.0000000e+00,
  -5.7285816e-01,   8.2835226e-01,   1.0000000e+00,
  -6.3539998e-01,  -4.7456571e-01,  -1.0000000e+00,
   9.0931654e-01,  -7.8488922e-01,   1.0000000e+00,
   2.5210516e-01,  -8.9393713e-01,   1.0000000e+00,
  -5.1763422e-01,   9.6044364e-01,   1.0000000e+00,
  -3.8587212e-01,  -3.1786995e-01,  -1.0000000e+00,
   8.2316659e-01,  -1.2779654e-01,   1.0000000e+00,
   8.2248638e-01,  -8.7684295e-01,   1.0000000e+00,
  -5.0366208e-01,   9.8027386e-01,   1.0000000e+00,
   5.3387369e-01,   8.2123374e-01,  -1.0000000e+00,
  -8.9497008e-01,  -2.4011485e-01,   1.0000000e+00,
   3.4287096e-01,   4.7497683e-01,  -1.0000000e+00,
   7.0928861e-01,   5.6220681e-01,  -1.0000000e+00,
  -1.0004317e+00,   6.0457567e-02,   1.0000000e+00,
   5.2428439e-01,   7.3519522e-01,  -1.0000000e+00,
  -5.6033040e-01,   7.5583819e-01,   1.0000000e+00,
   6.9752202e-01,  -6.7198955e-01,   1.0000000e+00,
   4.9042314e-01,   7.8508662e-01,  -1.0000000e+00,
  -3.2677396e-01,   3.4337193e-01,   1.0000000e+00,
  -2.9342093e-03,  -4.1518173e-01,  -1.0000000e+00,
  -6.3123891e-01,   3.5263395e-01,   1.0000000e+00,
   9.1388134e-01,   5.9305320e-01,  -1.0000000e+00,
   2.1828280e-01,   3.9683537e-02,  -1.0000000e+00,
  -6.1618517e-01,  -8.8657924e-01,  -1.0000000e+00,
  -5.2852914e-01,   2.8690210e-02,   1.0000000e+00,
  -4.0652261e-01,   1.0451480e+00,   1.0000000e+00,
  -2.2979506e-01,   7.1425065e-02,  -1.0000000e+00,
  -5.0212113e-01,   8.3373782e-01,   1.0000000e+00,
  -5.0808021e-01,   7.9326965e-01,   1.0000000e+00,
  -7.9067809e-01,   1.8780315e-01,   1.0000000e+00,
  -3.8251119e-01,   8.2474204e-01,   1.0000000e+00,
   8.2232787e-01,   4.0148690e-01,  -1.0000000e+00,
   9.8596447e-01,  -3.2916872e-01,   1.0000000e+00,
  -1.4046951e-02,  -1.5238711e-01,  -1.0000000e+00,
  -5.4165125e-02,   9.1428474e-01,   1.0000000e+00,
  -1.0724684e+00,  -7.2028556e-01,  -1.0000000e+00,
  -2.4298545e-01,  -1.0426515e+00,   1.0000000e+00,
  -3.2448585e-01,  -2.8317976e-01,  -1.0000000e+00,
   2.4774928e-01,  -2.5565619e-01,  -1.0000000e+00,
  -1.7221056e-01,  -8.4939971e-01,   1.0000000e+00,
  -4.1726341e-01,  -3.9327101e-01,  -1.0000000e+00,
  -3.4783792e-01,  -5.7380853e-01,  -1.0000000e+00,
  -8.5183380e-01,  -7.2266417e-01,  -1.0000000e+00,
  -7.2524439e-01,  -3.7370735e-01,  -1.0000000e+00,
   3.4532731e-01,  -2.2271792e-02,  -1.0000000e+00,
   7.4242107e-01,   7.4085732e-01,  -1.0000000e+00,
  -1.3712343e-01,  -3.4725580e-01,  -1.0000000e+00,
   1.0591518e-01,   6.3378770e-01,   1.0000000e+00,
   3.3240671e-01,  -5.6552847e-01,   1.0000000e+00,
  -4.1754141e-01,   9.4856269e-01,   1.0000000e+00,
  -4.0488945e-01,  -6.1346877e-01,  -1.0000000e+00,
  -7.9715796e-01,   9.0701039e-01,   1.0000000e+00,
   8.7592113e-01,   3.6021089e-01,  -1.0000000e+00,
   5.4354035e-01,  -1.8109142e-01,   1.0000000e+00,
   7.5379749e-02,  -5.1108329e-01,  -1.0000000e+00,
   5.6404879e-01,   7.7190965e-01,  -1.0000000e+00,
   8.1699117e-01,   5.2558687e-01,  -1.0000000e+00,
  -3.7661111e-01,   1.0599395e-01,   1.0000000e+00,
   4.3602851e-01,   1.5045963e-01,  -1.0000000e+00,
   3.9691855e-01,  -5.4893287e-01,   1.0000000e+00,
  -2.7417659e-01,   6.0223846e-01,   1.0000000e+00,
  -9.8918073e-01,   1.5764882e-01,   1.0000000e+00,
  -5.1644146e-01,  -8.2457351e-01,  -1.0000000e+00,
   9.8064342e-01,   5.4613841e-01,  -1.0000000e+00,
   7.7755709e-01,  -8.9346511e-01,   1.0000000e+00,
  -2.5936394e-01,  -6.4448037e-01,  -1.0000000e+00,
  -2.3771804e-01,  -9.0677112e-01,  -1.0000000e+00,
  -6.0395095e-01,   8.8173896e-02,   1.0000000e+00,
  -2.8011575e-01,  -1.5129025e-02,  -1.0000000e+00,
  -2.0373733e-01,   7.9798618e-01,   1.0000000e+00,
  -1.6372581e-01,   4.3600528e-01,   1.0000000e+00,
   7.4444529e-01,   4.1089362e-01,  -1.0000000e+00,
  -3.3226757e-01,  -4.5875081e-01,  -1.0000000e+00,
  -2.7664943e-02,  -3.6080053e-01,  -1.0000000e+00,
   7.0696684e-01,   7.5427653e-01,  -1.0000000e+00,
  -8.2301562e-01,  -3.0287426e-01,   1.0000000e+00,
  -9.8538078e-01,  -3.8441488e-01,   1.0000000e+00,
  -4.9099732e-01,   7.0280304e-01,   1.0000000e+00,
  -5.2247192e-01,   3.0000586e-01,   1.0000000e+00,
  -5.6967959e-01,   1.0446241e-01,   1.0000000e+00,
  -3.2370497e-01,   7.2171478e-01,   1.0000000e+00,
   9.1968888e-01,  -3.2574194e-01,   1.0000000e+00,
   8.1817666e-01,   3.4900834e-01,  -1.0000000e+00,
  -7.1264726e-01,  -4.9117684e-01,  -1.0000000e+00,
   5.3684649e-01,   1.0477039e+00,  -1.0000000e+00,
   4.8856574e-02,   1.2522709e-01,  -1.0000000e+00,
   4.0044649e-01,  -1.3359043e-01,  -1.0000000e+00,
   7.2957688e-01,  -3.6594220e-01,   1.0000000e+00,
  -8.4640286e-01,   8.7034709e-01,   1.0000000e+00,
   8.3074310e-01,   5.8471902e-01,  -1.0000000e+00,
   4.4642585e-01,   2.8335762e-01,  -1.0000000e+00,
   6.3503752e-01,   8.5261243e-01,  -1.0000000e+00,
   1.3470715e-01,   8.4030406e-01,   1.0000000e+00,
  -5.8076298e-01,  -1.4677565e-02,   1.0000000e+00,
  -4.2655272e-01,   4.9157607e-01,   1.0000000e+00,
  -3.0916589e-02,   1.0815951e+00,   1.0000000e+00,
   4.8724109e-01,  -7.5002965e-01,   1.0000000e+00,
  -5.0642751e-01,  -9.1039293e-01,  -1.0000000e+00,
   2.4941012e-01,   3.8309361e-01,  -1.0000000e+00,
  -4.4334283e-01,  -7.6390321e-01,  -1.0000000e+00,
  -1.0454160e-01,  -8.8349573e-01,   1.0000000e+00,
  -2.8387756e-01,  -5.7173283e-01,  -1.0000000e+00,
   1.0121137e+00,   3.7187620e-01,   1.0000000e+00,
   6.1823198e-02,  -6.0798489e-01,   1.0000000e+00,
   8.7139506e-02,   3.6089253e-01,  -1.0000000e+00,
   8.7154983e-01,   4.1356362e-01,  -1.0000000e+00,
  -4.2224414e-01,   5.2139273e-01,   1.0000000e+00,
  -5.2995701e-01,   6.9943942e-01,   1.0000000e+00,
   5.9019103e-01,   4.4697246e-01,  -1.0000000e+00,
   8.4047379e-01,  -8.5034269e-01,   1.0000000e+00,
   9.1856967e-02,  -2.3185106e-01,  -1.0000000e+00,
  -8.2210059e-02,  -4.0215845e-01,  -1.0000000e+00,
   9.7316077e-01,   6.4157939e-01,  -1.0000000e+00,
   4.3572065e-01,   2.7636108e-01,  -1.0000000e+00,
   5.2452847e-01,  -5.4354500e-01,   1.0000000e+00,
   5.5473340e-01,   3.8297967e-01,  -1.0000000e+00,
   9.5598815e-01,  -8.0157282e-01,   1.0000000e+00,
  -7.7089155e-01,   4.3341968e-01,   1.0000000e+00,
   8.8986401e-01,  -5.3114670e-01,   1.0000000e+00,
  -2.6199343e-01,   6.6586981e-02,   1.0000000e+00,
   6.4157196e-01,  -1.8224977e-01,   1.0000000e+00,
  -4.1531211e-01,   7.7832594e-01,   1.0000000e+00,
   4.6745673e-01,   8.1829167e-01,  -1.0000000e+00,
   4.7789055e-01,  -2.4897772e-01,   1.0000000e+00,
  -8.2623795e-01,   9.4397856e-01,   1.0000000e+00,
  -9.4592941e-01,  -4.2810919e-01,   1.0000000e+00,
   5.0662354e-01,  -8.0927375e-01,   1.0000000e+00,
  -5.3612852e-01,   2.7239961e-02,   1.0000000e+00,
   4.1017996e-01,   4.3030737e-01,  -1.0000000e+00,
  -2.6162250e-01,  -5.8031367e-01,  -1.0000000e+00,
   3.1911280e-01,   2.1557907e-01,  -1.0000000e+00,
   1.4756822e-01,  -8.1499284e-01,   1.0000000e+00,
   6.2989832e-01,   6.9969730e-01,  -1.0000000e+00,
  -9.5407922e-01,   9.2649187e-01,   1.0000000e+00,
  -2.4492709e-01,   1.8247041e-02,  -1.0000000e+00,
   5.8108886e-01,   3.1610360e-01,  -1.0000000e+00,
   4.1018695e-02,  -4.5715509e-01,  -1.0000000e+00,
   5.8396336e-01,   6.8267605e-01,  -1.0000000e+00,
  -4.8730486e-01,   8.6463229e-01,   1.0000000e+00,
   8.2589939e-01,  -4.9791886e-02,   1.0000000e+00,
   5.2174069e-01,  -8.8946540e-01,   1.0000000e+00,
   8.3882489e-01,  -8.2646301e-01,   1.0000000e+00,
  -6.2960544e-01,  -1.4889383e-01,  -1.0000000e+00,
   8.5244440e-01,  -1.0574672e+00,   1.0000000e+00,
   3.8323766e-01,   6.7882270e-01,  -1.0000000e+00,
  -5.6987331e-02,   6.0584030e-01,   1.0000000e+00,
   3.0412897e-01,  -1.0666795e+00,   1.0000000e+00,
  -8.6141937e-01,  -2.2624830e-01,   1.0000000e+00,
  -8.6320225e-01,   2.0290163e-01,   1.0000000e+00,
   5.5946340e-01,   1.5353255e-01,  -1.0000000e+00,
  -4.1758769e-01,  -3.2695057e-01,  -1.0000000e+00,
   1.2212703e-01,  -2.4896540e-01,  -1.0000000e+00,
   7.4665959e-01,   5.5546242e-01,  -1.0000000e+00,
   2.0615224e-01,   5.7418978e-01,   1.0000000e+00,
  -8.9098492e-01,   4.9858238e-01,   1.0000000e+00,
  -6.8592982e-01,  -4.6911031e-01,  -1.0000000e+00,
  -1.9780710e-01,  -1.3615995e-01,  -1.0000000e+00,
   6.1697403e-01,   1.2927644e-01,  -1.0000000e+00,
  -7.9341609e-01,   3.6140632e-01,   1.0000000e+00,
  -1.0101707e+00,   7.8496772e-02,   1.0000000e+00,
  -8.6176046e-01,  -5.7865892e-01,  -1.0000000e+00,
  -1.9720509e-01,   2.7524501e-01,  -1.0000000e+00,
   5.7561847e-01,   9.3704466e-01,   1.0000000e+00,
  -6.1354207e-01,  -9.4116391e-01,  -1.0000000e+00,
  -5.5181353e-01,  -2.6804314e-01,  -1.0000000e+00,
   1.6003740e-01,  -3.4197787e-01,  -1.0000000e+00,
  -3.9716418e-01,   6.5651695e-01,   1.0000000e+00,
  -5.7886995e-01,  -8.7340266e-01,  -1.0000000e+00,
   8.3179248e-01,  -8.5337474e-02,   1.0000000e+00,
  -2.9893087e-01,  -4.2892168e-01,  -1.0000000e+00,
  -1.4926315e-01,   6.5394038e-01,   1.0000000e+00,
  -7.4495752e-01,  -7.1933493e-01,  -1.0000000e+00,
   1.5562362e-01,   9.2067567e-01,   1.0000000e+00,
   5.2893831e-01,   9.1665351e-01,  -1.0000000e+00,
   8.2824454e-02,  -6.2795455e-01,   1.0000000e+00,
  -9.3987636e-01,  -6.6272682e-01,  -1.0000000e+00,
   7.1418089e-01,  -2.5977397e-01,   1.0000000e+00,
  -1.1125764e-02,  -8.4275573e-01,   1.0000000e+00,
   5.4275886e-01,   1.1839569e-01,  -1.0000000e+00,
   7.3410595e-01,  -8.9072193e-01,   1.0000000e+00,
   3.7896208e-01,  -1.1608314e-01,  -1.0000000e+00,
  -1.6656364e-01,  -4.1054076e-01,  -1.0000000e+00,
  -7.8201587e-01,   3.7672979e-01,   1.0000000e+00,
   2.7177275e-01,   8.1096707e-01,   1.0000000e+00,
  -8.6780139e-01,  -6.7610958e-01,  -1.0000000e+00,
  -1.8120946e-01,   6.8091241e-01,   1.0000000e+00,
  -4.4472888e-02,   4.4147201e-04,  -1.0000000e+00,
   4.2931482e-01,   8.2890920e-01,  -1.0000000e+00,
  -8.3797535e-01,  -7.6967158e-02,   1.0000000e+00,
   6.9977931e-01,  -2.0817600e-01,   1.0000000e+00,
   7.7403264e-01,   5.1157000e-01,  -1.0000000e+00,
  -6.8813838e-01,   7.9310664e-01,   1.0000000e+00,
  -4.2531784e-01,  -8.5040532e-01,  -1.0000000e+00,
   4.4357173e-01,  -2.4214875e-01,   1.0000000e+00,
  -2.1860655e-03,   6.9743662e-01,   1.0000000e+00,
   3.2546528e-01,  -1.8578580e-01,  -1.0000000e+00,
   2.7126831e-01,  -8.5202415e-01,   1.0000000e+00,
   2.0834261e-01,  -8.2931143e-01,   1.0000000e+00,
  -3.3803178e-01,  -8.9404250e-01,  -1.0000000e+00,
  -2.4260246e-02,  -5.5091178e-01,  -1.0000000e+00,
   2.5541631e-01,  -2.8836090e-01,  -1.0000000e+00,
  -7.1661879e-01,   4.1985982e-04,   1.0000000e+00,
   1.3171741e-01,  -4.5983099e-01,  -1.0000000e+00,
   3.4455847e-01,  -1.2877727e-01,  -1.0000000e+00,
   8.2438408e-02,  -9.7262595e-01,   1.0000000e+00,
   5.3327355e-01,   2.9512885e-01,  -1.0000000e+00,
  -3.3881875e-01,   9.1995587e-01,   1.0000000e+00,
   5.5124163e-01,  -8.4634282e-01,   1.0000000e+00,
  -4.1056066e-01,   5.1225154e-01,   1.0000000e+00,
   4.6292501e-01,  -7.3563881e-01,   1.0000000e+00,
   5.7568892e-01,  -5.8952332e-01,   1.0000000e+00,
  -6.3209364e-01,  -9.8042189e-01,  -1.0000000e+00,
  -1.6764181e-01,  -5.2920071e-01,  -1.0000000e+00,
   7.2036792e-01,  -1.0400766e+00,   1.0000000e+00,
   7.5002909e-01,  -5.3809352e-01,   1.0000000e+00,
   2.5173000e-01,  -9.6077128e-01,   1.0000000e+00,
  -7.2459762e-01,   7.4403736e-02,   1.0000000e+00,
  -7.2042750e-01,  -5.5720908e-01,  -1.0000000e+00,
  -9.5384434e-01,   4.7741916e-01,   1.0000000e+00,
   7.1149028e-01,  -9.9011601e-01,   1.0000000e+00,
   2.9083625e-01,  -4.4320360e-01,   1.0000000e+00,
   3.1974619e-01,  -4.0070957e-01,   1.0000000e+00,
   2.3353842e-01,   6.3672589e-01,   1.0000000e+00,
  -1.9596922e-01,  -9.9010507e-01,   1.0000000e+00,
  -4.3776733e-01,   1.1653144e-02,   1.0000000e+00,
  -3.5482986e-01,   8.1982000e-01,   1.0000000e+00,
   3.4704463e-01,  -5.4508368e-01,   1.0000000e+00,
   8.3637625e-01,   3.4383080e-01,  -1.0000000e+00,
  -7.1385080e-01,  -6.4057506e-01,  -1.0000000e+00 ]
).reshape((250, 3))
test_set = out_dta

Split in_dta into a training set (first 25 examples) and validation set (last 10 examples) and evaluate validation and test errors

In [5]:
train_set = in_dta[:25]
val_set = in_dta[25:]

val_test_error(train_set, val_set, test_set)

The lowest validation error is 0.0 occuring at k = 6
The lowest out-of-sample (test) error is 0.1 occuring at k = 7


Split in_dta into a training set (first 10 examples) and validation set (last 25 examples) and evaluate validation and test errors

In [6]:
train_set = in_dta[25:]
val_set = in_dta[:25]

val_test_error(train_set, val_set, test_set)

The lowest validation error is 0.1 occuring at k = 6
The lowest out-of-sample (test) error is 0.2 occuring at k = 6


# Cross Validation

Create data points

In [7]:
rhos = [sqrt(sqrt(3) + 4), sqrt(sqrt(3) - 1), sqrt(9 + 4*sqrt(6)), sqrt(9 - sqrt(6))]
const_hypotheses = {False:'h(x)=ax+b', True:'h(x)=b'}

Do leave one-out cross-validation (LOOCV) using squared error

In [8]:
def cross_validation(rhos, const_hypotheses):
    for const_hypothesis in const_hypotheses.keys():
        for rho in rhos:
            data = np.array([-1., 0., 
                 rho, 1., 
                 1., 0.
                ]).reshape(3, 2)
            lr = LinearRegression(transform=False)
            error = lr.cross_val_error(data, const_hypothesis, k=2)
            (print('For ρ = {0} and const_hypothesis {1} the cross-validation error is {2}.'
                   .format(rho, const_hypotheses[const_hypothesis], round(error, 2))))

In [9]:
cross_validation(rhos, const_hypotheses)

For ρ = 2.3941701709713277 and const_hypothesis h(x)=ax+b the cross-validation error is 1.14.
For ρ = 0.8555996771673521 and const_hypothesis h(x)=ax+b the cross-validation error is 64.66.
For ρ = 4.335661307243996 and const_hypothesis h(x)=ax+b the cross-validation error is 0.5.
For ρ = 2.5593964634688433 and const_hypothesis h(x)=ax+b the cross-validation error is 0.99.
For ρ = 2.3941701709713277 and const_hypothesis h(x)=b the cross-validation error is 0.5.
For ρ = 0.8555996771673521 and const_hypothesis h(x)=b the cross-validation error is 0.5.
For ρ = 4.335661307243996 and const_hypothesis h(x)=b the cross-validation error is 0.5.
For ρ = 2.5593964634688433 and const_hypothesis h(x)=b the cross-validation error is 0.5.


# PLA vs SVM

Create line class

In [10]:
class Line:
    """Class for representing lines in 2-dimensional coordinates."""
    def __init__(self, points):
        """ Create a new line 
        Args:
        points (2 x 2 np.ndarray): Points 
        """
        self.x1x2 = points
    
    @staticmethod
    def transform_inputs(points):
        """Transform points to include the constant term 1.0
        Args:
        points (np.ndarray): Points 
        Returns:
        np.ndarray: Points
        """
        ones = np.ones(len(points)).reshape(len(points), 1)
        x0x1x2 = np.concatenate((ones, points), axis=1)
        return x0x1x2
        
    def evaluate(self, points):
        """Evaluate to see which side of line the points lie on
        Args:
        points (np.ndarray): Points 
        Returns:
        np.ndarray: Signs (+1/-1) indicating which side of line a point lies on 
        """
        ones = np.ones(len(points)).reshape(len(points), 1)
        x = np.concatenate((ones, points), axis=1)
        y = np.array([np.sign(np.dot(self.weights().T, xn)) for xn in x]) 
        return y
    
    def weights(self):
        """Evaluate weights describing parametric equation of the line
        Args: 
        Returns:
        np.ndarray: Weights of the line 
        """
        y1_times_x2 = self.x1x2[0][1] * self.x1x2[1][0]
        x1_times_y2 = self.x1x2[0][0] * self.x1x2[1][1]
        y2_minus_y1 = self.x1x2[1][1] - self.x1x2[0][1]
        x2_minus_x1 = self.x1x2[1][0] - self.x1x2[0][0]   
        return np.array([y1_times_x2 - x1_times_y2, y2_minus_y1, -x2_minus_x1])

Create Perceptron Learning Algorithm (PLA) class

In [11]:
class PLA:
    """ Perceptron Learning Algorithm (PLA) class. """
    
    def __init__(self, inputs, outputs, line):
        """ Create a new Perceptron Learning Algorithm (PLA). 
        Args:
        inputs (np.ndarray): Input points.
        outputs (np.ndarray): Targets.
        line (Line): Line.
        """
        self.line = line
        self.xn = Line.transform_inputs(inputs)
        self.yn = outputs
        self.N = len(self.xn)
    
    def run(self, initial_weights=[0., 0., 0.], max_iterations=100000):    
        """Run of Perceptron Learning Algorithm (PLA).
        Args:
        initial_weights (list): Initial weights.
        max_iterations (int): Max number of iterations.
        Returns:
        np.ndarray : Weights
        """                
        # PLA iterations
        converged = False
        weights = np.copy(initial_weights) # initialize weights
        for i in range(0, max_iterations):
            
            # misclassified points
            mis_point_idxs = self.misclassified_points(weights)
        
            # break out of loop if converged
            if len(mis_point_idxs) == 0:
                converged = True
                break
            
            # pick misclassified point at random
            mis_point_idx = np.random.choice(mis_point_idxs)
            #print('Chosen misclassified point index ', mis_point_idx)
            
            # update weights
            weights = np.add(weights, np.multiply(self.yn[mis_point_idx], self.xn[mis_point_idx]))
            #print('Weights ', weights)
            
        if not converged:
            raise Exception('Perceptron algorithm failed to converge after maximum number of iterations.')
        
        return weights
    
    # private methods
    def misclassified_points(self, weights):
        mis_point_idxs = []      
        for i in range(0, self.N):
            if np.sign(np.dot(weights, self.xn[i])) != self.yn[i]:
                mis_point_idxs.append(i)
        #print('Misclassified point indexes', mis_point_idxs)
        return mis_point_idxs

Create Support Vector Machine (SVM) class

In [12]:
class SVM:
    """ Support Vector Machine (SVM) class. """
    
    def __init__(self, inputs, outputs, line, problem='dual'):
        """ Create a new Support Vector Machine Algorithm (SVM). 
        Args:
        inputs (np.ndarray): Input points.
        outputs (np.ndarray): Targets.
        line (Line): Line.
        problem (string): SVM Problem to solve, 'primal' or 'dual'
        """
        if problem != 'primal' and problem != 'dual':
            raise ValueError('Problem must be "primal" or "dual"')
        
        self.problem = problem
        if self.problem == 'dual':
            self.xn = inputs 
        else: # problem = # primal
            self.xn = Line.transform_inputs(inputs)
        self.yn = outputs
        self.line = line
        self.N = len(self.xn)
        
    def run(self):    
        """Run of Support Vector Machine Algorithm (SVM).
        Args:
        Returns:
        Tuple(np.ndarray, int): Weights, number of support vectors
        """
        mat = []
        if self.problem == 'dual':
            for row_idx in range(0, self.N):
                for col_idx in range(0, self.N):
                    val = self.yn[row_idx] * self.yn[col_idx] * np.dot(self.xn[row_idx].T, self.xn[col_idx])
                    mat.append(val)
            mat = np.array(mat).reshape((self.N, self.N))    
        elif self.problem == 'primal':
            for n in range(0, self.N):
                val = [-self.yn[n] * x for x in self.xn[n]]
                mat.append(val)
            mat = np.array(mat)
        
        # form matrices for quadratic programming solver
        dim = len(self.xn[0])
        if self.problem == 'dual':
            P = matrix(mat, tc='d')
            q = matrix(-np.ones(self.N), tc='d')
            b = matrix(0.0, tc='d')
            A = matrix(self.yn, tc='d')
            A = A.trans()
            h = matrix(np.zeros(self.N), tc='d')
            G = matrix(-np.identity(self.N), tc='d')
        elif self.problem == 'primal':
            P = matrix(np.identity(dim), tc='d') 
            q = matrix(0.0, (dim, 1), tc='d')
            h = matrix(-1.0, (self.N, 1), tc='d')
            G = matrix(mat, tc='d')
        
        #print('P = ', P)
        #print('q = ', q)
        #print('G = ', G)
        #print('h = ', h)
        #print('A = ', A)
        #print('b = ', b)
                
        # call qp solver to compute weights
        solvers.options['show_progress'] = False # supress solver output
        
        if self.problem == 'dual':
            sol = solvers.qp(P, q, G, h, A, b)
            alpha = np.array(list(sol['x']))
            #print('alpha = ', sol['x'])
            
            weights = np.zeros(dim)
            sv_idxs = []
            for n in range(0, self.N):
                weights += alpha[n] * self.yn[n] * self.xn[n]
                if alpha[n] > 0.001: # => self.xn[n] is support vector
                    sv_idxs.append(n)
            
            bs = []
            for n in sv_idxs:
                b = (1. / self.yn[n]) - np.dot(weights.T, self.xn[n])
                bs.append(b)
            bs_round = np.round(bs, 1)          
            #print('sv_idx = ', sv_idxs)
            #print('bs = ', bs)
            if (len(np.unique(bs_round)) != 1):
                raise Exception('All support vectors must produce the same value of b.')
            
            weights = np.insert(weights, 0, b)
        
        elif self.problem == 'primal':          
            sol = solvers.qp(P, q, G, h)
            weights = np.array(list(sol['x']))
        
        #print('weights = ', weights)
            
        # compute number of support vectors
        num_sv = 0   
        if self.problem == 'dual':
            num_sv = len(bs)
        elif self.problem == 'primal':
            for n in range(0, self.N):
                if np.abs(self.yn[n] * np.dot(weights.T, self.xn[n]) - 1.0) < 0.001: # => self.xn[n] is support vector
                    num_sv += 1
        
        return weights, num_sv

Calculate disagreement probability

In [13]:
def disagreement_probability(weights, line, points):
        x = Line.transform_inputs(points)
        y = line.evaluate(points)
        g = np.array([np.sign(np.dot(weights, xn)) for xn in x])
        #print('y = ', y)
        #print('g = ', g)
        return 100. * np.sum(y != g) / len(y)

Create target line

In [14]:
def create_line():
    points = np.array([np.random.uniform(-1, 1, 2) for n in range(0, 2)])
    line = Line(points)
    return line

Create dataset

In [15]:
def create_dataset(line, N=10):
    inputs = None
    outputs = None
    for n in range(0, N*1000):
        inputs = np.array([np.random.uniform(-1, 1, 2) for n in range(0, N)]).reshape(N, 2)
        outputs = line.evaluate(inputs)
        if np.abs(np.sum(outputs)) != N: # all points do not have same sign
            break
    return inputs, outputs

Run experiments

In [16]:
def run_experiments(N=10, num_runs=1000, problem='dual'):
    num_svm_better = 0
    ave_num_sv = 0
    
    for i in range(0, num_runs):
        # create target line
        line = create_line()
        
        # create dataset
        inputs, outputs = create_dataset(line, N)
        #print('outputs = ', outputs)
        
        # PLA run
        pla = PLA(inputs, outputs, line)
        weights_pla = pla.run()
        #print('weights_pla = ', weights_pla)
        
        # SVM run
        svm = SVM(inputs, outputs, line, problem)
        weights_svm, num_sv = svm.run()
        #print('weights_svm = ', weights_svm)
        #print('num_sv = ', num_sv)
        ave_num_sv += num_sv
        
        # calculate disagreement probabilities
        points = np.array([np.random.uniform(-1, 1, 2) for n in range(0, N*100)]).reshape(N*100, 2)
        disagreement_prob_pla = disagreement_probability(weights_pla, line, points)
        disagreement_prob_svm = disagreement_probability(weights_svm, line, points)
        #print('svm prob = ', disagreement_prob_svm, 'pla prob = ', disagreement_prob_pla)
        
        if disagreement_prob_svm < disagreement_prob_pla:
            num_svm_better += 1
    
    ave_num_sv = ave_num_sv / num_runs
    percentage_svm_better = 100. * num_svm_better / num_runs
    
    (print('Percentage g_SVM better in approximating f than g_PLA = {0}% (for {1} problem)'
          .format(percentage_svm_better, problem)))
    print('Average number of support vectors = {0} (for {1} problem)'.format(round(ave_num_sv), problem))

Calculate the percentage of times g_SVM is better in approximating f than g_PLA using 10 points and 1000 runs

In [17]:
run_experiments(N=10, num_runs=1000, problem='dual')
run_experiments(N=10, num_runs=1000, problem='primal')

Percentage g_SVM better in approximating f than g_PLA = 60.3% (for primal problem)
Average number of support vectors = 2.862 (for primal problem)


Calculate the percentage of times g_SVM is better in approximating f than g_PLA using 100 points and 1000 runs

In [18]:
run_experiments(N=100, num_runs=1000, problem='dual')
run_experiments(N=100, num_runs=1000, problem='primal')

Percentage g_SVM better in approximating f than g_PLA = 64.8% (for primal problem)
Average number of support vectors = 2.998 (for primal problem)
