In [1]:
import numpy as np
import matplotlib.pyplot as plt 
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

np.random.seed(42)

In [2]:
iris_data = load_iris() 
x = iris_data.data
y_ = iris_data.target.reshape(-1, 1)

encoder = OneHotEncoder(sparse=False)
train_x, test_x, train_y, test_y = train_test_split(x, y_, test_size=0.20)
train_y = encoder.fit_transform(train_y)

In case you used a LabelEncoder before this OneHotEncoder to convert the categories to integers, then you can now use the OneHotEncoder directly.


In [3]:
def relu(x):
    x[x<0]=0.0
    return x

def softmax(arr):
    arr = arr/np.max(arr)
    return np.exp(arr)/(np.sum(np.exp(arr),axis=0))

def diff_relu(x):
    x[x>0]=1.0
    x[x<=0]=0.0
    return x

def glorot_initializer(out,inp):
    limit = np.sqrt(6*1.0/(inp+out))
    return np.random.uniform(-limit,limit,(out,inp))

In [4]:
## Architecture and weights initializations
inp = 4
hidden1 = 10
hidden2 = 10
output = 3

W1 = glorot_initializer(hidden1,inp)
Bi1 = glorot_initializer(hidden1,1)
W2 = glorot_initializer(hidden2,hidden1)
Bi2 = glorot_initializer(hidden2,1)
W3 = glorot_initializer(output,hidden2)
Bi3 = glorot_initializer(output,1)

In [5]:
epochs = 2000
lr = 4*1e-6
decay  = 1e-9
f = open("SGD.txt", "w")

for i in range(epochs):
    loss = 0
    lr_ini = 4*1e-6
#     train_x, train_y = shuffle(train_x, train_y)
    for j in range(len(train_x)):
        #Forward pass
        t = train_x[j]
        t1 = train_y[j]
        h1 = (np.matmul(W1,t.reshape(-1,1)).reshape(-1,1)+Bi1).reshape(-1,1)
        h1_r = (relu(h1)).reshape(-1,1)
        h2 = (np.matmul(W2,h1_r).reshape(-1,1)+Bi2).reshape(-1,1)
        h2_r = (relu(h2)).reshape(-1,1)
        out = (np.matmul(W3,h2_r.reshape(-1,1))+Bi3).reshape(-1,1)
        y = (softmax(out)).reshape(-1,1)
        
        #Backprop
        d3 = y - (t1).reshape(-1,1)
        d2 = np.matmul(W3.T,d3) * diff_relu(h2)
        d1 = np.matmul(W2.T,d2) * diff_relu(h1)
        
        Bi3 -= lr*d3
        Bi2 -= lr*d2
        Bi1 -= lr*d1
        W3  -= lr*np.matmul(d3,h2_r.T)
        W2  -= lr*np.matmul(d2,h1_r.T)
        W1  -= lr*np.matmul(d1,t.reshape(-1,1).T)
        
        loss -= np.sum(t1*np.log(y))
        
        lr -= lr*decay
    print(i,loss/len(train_x))
    f.write("%s" %(loss/len(train_x)))
    f.write("\n")
    
    
f.close()

0 4.7622930584255885
1 4.7618176470203135
2 4.761339469198391
3 4.760858269877703
4 4.760375346173194
5 4.759891414359902
6 4.759405314197985
7 4.758916765656905
8 4.758426434133693
9 4.757935246428448
10 4.757441904169485
11 4.756945481014091
12 4.756448256354437
13 4.755949176317599
14 4.755448937084143
15 4.754946911658617
16 4.754443191103968
17 4.753937570739562
18 4.753430794305479
19 4.752922429605545
20 4.752412654310733
21 4.751901428514079
22 4.751388444894308
23 4.750874995007459
24 4.750360378435334
25 4.749844106890437
26 4.74932618262257
27 4.748806845666375
28 4.748282978495049
29 4.7477564056073565
30 4.747229352129
31 4.746700987072071
32 4.746171030364102
33 4.745639484828325
34 4.745107183152431
35 4.744572134132919
36 4.744035549029963
37 4.743497720445067
38 4.74295799820585
39 4.742416754902293
40 4.7418732282539136
41 4.741326909103412
42 4.7407773882456565
43 4.740226442467847
44 4.7396740756583196
45 4.7391187910663
46 4.738559728003779
47 4.737999308450592
48 

413 4.1440371384132435
414 4.14010451539358
415 4.136120997428363
416 4.132180472574844
417 4.128308332883046
418 4.124531916596448
419 4.120832095184369
420 4.117221583372854
421 4.113663886891575
422 4.110197101667116
423 4.1067685249215415
424 4.103290539026438
425 4.099688440745293
426 4.096114795759572
427 4.0923935125732935
428 4.088693211570542
429 4.085053980875539
430 4.081452372860174
431 4.077893410616999
432 4.074373774874334
433 4.070895459163632
434 4.06745569927698
435 4.064069923795335
436 4.060739115886836
437 4.057446265154506
438 4.054190686097756
439 4.050969112841881
440 4.047782113144709
441 4.04461572245604
442 4.04148650325349
443 4.038392032234959
444 4.035331707929178
445 4.032307702923727
446 4.029314385857547
447 4.0263529423586
448 4.023276393556933
449 4.020175720299067
450 4.017107473677289
451 4.014071080708954
452 4.011066092722474
453 4.008092553327024
454 4.00514923027959
455 4.002235605692126
456 3.9993511744193126
457 3.996495443664281
458 3.9936670

814 3.69978356233047
815 3.6995089836992263
816 3.6992353134491522
817 3.698962549641432
818 3.6986906903426955
819 3.69841973362493
820 3.6981496775654716
821 3.697880520246942
822 3.69761225975722
823 3.697344894189384
824 3.6970784216416854
825 3.6968128402174982
826 3.696548148025281
827 3.6962843431785415
828 3.6960214237957913
829 3.695759388000509
830 3.6954982339211
831 3.695237959690857
832 3.6949785634479317
833 3.694720043335281
834 3.694462397500658
835 3.6942056240965337
836 3.6939497212801
837 3.6936946872132155
838 3.6934405200623757
839 3.693187217998666
840 3.692934779197754
841 3.692683201839825
842 3.6924324841095673
843 3.692182624196136
844 3.691933620293121
845 3.6916854705985043
846 3.6914381733146437
847 3.6911917266482264
848 3.690948483255297
849 3.690706980563889
850 3.6904663172178633
851 3.690226491459905
852 3.6899875015367414
853 3.689749345699131
854 3.689512451624377
855 3.6892769074961764
856 3.68904218844957
857 3.688808292767058
858 3.688575218734946

1175 3.6412249198789897
1176 3.6411489189816084
1177 3.641073345641521
1178 3.6409981989409994
1179 3.640923477963492
1180 3.640849181793624
1181 3.6407753095171986
1182 3.6407018602211947
1183 3.64062883299377
1184 3.6405562269242595
1185 3.6404842235353225
1186 3.6404134017691296
1187 3.6403429958862903
1188 3.6402730049894925
1189 3.640203428182599
1190 3.640134264570628
1191 3.640065513259756
1192 3.6399971733573357
1193 3.6399292439718716
1194 3.639861724213041
1195 3.639794504080593
1196 3.639727276803014
1197 3.6396606908697304
1198 3.6395945088359074
1199 3.639528729819949
1200 3.6394633529414278
1201 3.6393990048983085
1202 3.6393354892473284
1203 3.639272401807809
1204 3.6392039954089683
1205 3.6391322021775
1206 3.6390608302958776
1207 3.638989878813064
1208 3.6389199339415375
1209 3.638850833213638
1210 3.638782146035531
1211 3.6387138714753053
1212 3.638646008602441
1213 3.6385785564878126
1214 3.6385114949456803
1215 3.638445878385458
1216 3.6383814838794932
1217 3.638317

1548 3.6298282393014216
1549 3.6298265217148695
1550 3.6298245265121034
1551 3.629821455689803
1552 3.629818598486102
1553 3.6298159541691746
1554 3.6298135220093135
1555 3.629811301278916
1556 3.629809291252497
1557 3.62980749120666
1558 3.6298059004201235
1559 3.62980527329295
1560 3.6298070196361447
1561 3.629807548498114
1562 3.6298079433302064
1563 3.6298072364920784
1564 3.629806727161627
1565 3.629806414648168
1566 3.6298068480117323
1567 3.6298085670357945
1568 3.6298104772091357
1569 3.6298125778652843
1570 3.6298148683397695
1571 3.6298173479700777
1572 3.629820016095677
1573 3.629822872057995
1574 3.6298259152004237
1575 3.6298291448683133
1576 3.6298325604089596
1577 3.6298361611716197
1578 3.629839946507486
1579 3.6298439157696887
1580 3.629848068313308
1581 3.6298524034953363
1582 3.629856920674713
1583 3.62986161921229
1584 3.6298664984708346
1585 3.629871557815037
1586 3.6298767966114966
1587 3.6298822142287173
1588 3.6298878100370997
1589 3.6298935834089514
1590 3.6298

1903 3.6353731383594576
1904 3.6354017216266774
1905 3.635430384579121
1906 3.63545912685764
1907 3.635487948104586
1908 3.6355168479638063
1909 3.6355458260806213
1910 3.6355748821018428
1911 3.635604015675741
1912 3.6356332264520805
1913 3.635662514082061
1914 3.6356918782183616
1915 3.6357213185150994
1916 3.6357508346278453
1917 3.635780426213614
1918 3.635810092930853
1919 3.635839834439439
1920 3.6358696504006836
1921 3.6358995404773125
1922 3.635929504333463
1923 3.6359595416346924
1924 3.6359896520479578
1925 3.636019835241617
1926 3.6360500908854183
1927 3.6360804186505087
1928 3.6361108182094157
1929 3.6361412892360385
1930 3.636171831405661
1931 3.636202444394931
1932 3.6362331278818636
1933 3.6362638815458275
1934 3.636294705067554
1935 3.6363255981291114
1936 3.6363565604139216
1937 3.6363875916067414
1938 3.6364186913936596
1939 3.6364498594620995
1940 3.6364810955008045
1941 3.6365123991998405
1942 3.636543770250579
1943 3.636575208345712
1944 3.636606713179226
1945 3.63

In [6]:
y_pred = []
for j in range(len(test_x)):
    t = test_x[j]
    h1 = (np.matmul(W1,t.reshape(-1,1)).reshape(-1,1)+Bi1).reshape(-1,1)
    h1_r = (relu(h1)).reshape(-1,1)
    h2 = (np.matmul(W2,h1_r).reshape(-1,1)+Bi2).reshape(-1,1)
    h2_r = (relu(h2)).reshape(-1,1)
    out = (np.matmul(W3,h2_r.reshape(-1,1))+Bi3).reshape(-1,1)
    y = (softmax(out)).reshape(-1,1)
    y_pred.append(y)

In [7]:
a = np.array(y_pred)
a = np.squeeze(a)
eww = []
for i in a:
    eww.append(np.argmax(i))
print(np.array(eww))
print(np.squeeze(test_y))
print(accuracy_score(np.squeeze(test_y), np.array(eww)))

[1 0 1 1 1 0 1 1 1 1 1 0 0 0 0 1 1 1 1 1 0 1 0 1 2 1 1 1 0 0]
[1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0]
0.6666666666666666
