In [1]:
import numpy as np

In [2]:
def readfile(filename):
    listdata = []
    myfile = open(filename, "r", encoding= "utf-8")
    for line in myfile:
        data = line.split()
        if data == []:
            break
        listdata.append(data)
    myfile.close()
    return listdata

In [3]:
def normalize_and_add_ones(X):
    X_max = np.array([[np.amax(X[:, column_id]) for column_id in range(X.shape[1])] for _ in range(X.shape[0])])
    X_min = np.array([[np.amin(X[:, column_id]) for column_id in range(X.shape[1])] for _ in range(X.shape[0])])
    X_normalized = (X-X_min)/(X_max-X_min)
    ones = np.array([[1] for _ in range(X_normalized.shape[0])])
    return np.column_stack((ones, X_normalized))

In [4]:
class RidgeRegression:
    def __init__(self):
        return

    def fit(self, X_train, Y_train, LAMBDA):
        assert len(X_train.shape) == 2 and X_train.shape[0] == Y_train.shape[0]
        W = np.linalg.inv(
            X_train.transpose().dot(X_train) + LAMBDA * np.identity(X_train.shape[1])
        ).dot(X_train.transpose()).dot(Y_train)
        return W

    def fit_gradident_descent(self, X_train, Y_train, LAMBDA, learning_rate, max_num_epoch = 100, batch_size = 20):
        W = np.random.randn(X_train.shape[1])
        W = np.expand_dims(W, axis = 1)
        last_loss = 1e9
        for ep in range(max_num_epoch):
            arr = np.array(range(X_train.shape[0]))
            np.random.shuffle(arr)
            X_train = X_train[arr]
            Y_train = Y_train[arr]
            total_minibatch = int(np.ceil(X_train.shape[0]/ batch_size))
            for i in range(total_minibatch):
                index = i*batch_size
                X_train_sub = X_train[index: min(index+batch_size, X_train.shape[0])]
                Y_train_sub = Y_train[index: min(index+batch_size, X_train.shape[0])]
                grad = X_train_sub.T.dot(X_train_sub.dot(W) - Y_train_sub) + LAMBDA*W
                W = W - learning_rate*grad
            new_loss = self.compute_RSS(self.predict(W, X_train), Y_train)
            if np.abs(new_loss - last_loss) < 1e-2:
                break
            last_loss = new_loss
        return W

    def predict(self, W, X_new):
        X_new = np.array(X_new)
        Y_new = X_new.dot(W)
        return Y_new

    def compute_RSS(self, Y_new, Y_predicted):
        loss = 1. / Y_new.shape[0] * np.sum((Y_new - Y_predicted)**2)
        return loss

    def get_the_best_LAMBDA(self, X_train, Y_train):
        def cross_validation(num_folds, LAMBDA):
            row_ids = np.array(range(X_train.shape[0]))
            # Redundant
            ending_ids = len(row_ids)-len(row_ids)%num_folds
            # Standard
            valid_ids = np.split(row_ids[ :ending_ids], num_folds)
            # Add redundant to last part
            valid_ids[-1] = np.append(valid_ids[-1], row_ids[ending_ids: ])
            # Create trainning parts
            train_ids = [[k for k in row_ids if k not in valid_ids[i]] for i in range(num_folds)]
            total_RSS = 0
            for i in range(num_folds):
                valid_part = {'X': X_train[valid_ids[i]], 'Y': Y_train[valid_ids[i]]}
                train_part = {'X': X_train[train_ids[i]], 'Y': Y_train[train_ids[i]]}
                # W = self.fit(X_train=train_part['X'], Y_train=train_part['Y'], LAMBDA=LAMBDA)
                W = self.fit_gradident_descent(X_train=train_part['X'], Y_train=train_part['Y'], LAMBDA=LAMBDA, learning_rate=1e-3)
                Y_predicted = self.predict(W, valid_part['X'])
                total_RSS += self.compute_RSS(valid_part['Y'], Y_predicted)
            return total_RSS/num_folds
            
        def range_scan(best_LAMBDA, minimum_RSS, LAMBDA_values):
            for current_LAMBDA in LAMBDA_values:
                aver_RSS = cross_validation(num_folds=5, LAMBDA=current_LAMBDA)
                if aver_RSS < minimum_RSS:
                    best_LAMBDA = current_LAMBDA
                    minimum_RSS = aver_RSS
                print(f"LAMBDA: {current_LAMBDA}\tRSS: {aver_RSS}\tBest LAMBDA: {best_LAMBDA}\tMinimum RSS: {minimum_RSS}")
            return best_LAMBDA, minimum_RSS
        
        # Initialize
        best_LAMBDA = 0
        minimum_RSS = 1e10

        # Scan with long steps
        LAMBDA_values = range(50)
        best_LAMBDA, minimum_RSS = range_scan(best_LAMBDA=best_LAMBDA, minimum_RSS=minimum_RSS, LAMBDA_values=LAMBDA_values)

        # Scan with short steps
        LAMBDA_values = np.array(range(max(0, (best_LAMBDA-1)*1000), (best_LAMBDA+1)*1000))/1000
        best_LAMBDA, minimum_RSS = range_scan(best_LAMBDA=best_LAMBDA, minimum_RSS=minimum_RSS, LAMBDA_values=LAMBDA_values)

        # Return
        return best_LAMBDA

In [5]:
# Load & process data
data = np.array(readfile("x28.txt")).astype('float')
print(f"Original data shape: {data.shape}")
X = data[:,1: data.shape[1]-1]
print(f"Original X shape: {X.shape}")
Y = data[:,data.shape[1]-1: ]
print(f"Original Y shape: {Y.shape}")
X = normalize_and_add_ones(X)
print(f"Normalized X shape: {X.shape}")

Original data shape: (60, 17)
Original X shape: (60, 15)
Original Y shape: (60, 1)
Normalized X shape: (60, 16)


In [6]:
# 50 data points for trainning, 10 data points for testing
X_train, Y_train = X[:50], Y[:50]
X_test, Y_test = X[50:], Y[50:]    

In [7]:
# Ridge regression
ridge_regression = RidgeRegression()

# Get best lambda for ridge
best_LAMBDA = ridge_regression.get_the_best_LAMBDA(X_train, Y_train)    
print('Best LAMBDA:',best_LAMBDA)

# Learn the weight
W_learned = ridge_regression.fit(X_train = X_train, Y_train=Y_train,LAMBDA=best_LAMBDA)

# Testing
Y_predicted = ridge_regression.predict(W = W_learned, X_new = X_test)

# RSS computation
print("RSS:", ridge_regression.compute_RSS(Y_new = Y_test, Y_predicted = Y_predicted))

LAMBDA: 0	RSS: 5390.1344283571425	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 1	RSS: 5666.892971446068	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 2	RSS: 6202.407305416955	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 3	RSS: 6998.171772901643	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 4	RSS: 8029.004945207195	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 5	RSS: 9274.336909932163	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 6	RSS: 10722.919914236658	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 7	RSS: 12379.651965287881	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 8	RSS: 14206.413404849975	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 9	RSS: 16184.77674809418	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 10	RSS: 18336.43497643368	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 11	RSS: 20537.934931285075	Best LAMBDA: 0	Minimum RSS: 5390.1344283571425
LAMBDA: 12	RSS: 22904.50144509531

LAMBDA: 0.244	RSS: 5435.09435789321	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.245	RSS: 5458.982278823816	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.246	RSS: 5437.5423591304925	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.247	RSS: 5458.024589839333	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.248	RSS: 5446.0952223406075	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.249	RSS: 5456.484759562086	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.25	RSS: 5431.51118802298	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.251	RSS: 5440.595451095603	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.252	RSS: 5445.4839984178425	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.253	RSS: 5441.678016123804	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.254	RSS: 5462.947064735745	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.255	RSS: 5444.565249729065	Best L

LAMBDA: 0.339	RSS: 5435.697325143538	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.34	RSS: 5450.110976417289	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.341	RSS: 5468.203076811044	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.342	RSS: 5447.611760000016	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.343	RSS: 5456.049569628858	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.344	RSS: 5469.279027414772	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.345	RSS: 5461.980409690486	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.346	RSS: 5440.609693215674	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.347	RSS: 5448.220527370686	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.348	RSS: 5454.390838018027	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.349	RSS: 5460.081108596636	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.35	RSS: 5462.961573914249	Best LAM

LAMBDA: 0.435	RSS: 5493.214150547649	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.436	RSS: 5502.1608908655535	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.437	RSS: 5482.009908393777	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.438	RSS: 5453.61229634697	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.439	RSS: 5486.4501471937965	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.44	RSS: 5489.573968963627	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.441	RSS: 5451.021885164138	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.442	RSS: 5495.455483749871	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.443	RSS: 5452.523124173994	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.444	RSS: 5487.283887334121	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.445	RSS: 5491.6183060373	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.446	RSS: 5477.879616841465	Best LAM

LAMBDA: 0.53	RSS: 5500.076517474556	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.531	RSS: 5490.797656573127	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.532	RSS: 5500.169816949623	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.533	RSS: 5502.620169477351	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.534	RSS: 5482.4776901737805	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.535	RSS: 5519.363068198998	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.536	RSS: 5538.83214832642	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.537	RSS: 5489.8762073368935	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.538	RSS: 5510.781071968661	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.539	RSS: 5492.265369904768	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.54	RSS: 5527.84804681152	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.541	RSS: 5499.977911831605	Best LAM

LAMBDA: 0.625	RSS: 5534.810376559512	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.626	RSS: 5541.126512634433	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.627	RSS: 5515.599043988457	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.628	RSS: 5531.341399969186	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.629	RSS: 5493.031089025713	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.63	RSS: 5516.687354016618	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.631	RSS: 5537.614167223323	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.632	RSS: 5521.786541440005	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.633	RSS: 5526.67144724898	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.634	RSS: 5504.970317634966	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.635	RSS: 5509.960396002849	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.636	RSS: 5508.895629151666	Best LAM

LAMBDA: 0.72	RSS: 5592.323019179106	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.721	RSS: 5562.105718223445	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.722	RSS: 5580.320350571759	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.723	RSS: 5595.360521816289	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.724	RSS: 5569.191265455351	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.725	RSS: 5562.217612674813	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.726	RSS: 5581.414717285969	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.727	RSS: 5562.325216174579	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.728	RSS: 5559.900777397455	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.729	RSS: 5586.454668196569	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.73	RSS: 5538.179440905297	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.731	RSS: 5564.595102161134	Best LAM

LAMBDA: 0.815	RSS: 5594.538972990131	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.816	RSS: 5563.986356525571	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.817	RSS: 5581.193361177595	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.818	RSS: 5580.919663422696	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.819	RSS: 5598.19916610533	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.82	RSS: 5578.90216300781	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.821	RSS: 5618.0394378047195	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.822	RSS: 5593.756179123411	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.823	RSS: 5568.395777063696	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.824	RSS: 5607.114607661027	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.825	RSS: 5592.265525330205	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.826	RSS: 5580.734240196177	Best LAM

LAMBDA: 0.91	RSS: 5601.255509261713	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.911	RSS: 5623.635945443016	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.912	RSS: 5639.279494328585	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.913	RSS: 5617.049257421453	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.914	RSS: 5649.0885132543945	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.915	RSS: 5650.900456805344	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.916	RSS: 5611.011517074354	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.917	RSS: 5606.839885042306	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.918	RSS: 5614.126748636666	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.919	RSS: 5635.537696056586	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.92	RSS: 5616.135605635594	Best LAMBDA: 0.014	Minimum RSS: 5367.681887971855
LAMBDA: 0.921	RSS: 5630.007574251702	Best LA