In [109]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys

import numpy 
import pandas
import sklearn.metrics
import sklearn.model_selection
import sklearn.linear_model
import sklearn.preprocessing


def load_train_test_data(train_ratio=.5):
#     https://archive.ics.uci.edu/ml/datasets/Energy+efficiency
    data = pandas.read_csv('./ENB2012_data.csv')
    feature_col = ['X1', 'X2', 'X3', 'X4', 'X5', 'X6', 'X7', 'X8']
    label_col = ['Y1']
    X = data[feature_col]
    y = data[label_col]

    return sklearn.model_selection.train_test_split(X, y, test_size = 1 - train_ratio, random_state=0)


def scale_features(X_train, X_test, low=0, upp=1):
    minmax_scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(low, upp)).fit(numpy.vstack((X_train, X_test)))
    X_train_scaled = minmax_scaler.transform(X_train)
    X_test_scaled = minmax_scaler.transform(X_test)
    return X_train_scaled, X_test_scaled


def gradient_descent(X, y, alpha = 1, iters = 10000, eps=1e-4):
    # TODO: fill this procedure as an exercise
    n, d = X.shape
    theta = numpy.zeros((d, 1))
    y_hat = numpy.dot(X, theta)
    learn = numpy.zeros((d, 1))
    r_score = 0
    sumY_hat = 0
    y_mean = y["Y1"].mean()
    sumY_down = 0
    for i in range(X.shape[0]):
        sumY_hat = sumY_hat + (y.iat[i, 0] - y_hat[i] )**2
        sumY_down = sumY_down + (y.iat[i, 0] - y_mean)**2
    r_score = 1 - (sumY_hat/sumY_down)
    new_theta = theta - 10 * eps
    num1 = 0
    for times in range(iters):
        grad = numpy.zeros((d, 1))
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                if(j == 0) :
                    grad[j, 0] = grad[j, 0] + (-2) * (y.iat[i, 0] - y_hat[i]) 
                else:
                    grad[j, 0] = grad[j, 0] + (-2) * (y.iat[i, 0] - y_hat[i]) * X[i][j]
        for b in range(X.shape[1]):
            learn[b, 0] = learn[b, 0] + grad[b, 0]**2
        
        for z in range(X.shape[1]):
            theta[z, 0] = theta[z, 0] - (alpha/numpy.sqrt(learn[z, 0])) * grad[z, 0]
        y_hat = numpy.dot(X, theta)
        if(times % 20 == 0):
            print(theta[:, 0],"   ", times)
        
    return theta


def predict(X, theta):
    return numpy.dot(X, theta)


def main(argv):
    X_train, X_test, y_train, y_test = load_train_test_data(train_ratio=.5)
    X_train_scaled, X_test_scaled = scale_features(X_train, X_test, 0, 1)
    correct = numpy.ones((X_train_scaled.shape[0], 1))
    X_train_scaled = numpy.c_[correct, X_train_scaled]
    X_test_scaled = numpy.c_[correct, X_test_scaled]
    theta = gradient_descent(X_train_scaled, y_train)
    y_hat = predict(X_train_scaled, theta)
    print('this is theta', theta)
    print(y_hat)
    print(y_train["Y1"])
    print("Linear train R^2: %f" % (sklearn.metrics.r2_score(y_train, y_hat)))
    y_hat = predict(X_test_scaled, theta)
    print("Linear test R^2: %f" % (sklearn.metrics.r2_score(y_test, y_hat)))


if __name__ == "__main__":
    main(sys.argv)





[1. 1. 1. 1. 1. 1. 1. 1. 1.]     0
[4.28063566 5.20044484 2.94889481 4.78016953 1.84101022 5.65471669
 3.87174076 4.43517481 4.10882679]     20
[4.69480871 6.37431759 2.14666291 5.75841173 0.01626862 7.26870647
 3.71248552 5.00958305 4.25826823]     40
[ 4.90622134  7.00461993  1.64446093  6.44144997 -1.17180046  8.25969495
  3.36772123  5.3420371   4.18816238]     60
[ 5.05021116  7.3540301   1.39584251  6.98312228 -1.92478898  8.92196728
  2.99791698  5.58062429  4.05097236]     80
[ 5.15453576  7.53513646  1.30322226  7.42911224 -2.41089931  9.38410337
  2.63650866  5.76290677  3.88211729]     100
[ 5.2326084   7.61290162  1.30286841  7.80477419 -2.72998905  9.7189612
  2.29704541  5.90836569  3.69909788]     120
[ 5.2925777   7.62721628  1.35508973  8.12641264 -2.94285354  9.97091519
  1.98530967  6.02886635  3.51212972]     140
[ 5.33969825  7.60299427  1.43538138  8.40511201 -3.0873489  10.16779879
  1.70305142  6.13193332  3.32731186]     160
[ 5.37745038  7.55585954  1.52876796

[ 5.47842363  4.86252303  2.68821672  9.90006142 -2.91962251 12.99468255
 -0.4703785   7.48793979  0.63815777]     1400
[ 5.47683236  4.83162477  2.6865469   9.88438929 -2.90739481 13.02171773
 -0.47002023  7.48965812  0.63701336]     1420
[ 5.47526765  4.80088406  2.68484403  9.86874655 -2.89522252 13.04860142
 -0.46964132  7.49126248  0.63601723]     1440
[ 5.47372953  4.77029842  2.68311084  9.85313764 -2.88310656 13.07533547
 -0.46924381  7.49275993  0.63515874]     1460
[ 5.47221801  4.73986557  2.68134978  9.83756656 -2.87104775 13.10192156
 -0.46882952  7.49415711  0.63442804]     1480
[ 5.47073306  4.70958334  2.67956312  9.82203688 -2.85904681 13.12836133
 -0.46840014  7.49546026  0.63381595]     1500
[ 5.46927467  4.67944971  2.67775287  9.80655183 -2.84710434 13.1546563
 -0.46795716  7.49667525  0.63331393]     1520
[ 5.46784278  4.64946277  2.67592088  9.79111427 -2.83522088 13.1808079
 -0.46750195  7.49780759  0.63291406]     1540
[ 5.46643736  4.61962073  2.67406884  9.77

[ 5.42667629  3.03292709  2.54640169  8.95672933 -2.21216434 14.5594588
 -0.4355837   7.51044038  0.6700774 ]     2780
[ 5.42670654  3.01029726  2.54418694  8.94522564 -2.20380101 14.578226
 -0.43511313  7.51039124  0.67084774]     2800
[ 5.42675657  2.98776491  2.54196936  8.93377856 -2.19548571 14.596895
 -0.43464476  7.51034126  0.67161543]     2820
[ 5.42682626  2.96532953  2.53974898  8.92238779 -2.18721817 14.61546628
 -0.43417859  7.51029051  0.67238042]     2840
[ 5.42691553  2.94299065  2.5375258   8.91105306 -2.17899815 14.63394034
 -0.43371461  7.51023908  0.67314264]     2860
[ 5.42702427  2.92074777  2.53529985  8.89977408 -2.17082542 14.65231769
 -0.43325282  7.51018702  0.67390204]     2880
[ 5.42715238  2.89860042  2.53307115  8.88855058 -2.16269972 14.6705988
 -0.43279322  7.5101344   0.67465856]     2900
[ 5.42729977  2.87654812  2.53083971  8.87738226 -2.15462082 14.68878417
 -0.43233581  7.51008126  0.67541216]     2920
[ 5.42746633  2.8545904   2.52860555  8.866268

[ 5.4702601   1.67656285  2.38776768  8.28193906 -1.73589442 15.6479975
 -0.40794964  7.50677429  0.71598198]     4160
[ 5.47144133  1.65963269  2.38539301  8.27373946 -1.7303296  15.66103206
 -0.40761475  7.50672639  0.71654059]     4180
[ 5.47263639  1.6427715   2.38301647  8.26557978 -1.72479854 15.67399737
 -0.40728152  7.5066787   0.7170964 ]     4200
[ 5.47384519  1.62597895  2.38063809  8.2574598  -1.71930108 15.68689378
 -0.40694996  7.50663122  0.71764943]     4220
[ 5.47506768  1.6092547   2.37825786  8.24937935 -1.71383705 15.69972163
 -0.40662005  7.50658396  0.7181997 ]     4240
[ 5.47630378  1.59259839  2.37587581  8.2413382  -1.70840627 15.71248127
 -0.40629179  7.5065369   0.71874721]     4260
[ 5.47755341  1.5760097   2.37349194  8.23333618 -1.70300858 15.72517304
 -0.40596517  7.50649006  0.71929198]     4280
[ 5.4788165   1.55948828  2.37110626  8.22537308 -1.69764382 15.73779729
 -0.40564018  7.50644343  0.71983402]     4300
[ 5.48009299  1.5430338   2.36871879  8.2

[ 5.58069145  0.65358166  2.22010246  7.80011632 -1.42289172 16.40176563
 -0.38835902  7.50393101  0.74863167]     5540
[ 5.58267397  0.64068752  2.21762353  7.79424915 -1.4192979  16.41075504
 -0.38812186  7.50389602  0.7490264 ]     5560
[ 5.58466606  0.62784227  2.21514341  7.78841019 -1.41572789 16.41969556
 -0.38788589  7.50386118  0.74941916]     5580
[ 5.58666769  0.61504566  2.21266212  7.78259929 -1.41218155 16.42858744
 -0.38765109  7.50382651  0.74980993]     5600
[ 5.58867881  0.60229746  2.21017965  7.77681631 -1.40865877 16.43743093
 -0.38741745  7.50379199  0.75019875]     5620
[ 5.59069935  0.58959741  2.20769602  7.77106112 -1.40515943 16.44622626
 -0.38718499  7.50375763  0.75058561]     5640
[ 5.59272928  0.57694528  2.20521123  7.76533356 -1.40168342 16.45497368
 -0.38695368  7.50372342  0.75097052]     5660
[ 5.59476853  0.56434083  2.20272529  7.75963351 -1.3982306  16.46367343
 -0.38672353  7.50368936  0.7513535 ]     5680
[ 5.59681707  0.55178382  2.20023822  7.

[ 5.73496191 -0.12309323  2.04920926  7.45885741 -1.22715596 16.91306051
 -0.37465086  7.50187457  0.77141469]     6900
[ 5.73749123 -0.1331597   2.04666634  7.45461901 -1.22493425 16.91922863
 -0.37448196  7.50184869  0.77169486]     6920
[ 5.74002718 -0.14319128  2.04412271  7.45040066 -1.22272939 16.92536206
 -0.37431391  7.50182293  0.77197364]     6940
[ 5.74256972 -0.15318817  2.04157838  7.44620225 -1.2205413  16.93146096
 -0.37414668  7.50179727  0.772251  ]     6960
[ 5.74511881 -0.16315052  2.03903334  7.44202369 -1.21836989 16.93752553
 -0.37398029  7.50177173  0.77252697]     6980
[ 5.74767441 -0.17307851  2.0364876   7.43786487 -1.21621507 16.94355592
 -0.37381473  7.5017463   0.77280156]     7000
[ 5.7502365  -0.18297231  2.03394118  7.4337257  -1.21407677 16.94955231
 -0.37364999  7.50172098  0.77307476]     7020
[ 5.75280502 -0.19283208  2.03139407  7.42960608 -1.21195489 16.95551488
 -0.37348607  7.50169576  0.77334659]     7040
[ 5.75537995 -0.202658    2.02884627  7.

[ 5.9232642  -0.74418378  1.87235433  7.20858795 -1.10926203 17.26570796
 -0.3647641   7.50032499  0.78778166]     8280
[ 5.92617394 -0.75220611  1.86977439  7.20552171 -1.10802333 17.26984949
 -0.3646443   7.50030567  0.78797944]     8300
[ 5.92908818 -0.76020364  1.86719408  7.20246965 -1.10679649 17.27396653
 -0.3645251   7.50028643  0.78817622]     8320
[ 5.93200688 -0.76817648  1.8646134   7.19943171 -1.10558145 17.27805921
 -0.36440649  7.50026727  0.78837202]     8340
[ 5.93493003 -0.77612476  1.86203235  7.19640781 -1.10437815 17.28212764
 -0.36428847  7.50024818  0.78856682]     8360
[ 5.9378576  -0.78404859  1.85945094  7.19339788 -1.10318652 17.28617195
 -0.36417104  7.50022918  0.78876065]     8380
[ 5.94078956 -0.79194809  1.85686918  7.19040186 -1.10200651 17.29019226
 -0.36405419  7.50021025  0.78895349]     8400
[ 5.94372588 -0.79982339  1.85428706  7.18741966 -1.10083806 17.29418868
 -0.36393791  7.5001914   0.78914537]     8420
[ 5.94666655 -0.8076746   1.85170459  7.

[ 6.13333849 -1.24540608  1.6936601   7.02686629 -1.04869295 17.49998038
 -0.35774806  7.49915892  0.79933185]     9660
[ 6.13650428 -1.25197268  1.691063    7.0246296  -1.04814545 17.50269156
 -0.35766299  7.49914425  0.79947137]     9680
[ 6.13967306 -1.2585216   1.68846577  7.02240294 -1.04760628 17.50538546
 -0.35757834  7.49912962  0.79961019]     9700
[ 6.14284481 -1.2650529   1.68586841  7.02018628 -1.04707539 17.50806217
 -0.35749412  7.49911506  0.7997483 ]     9720
[ 6.14601952 -1.27156669  1.68327092  7.01797955 -1.04655275 17.51072178
 -0.3574103   7.49910055  0.79988572]     9740
[ 6.14919717 -1.27806303  1.6806733   7.01578272 -1.04603831 17.51336436
 -0.3573269   7.49908609  0.80002244]     9760
[ 6.15237773 -1.28454203  1.67807557  7.01359573 -1.04553203 17.51599001
 -0.35724392  7.4990717   0.80015847]     9780
[ 6.1555612  -1.29100376  1.67547771  7.01141853 -1.04503387 17.51859881
 -0.35716134  7.49905735  0.80029382]     9800
[ 6.15874755 -1.29744831  1.67287973  7.