In [1]:
import numpy as np
import matplotlib.pyplot as plt 
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

np.random.seed(42)

In [2]:
iris_data = load_iris() 
x = iris_data.data
y_ = iris_data.target.reshape(-1, 1)

encoder = OneHotEncoder(sparse=False)
train_x, test_x, train_y, test_y = train_test_split(x, y_, test_size=0.20)
train_y = encoder.fit_transform(train_y)

In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.


In [3]:
def relu(x):
    x[x<0]=0.0
    return x

def softmax(arr):
    arr = arr/np.max(arr)
    return np.exp(arr)/(np.sum(np.exp(arr),axis=0))

def diff_relu(x):
    x[x>0]=1.0
    x[x<=0]=0.0
    return x

def glorot_initializer(out,inp):
    limit = np.sqrt(6*1.0/(inp+out))
    return np.random.uniform(-limit,limit,(out,inp))

def glorot_normal(out,inp):
    limit = np.sqrt(2*1.0/(inp+out))
    return np.random.normal(0,limit,(out,inp))

In [4]:
## Architecture and weights initializations
inp = 4
hidden1 = 10
hidden2 = 10
output = 3

W1 = glorot_normal(hidden1,inp)
Bi1 = glorot_normal(hidden1,1)
W2 = glorot_normal(hidden2,hidden1)
Bi2 = glorot_normal(hidden2,1)
W3 = glorot_normal(output,hidden2)
Bi3 = glorot_normal(output,1)

In [5]:
epochs = 2500
epsilon = 2*1e-6
rho = 0.4
delta = 1e-10

f = open("RMS.txt","w")

for i in range(epochs):
    loss = 0
    train_x, train_y = shuffle(train_x, train_y)
    gamma = [0.0,0.0,0.0,0.0,0.0,0.0]
    for j in range(len(train_x)):
        #Forward pass
        t = train_x[j]
        t1 = train_y[j]
        h1 = (np.matmul(W1,t.reshape(-1,1)).reshape(-1,1)+Bi1).reshape(-1,1)
        h1_r = (relu(h1)).reshape(-1,1)
        h2 = (np.matmul(W2,h1_r).reshape(-1,1)+Bi2).reshape(-1,1)
        h2_r = (relu(h2)).reshape(-1,1)
        out = (np.matmul(W3,h2_r.reshape(-1,1))+Bi3).reshape(-1,1)
        y = (softmax(out)).reshape(-1,1)
        
        #Backprop
        d3 = y - (t1).reshape(-1,1)
        d2 = np.matmul(W3.T,d3) * diff_relu(h2)
        d1 = np.matmul(W2.T,d2) * diff_relu(h1)
        
        gamma[0] = (1-rho)*gamma[0] + rho*(d3**2)
        gamma[1] = (1-rho)*gamma[1] + rho*(d2**2) 
        gamma[2] = (1-rho)*gamma[2] + rho*(d1**2) 
        gamma[3] = (1-rho)*gamma[3] + rho*(np.matmul(d3,h2_r.T)**2) 
        gamma[4] = (1-rho)*gamma[4] + rho*(np.matmul(d2,h1_r.T)**2) 
        gamma[5] = (1-rho)*gamma[5] + rho*(np.matmul(d1,t.reshape(-1,1).T)**2)
        
        Bi3 -= (epsilon*d3)/np.sqrt(delta+gamma[0])
        Bi2 -= (epsilon*d2)/np.sqrt(delta+gamma[1])
        Bi1 -= (epsilon*d1)/np.sqrt(delta+gamma[2])
        W3  -= (epsilon*np.matmul(d3,h2_r.T))/np.sqrt(delta+gamma[3])
        W2  -= (epsilon*np.matmul(d2,h1_r.T))/np.sqrt(delta+gamma[4])
        W1  -= (epsilon*np.matmul(d1,t.reshape(-1,1).T))/np.sqrt(delta+gamma[5])
        
        loss -= np.sum(t1*np.log(y))
    
    print(i,loss/len(train_x))
    
    f.write("%s" %(loss/len(train_x)))
    f.write("\n")
    
    
f.close()

0 4.381774431208479
1 4.381395954306188
2 4.3809656550381115
3 4.3807006296044735
4 4.380274501496548
5 4.379978136805089
6 4.379560431510168
7 4.379199185689751
8 4.378802132370507
9 4.378439348914514
10 4.378128908010077
11 4.37781595691069
12 4.3774038588727935
13 4.37706649008876
14 4.376659680190183
15 4.376329279837867
16 4.375965327780848
17 4.375581593145561
18 4.375307144852083
19 4.374969824591008
20 4.374631222580641
21 4.374321878970991
22 4.373964012821459
23 4.373636922151772
24 4.373286234826051
25 4.372967999002275
26 4.3725608255177395
27 4.372264889501445
28 4.371874170950916
29 4.3715485501513065
30 4.3712184384378086
31 4.37093042510905
32 4.370579415512755
33 4.370238023505179
34 4.369914057325604
35 4.369588777097358
36 4.369315991898719
37 4.368953214346257
38 4.368635168782555
39 4.368298026610579
40 4.3679560249176665
41 4.367652965832627
42 4.36739041740554
43 4.367059630953075
44 4.3667998372616985
45 4.366549562300674
46 4.366215324467716
47 4.36583550995798

387 4.2489475333332685
388 4.248672591050079
389 4.248318688821355
390 4.247949523293171
391 4.24765278077629
392 4.247297029932549
393 4.247077172704738
394 4.246762485165777
395 4.2464543467595925
396 4.246221452008957
397 4.245918320748098
398 4.245637801122398
399 4.245341980029575
400 4.244876917662897
401 4.2446167785701485
402 4.244339104619273
403 4.244076351040837
404 4.243739201353362
405 4.243412155535954
406 4.243002277463736
407 4.242761852430388
408 4.242358924264199
409 4.242151672000373
410 4.241756833355475
411 4.241464892744355
412 4.241256469163691
413 4.240976626707402
414 4.240691080401935
415 4.2402706180638035
416 4.2400364716034264
417 4.239591574593522
418 4.239384965379171
419 4.239078603897289
420 4.238667528260551
421 4.238255580336117
422 4.238014763272396
423 4.237544559825449
424 4.237261812953999
425 4.236921694293481
426 4.2365689134789735
427 4.2362289181157315
428 4.235783120851386
429 4.23534074277556
430 4.234942699135236
431 4.234487309454694
432 4

785 4.089307744712984
786 4.08884843230006
787 4.088526081680988
788 4.088170026601898
789 4.087868377520956
790 4.087554970391789
791 4.087136919721352
792 4.0868679449285334
793 4.086393656644474
794 4.0860227913219545
795 4.085682561851411
796 4.085410862537175
797 4.085032948611544
798 4.084680459298431
799 4.084333755366033
800 4.0839272681591785
801 4.083583524711505
802 4.0832969721281325
803 4.0829929629828285
804 4.082508226102582
805 4.082293136044451
806 4.081967441302066
807 4.081648717370163
808 4.081273591679302
809 4.080972224623239
810 4.080565585591492
811 4.080326254225548
812 4.079936469432398
813 4.079757769068922
814 4.07937749796326
815 4.079076161862533
816 4.078613153764282
817 4.078396614097142
818 4.078075525068431
819 4.077716219676672
820 4.077412393471509
821 4.077076169478122
822 4.0767757263679805
823 4.0762886677739605
824 4.076102718991679
825 4.075790722404949
826 4.0753878439835844
827 4.075154456576662
828 4.074727939513396
829 4.074508161417967
830 

1165 3.994141967263964
1166 3.9940444187849087
1167 3.993916775378702
1168 3.9937585902046253
1169 3.99360151211831
1170 3.993436795052052
1171 3.993252114068041
1172 3.9930483483922052
1173 3.99287442790852
1174 3.9927302888892666
1175 3.992570197263404
1176 3.992365804774847
1177 3.992236487467884
1178 3.9920488039326
1179 3.991884226905717
1180 3.9916912812395196
1181 3.9914992380678727
1182 3.9915185718482515
1183 3.991218917501553
1184 3.9911037068826993
1185 3.990919428383122
1186 3.990784055043999
1187 3.990594144754794
1188 3.990431434286197
1189 3.9903549318637954
1190 3.9901629049578102
1191 3.9899352589211174
1192 3.989877230227672
1193 3.9896164581441105
1194 3.989538195444773
1195 3.9893164144858835
1196 3.989101350485209
1197 3.9889483855720282
1198 3.9888019452876593
1199 3.988688795305775
1200 3.9885144192892907
1201 3.988408878111517
1202 3.9882510556612507
1203 3.9881789829780088
1204 3.987958103485185
1205 3.9877094432129097
1206 3.987593422770501
1207 3.987521978231

1550 3.9420608124735392
1551 3.9419012102446254
1552 3.941937107161602
1553 3.94171436512104
1554 3.9417544186316773
1555 3.941690517634983
1556 3.9416042653156746
1557 3.941535479650109
1558 3.9414068452381015
1559 3.9413564667402494
1560 3.9412621434692254
1561 3.941200652759011
1562 3.941183663880102
1563 3.94104003310771
1564 3.9410461980992113
1565 3.9410799893697073
1566 3.940939023643727
1567 3.940922303612368
1568 3.940844851729913
1569 3.940827211430351
1570 3.9406761520864153
1571 3.9406930975843126
1572 3.940698708672935
1573 3.94063176174894
1574 3.9407087454192533
1575 3.9407057369914766
1576 3.940685623958881
1577 3.9406639613942938
1578 3.9406035617029507
1579 3.940584105998561
1580 3.9405391950091415
1581 3.9404909921865934
1582 3.9405146965320137
1583 3.9405078384041965
1584 3.940530333913815
1585 3.94050282377056
1586 3.940520555408896
1587 3.9404353963538488
1588 3.940414975961205
1589 3.9403777090935095
1590 3.940365576652504
1591 3.940402730547913
1592 3.9404385523

1946 4.042540223734935
1947 4.043082046932335
1948 4.043528576453916
1949 4.044093515520221
1950 4.044652927744684
1951 4.045263174239362
1952 4.045794736998322
1953 4.046324390221069
1954 4.046880611657838
1955 4.047487494308835
1956 4.04794389969271
1957 4.048544178674073
1958 4.049043929166613
1959 4.049601546331385
1960 4.050078865931888
1961 4.050614456232588
1962 4.051178040738271
1963 4.051754052381386
1964 4.052310041743717
1965 4.052918982000687
1966 4.053391718196623
1967 4.053857528694949
1968 4.054479981791475
1969 4.054933939011734
1970 4.055384529233576
1971 4.055789920097918
1972 4.056353459477434
1973 4.056778781266201
1974 4.057329645665073
1975 4.057840935108407
1976 4.0583048121373615
1977 4.058763529778669
1978 4.059257802555906
1979 4.05974636261
1980 4.060224320058511
1981 4.0606899075167
1982 4.061188555673521
1983 4.0616899277224805
1984 4.062152480232716
1985 4.062680446970087
1986 4.063217523849925
1987 4.063735821168246
1988 4.064363140921743
1989 4.064839501

2359 4.224370850653247
2360 4.224763565115241
2361 4.225439533323281
2362 4.22599136674106
2363 4.226594716219127
2364 4.227048388310245
2365 4.2275696012581845
2366 4.22811840608942
2367 4.228638879483177
2368 4.22919130501417
2369 4.229692152482669
2370 4.230203025921893
2371 4.230689074110863
2372 4.231300175684227
2373 4.2318669008679946
2374 4.232342729431144
2375 4.232884790844922
2376 4.23329633634085
2377 4.233946693886819
2378 4.2344051837632115
2379 4.234952547117222
2380 4.235355141028665
2381 4.236039093882374
2382 4.236501987559218
2383 4.237042880207037
2384 4.237643644462404
2385 4.238154568378316
2386 4.238723047457809
2387 4.239321819036934
2388 4.239858386896204
2389 4.240352964725499
2390 4.240953403295362
2391 4.241403331018827
2392 4.24198082431005
2393 4.242668198537961
2394 4.2431538071299535
2395 4.24371952279133
2396 4.2443376880220365
2397 4.24488149014118
2398 4.245482066755835
2399 4.246019438331997
2400 4.246523209314691
2401 4.247234354270243
2402 4.247809

In [6]:
y_pred = []
for j in range(len(test_x)):
    t = test_x[j]
    h1 = (np.matmul(W1,t.reshape(-1,1)).reshape(-1,1)+Bi1).reshape(-1,1)
    h1_r = (relu(h1)).reshape(-1,1)
    h2 = (np.matmul(W2,h1_r).reshape(-1,1)+Bi2).reshape(-1,1)
    h2_r = (relu(h2)).reshape(-1,1)
    out = (np.matmul(W3,h2_r.reshape(-1,1))+Bi3).reshape(-1,1)
    y = (softmax(out)).reshape(-1,1)
    y_pred.append(y)

In [7]:
a = np.array(y_pred)
a = np.squeeze(a)
eww = []
for i in a:
    eww.append(np.argmax(i))
print(np.array(eww))
print(np.squeeze(test_y))
print(accuracy_score(np.squeeze(test_y), np.array(eww)))

[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0]
[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0]
1.0
