In [1]:
import numpy as np
import time
import sys
if "../" not in sys.path:
  sys.path.append("../") 
from lib.utils import read_data_from_file, sign
from lib.pla import pocket_pla

In [2]:
def linear_reg_closed_form(X, y):
    """
    Linear Regression Algorithm(Closed Form)
    Args:
        X: 数据
        y: 预测值
    Returns:
        w_lin: 特征权重
    """    
    X_pinv = np.linalg.pinv(X)
    w_lin = X_pinv.dot(y)
    return w_lin

In [3]:
def get_err_rate(X, y, w):
    err_rate = (sign(X.dot(w)) != y).mean()
    return err_rate

In [4]:
np.random.seed(0)

In [5]:
# 数据读取
data_train = read_data_from_file('hw1_18_train.dat') 
print('data_train shape: ', data_train.shape)
data_test = read_data_from_file('hw1_18_test.dat')
print('data_test shape: ', data_test.shape)

y = data_train[:,-1]
X = np.concatenate((np.ones((data_train.shape[0],1)), data_train[:,:-1]), axis=1)
y_test = data_test[:,-1]
X_test = np.concatenate((np.ones((data_test.shape[0],1)), data_test[:,:-1]), axis=1)

times = 2000

data_train shape:  (500, 5)
data_test shape:  (500, 5)


# linear_reg_closed_form 测试

In [6]:
err_rates = []
start = time.time()
for i in range(times):
    w_lin = linear_reg_closed_form(X, y)
    err_rates.append(get_err_rate(X_test, y_test, w_lin))
print('linear reg for classification:')
print("error rate: {}\tcost: {}".format(np.mean(np.array(err_rates)),time.time()-start))

linear reg for classification:
error rate: 0.10400000000000002	cost: 0.8960509300231934


# pocket_pla 测试,主要用于做对比

In [7]:
err_rates = []
start = time.time()
for i in range(times):
    w_pocket, _ = pocket_pla(X, y)
    err_rates.append(get_err_rate(X_test, y_test, w_pocket))
print('pocket pla for classification:')
print("error rate: {}\tcost: {}".format(np.mean(np.array(err_rates)),time.time()-start))

pocket pla for classification:
error rate: 0.13280200000000003	cost: 19.535125494003296


# 初值为w_lin的pocket_pla

In [9]:
err_rates = []
start = time.time()
w_lin = linear_reg_closed_form(X, y)
for i in range(times):
    w_pocket, _ = pocket_pla(X, y, w0=w_lin)
    err_rates.append(get_err_rate(X_test, y_test, w_pocket))
print('pocket pla with w0=w_lin for classification:')
print("error rate: {}\tcost: {}".format(np.mean(np.array(err_rates)),time.time()-start))

pocket pla with w0=w_lin for classification:
error rate: 0.10499300000000003	cost: 19.429277181625366
