In [1]:
import matplotlib.pyplot as plt
from neural_network import build_model,predict,plot_decision_boundary
import numpy as np
from sklearn.datasets import make_moons

In [2]:
X, y = make_moons(200,noise=0.20)
nn_hdim =4 

In [3]:
def softmax(z):
    exp_z = np.exp(z - np.max(z))
    soft_max = exp_z / np.sum(exp_z)
    return soft_max

In [4]:
def weight_initialization(x_shape, nndim, y_shape):
    W1 = np.random.randn(x_shape, nndim)
    b1 = np.zeros((1,nndim))
    W2 = np.random.randn(nndim,y_shape)
    b2 = np.zeros((1,y_shape))
    parameters = {
        "W1": W1,
        "b1" : b1,
        "W2": W2,
        "b2" : b2
      }
    return parameters

In [5]:
parameters = weight_initialization(X.shape[1], nn_hdim, 2)
print(parameters['W1'].shape)
print(parameters['b1'].shape)
print(parameters['W2'].shape)
print(parameters['b2'].shape)

(2, 4)
(1, 4)
(4, 2)
(1, 2)


In [6]:
def feedfoward(X,W1,W2,b1,b2):
        a = np.dot(X,W1) + b1
        #print(a.shape)
        h = np.tanh(a)
        #print(h.shape)
        z = np.dot(h,W2) + b2
        #print(z.shape)
        y_pred = softmax(z)
        #print(y_pred.shape)
        return a,h,z,y_pred

In [7]:
print(X[1].shape)

(2,)


In [8]:
a,h,z,y_pred = feedfoward(X[4],parameters['W1'],parameters['W2'],parameters['b1'],parameters['b2'])
print(a,h,z,y_pred)

[[-1.27097066 -0.18445111 -0.28900313 -0.23782213]] [[-0.85406051 -0.18238738 -0.28121704 -0.2334376 ]] [[1.01533451 0.33043433]] [[0.66483149 0.33516851]]


In [9]:
def one_hot_encoding(y):
    if y==1:
        y=np.array([0,1])
    else:
        y=np.array([1,0])
    return y

In [10]:
def calculate_loss(model,X, y):
    W1, W2, b1, b2 = model['W1'], model['W2'], model['b1'], model['b2']
    a,h,z,y_pred = feedfoward(X,W1,W2,b1,b2)
    y = one_hot_encoding(y)
    #print(y)
    #if y_pred==0:
    #    logY = np.multiply(y, 0)
    #    logY_= np.multiplty(1-y,np.log2(1-y_pred))
    #    loss = -np.sum(logY + logY_)/2 
    #elif y_pred==1:
    #    logY = np.multiply(y, np.log2(y_pred))
    #    logY_ = np.multiply(1-y,0)
    #    loss = -np.sum(logY + logY_)/2 
    #else:
    #print(X.shape[0])
    loss = -np.sum(np.multiply(y, np.log(y_pred)) +  np.multiply(1-y, np.log(1-y_pred)))/X.shape[0]
    
    #print (loss)
    #cost = -np.sum(np.multiply(Y, np.log(A2)) +  np.multiply(1-Y, np.log(1-A2)))/m
    loss = np.squeeze(loss)
    cost = {
    "a": a,
    "h": h,
    "z": z,
    "y_pred": y_pred,
    "loss":loss
    }
    return cost

In [11]:
cost = calculate_loss(parameters,X[0], y[0])
print(cost['loss'])

0.4394011193838229


In [12]:
def backward_prop(X, Y, cost, parameters):
    a = cost['a']
    h = cost['h'] 
    y_pred = cost['y_pred']
    W2 = parameters['W2']
    X = np.reshape(X,(1,2))
    Y = one_hot_encoding(Y)
    
    dZ2 = np.subtract(y_pred,Y)
    dW2 = np.dot(h.T,dZ2)
    db2 = dZ2#np.sum(dZ2, axis=1, keepdims=True)/m
    dZ1 = np.multiply(np.dot(dZ2,W2.T), 1-np.power(np.tanh(a),2))
    dW1 = np.dot(X.T,dZ1)
    db1 = dZ1#np.sum(dZ1, axis=1, keepdims=True)/m

    grads = {
    "dW1": dW1,
    "db1": db1,
    "dW2": dW2,
    "db2": db2
    }

    return grads

In [13]:
grad = backward_prop(X[1], y[1], cost, parameters)
print(grad['dW1'].shape)
print(grad['db1'].shape)
print(grad['dW2'].shape)
print(grad['db2'].shape)

(2, 4)
(1, 4)
(4, 2)
(1, 2)


In [14]:
def update_parameters(parameters, grads, learning_rate):
    W1 = parameters["W1"]
    b1 = parameters["b1"]
    W2 = parameters["W2"]
    b2 = parameters["b2"]

    dW1 = grads["dW1"]
    db1 = grads["db1"]
    dW2 = grads["dW2"]
    db2 = grads["db2"]

    W1 = W1 - learning_rate*dW1
    b1 = b1 - learning_rate*db1
    W2 = W2 - learning_rate*dW2
    b2 = b2 - learning_rate*db2

    new_parameters = {
    "W1": W1,
    "W2": W2,
    "b1" : b1,
    "b2" : b2
    }

    return new_parameters

In [15]:
def build_model(X, y, nn_hdim, num_passes=20000, print_loss=False):
    c= 0
    learning_rate = 0.001
    parameters= weight_initialization(X.shape[1], nn_hdim, 2)
    for i in range(num_passes):
        if i % X.shape[0]-1==0:
            c = 0
        cost = calculate_loss(parameters,X[c],y[c])
        if print_loss == True:
            print(cost['loss'])
        grads = backward_prop(X[c], y[c], cost, parameters)
        parameters= update_parameters(parameters, grads, learning_rate)
        #print(c)
        #print(i)
        c = c +1
    return parameters

In [16]:
nnhdim =4
model  =  build_model(X,y ,nnhdim,print_loss=True) 

2.7461714784875304
2.7378978440960644
2.9199301088673897
1.1853572279635007
2.1485378626551244
2.2105457056679736
2.4897838233586573
2.7390415158828554
0.5569212235327956
0.3838329978505488
0.12485356755043614
1.848498183725196
2.679796825980935
0.8319877730094067
0.4076923681645806
1.5099701534677257
2.209653834723256
2.2226931396615965
2.121206301872789
0.4968066002873118
1.757036857761543
0.18688770653147968
0.10385260125769671
1.4925914434733023
2.801572810541595
1.4117764709098157
2.5731203025259695
1.5903719122671354
2.5468907711136017
0.38864380568866486
0.18950017667505745
1.4317363473346392
2.6628860659817946
0.8343528757437817
0.6579190425200735
1.235747825745752
0.8773774365579229
2.226198900852214
0.20897131124161555
0.26343685911416714
2.392823642243175
0.0921955404961555
0.07960899399147142
2.6889401175671255
2.0169402922885
0.4284541053607647
0.0706792227286318
0.6423030217462706
2.3292545880550426
0.6536031257861112
2.0023112467745716
0.062122094184915135
1.996372615719

1.1034145468391028
0.680524730012492
0.7657309463151828
0.6934141345124297
0.3600244732662029
1.0354112579772996
0.11598483538496183
0.14563330145966114
0.4527521626155497
0.47008266940228816
1.0215681367124751
0.2622965698828272
1.5077450611300933
0.11111677116432438
1.1026782013628298
0.10321733266143242
0.8752384302611088
0.2506125861712142
0.23921937576291413
0.16114382144038852
0.1674915686361177
1.0278574999469567
0.09304253922595895
0.16583520689600767
0.7687935512897356
0.12764961331327476
0.7343644137331332
0.7234439706661291
1.4349758579063034
0.15780352769623518
1.3152301980084697
0.07937721334793771
0.4419257752678654
1.1711082250927354
0.6802547413329032
0.9364445322552613
0.7022755870161446
0.7845885197710221
0.09292080599866874
1.0313204550571962
0.8841657615535059
0.48469161427672713
1.2014051867884985
0.9873063494065828
0.4482839663144837
0.9830136397259925
0.31496492898442874
0.09757367975796202
0.1480307199806529
0.5755499570187381
0.7188663182332732
0.12178591859654

0.12567720435872992
0.6461071317583659
0.1966267803585893
1.511520738745583
0.23099400817474786
1.1954694576478535
0.8821622847636503
0.3776762768389176
0.5438299130944052
0.7734420076141157
0.11243476129234342
0.1137190099952882
0.3789343536567372
0.4196780480765823
0.3159545488856952
0.05782963146190137
0.10376788664876091
0.6804393649786082
0.7935293019340113
0.2672511556298361
0.11018963407875987
0.525164398672892
0.4890786507119079
0.07585009531897083
0.3995959440203447
0.2654495912798338
1.2182063406656765
0.09907436393225914
0.07186926109153571
0.18580521631552982
1.7300685868998618
0.08722636728603068
0.5535875986099094
0.1838856292712283
0.05000264008641822
0.18778944163198036
0.32832469757595584
0.07366676481493414
0.1473711930389897
0.06797107357077675
0.20616127029044592
0.08480266189251945
0.0972144745605574
0.17325714952640986
0.11323844149102785
1.2597541331503412
0.060634594302172815
0.0932233606352145
0.1923076744338655
0.1112962787382846
0.5747442221265598
0.066290574

0.07734730247999305
0.8078919238136801
0.08235268155214855
0.1503984034602256
0.16294911873878667
1.1885012802719366
0.28964493694174454
0.09352685011347558
0.08377801127242113
0.057625583810216464
1.3582466644507956
0.1944502363881519
0.04331755667651818
0.43384026060063297
1.0642763227289933
0.1466602069012624
0.3339429932256823
0.14818110814358623
0.3249286589297192
0.03169583341265828
0.25367400192286227
0.10229395573248376
0.18487771844793294
0.9396825307074621
0.06086164711514818
1.4895855620772351
0.26367184128319004
0.9060129996849862
0.05596200215517928
0.27011920248798965
0.05329647613182367
0.051738655391139435
0.04665805869068297
0.038458340555109716
1.3357224069607732
0.14792270882410674
0.4292287920722322
0.4875897390931222
0.27987536535193935
0.11740240345572037
0.12551905319059276
1.2564071227867806
0.10689984289876162
1.1675884550125044
0.09908647819423447
0.08315911172394487
2.0921544229771776
0.16069326003400647
0.2407379908502404
1.4420491437860914
0.166130880822369

0.1155552791034823
1.184292107029723
0.10096163298451595
0.05698608183333319
2.1864544736132494
0.12570934646204324
0.18349879700185676
1.5514481240053608
0.12383493002869428
0.048005915902872946
0.2550903030667394
0.13877164844059955
0.10862555197024132
0.6615357273795668
0.419728449395222
0.06102179554078122
0.2242729611429073
0.14057710039770868
0.10848835285127234
0.19831747859440557
0.07127875647694544
0.1204041464535211
0.026749881156325112
0.4979297942670025
0.25363404787045957
0.07801279372506087
0.03936311308960405
0.07042987846367274
0.0325721075206065
0.05318898138154609
0.2665122873532517
2.1149951936434532
0.18995340167841818
0.10973173802567213
0.06715020333704518
0.12612710692127033
0.2852330500757696
0.07020188373678095
1.0887200768523764
1.5468313538247533
0.03311385395276306
0.12152140649746092
0.15640053673617527
0.03643173002112361
1.6722614806658598
0.07847306260430251
0.08585050267402922
0.06218176801973648
0.1391554932908883
0.8601017923065261
0.07915349176674277

0.8117121292608138
0.07734650738174548
0.17080261132895586
0.0798633863915622
0.1608531151408238
0.26510879422241906
0.08042024352877009
0.5271362619956655
0.060170745576695134
0.05305391010196419
0.04632692915998315
0.6035334688273141
1.656486955362644
0.025758308466802132
0.6160397794734764
0.03880296970322206
0.08857524098416159
1.38706550239344
0.09387547641002914
0.08604483948179777
0.08929058003187455
0.03154080115761702
0.3792308139338686
0.2912005526688311
1.7805524305678353
0.08285618708863493
1.145996803478004
0.7681066303492712
0.13123595420400982
0.17067316683305428
0.526883330452276
0.18486843278298218
0.048361904538809564
0.10956869754133462
0.18888243121824075
0.11144977568887132
0.040221312621430645
0.06563257727806115
0.4100193224528899
0.5766279868992069
0.08513252268254082
0.055038989894521245
0.22280895554072497
0.2537907868559911
0.061178139391266546
0.1976482002557119
0.08338969129467991
1.2454934098316164
0.040994398563383196
0.022598580348460254
0.06608710407404

0.16698239257133587
0.0947648762142745
0.03715438610353878
0.062154772509359905
0.38017911975258056
0.5407175140212049
0.07066522223589589
0.04879597962570155
0.1937013206156636
0.2297223851831318
0.057706104439160966
0.17782741161637156
0.06908262016356845
1.2618045591233091
0.03548367349759877
0.018968819022603917
0.0568446484984534
1.0749396106476192
0.04397138626792128
0.2647382014123514
0.05312467160074084
0.05958495195694185
0.05582270864046639
0.09172487520614897
0.06508587139487015
0.03521021568308313
0.08566159230934905
0.10237097133549644
0.0253176375740322
0.07325860761652996
0.20785576232318104
0.11160072008414626
1.5693723979179972
0.02845745560722432
0.03803086304141393
0.05901259750756778
0.261502834534363
0.2714411602918322
0.04858571764007538
0.019861396807307235
0.022772837733184538
0.4735521632697741
0.18617910090477405
0.28055260164218165
0.2778566090924687
0.109169012090987
0.06544057164046124
0.04956928114198629
0.15994436449654062
0.04387394963078203
0.0391967786

0.055047474379629965
1.0568767914011925
0.04322446172478735
0.25921521665895375
0.05136198731697005
0.06000129961477169
0.0538699386350593
0.08829192365862167
0.06500507831561414
0.03380945907168964
0.08641846862889613
0.1005246001828719
0.024450755604537647
0.07227528725671611
0.20780254569849546
0.11086201395884598
1.5767942285747987
0.02775607600032636
0.0371639744785381
0.057140791049432565
0.26823697036132105
0.265632157769051
0.047865208404894714
0.019087969916333564
0.02197987340673183
0.4893376544472548
0.19076624616119867
0.2818969319610043
0.2727892866891688
0.10542042164599441
0.06381882584979358
0.04791331040461392
0.15493108274484124
0.04277260960973053
0.03862534374771259
0.16851810028992456
0.4994434578043042
0.2134655619729464
2.0326316904765482
0.20491084506180463
0.01998420044542531
0.09038444165041817
0.177170271963815
0.041376156072842785
0.373436855593702
0.08581425779078922
1.0381576062871498
0.02967780255402231
0.11751910711672346
0.03899667689471486
0.1679716548

0.10089951959038042
0.0330837494829461
0.03252859508530675
0.02268436938790257
0.017378102357559287
1.2429816838455992
0.3050508875149126
0.6259126355945029
0.2844287524306859
0.37313867804471856
0.0467131835198918
0.052468744036744966
1.308044038814225
0.1340337177583794
1.2366728679141383
0.09608375835907637
0.030020449238838606
2.335144267694707
0.08305635606686575
0.11418415355496853
1.7324440051860388
0.074813125071051
0.05019115607136402
0.16936479719202896
0.08020802361324031
0.06312459770014445
0.5907708218621971
0.3482768552960469
0.039484954476454635
0.14509330092683664
0.09987543164051407
0.14402467638991506
0.1942422765668026
0.03582532190888629
0.07530287947281247
0.017002553382757097
0.41988630433474305
0.14053508332961512
0.05550967197872994
0.02682101381588204
0.05085722102658886
0.022168572917097396
0.02736886437384444
0.4279014896801017
2.44240673656177
0.10376233418453795
0.06374398754397023
0.0587051393619734
0.08117040263334474
0.15663455659518863
0.043648240917656

0.10908192018802769
0.7377462124518638
0.06953184034882273
0.16939945015896818
0.05320189249203883
0.1333041172479116
0.22197906531117112
0.07419897827376376
0.4362590322026465
0.043567804817402536
0.038169367403934396
0.03553427742452892
0.7083268711089034
1.7012557481538155
0.018582262698905387
0.5615169424270674
0.025256942867994316
0.05788429823127329
1.4121043159098263
0.06643955685358652
0.058117979893205544
0.06254194993008191
0.01901247327948067
0.31128577046053285
0.31155818545322883
1.924366127420598
0.0551661236671483
1.1589486877916049
0.711998250682762
0.08546566965540835
0.10510310963207364
0.441004129387885
0.2280420555548519
0.03609479608949028
0.06655880407868096
0.13903079895194972
0.07452193786517633
0.03251590659516511
0.056773960651417654
0.340127879035935
0.49592200385763074
0.054219094940538795
0.04076331375344614
0.1570181196166654
0.19808470240206255
0.05193108734995139
0.15194770172061825
0.052853210981702925
1.2879382401997015
0.02868708835560153
0.0146888213

0.17254334012402595
0.03140865636887534
0.3932864650749318
0.07838706494861682
1.0180520984815504
0.02176156662473587
0.08782814490943659
0.033480051219851384
0.13604228179041453
0.07058813761989456
0.043133027730645986
0.054310065169041064
0.08824498703334026
0.9436022929637753
0.20115691183977674
0.012946163926769779
0.039426982915183574
0.06348205111612988
0.6686079421129142
0.03769308433179598
0.03727828387369008
0.053293555231763635
1.5352638731462491
0.38399470530818064
0.14532521104494445
0.021095838645997792
0.05243030115353289
1.4412422846918118
0.0905806377951504
0.01919687330095183
0.2431412918457886
1.1258307790126691
0.05279167314301292
0.14547077420849114
0.06549901805985198
0.14986207436530335
0.013257469925473333
0.2993421588545718
0.03190115690438543
0.07601030858298374
0.6495329189049376
0.02147679657309045
1.678927355736188
0.3785401857027929
0.563804802833155
0.03143217307047877
0.08175269232724783
0.029488617192109788
0.02905798362673553
0.01917067392746968
0.01448

0.029252788241659264
0.04136701291085931
0.3256155172416714
0.21327979149708626
0.040163547307925596
0.012818786808207751
0.01540375009114809
0.6361991398707001
0.22715323073439245
0.2868288408500226
0.2263132044946699
0.07547854525548084
0.04949771169789495
0.034314507584898755
0.11178017006940127
0.03304397072814294
0.03273055204541442
0.13066396571852673
0.44504634656075237
0.16343175020242662
2.167646995977489
0.15899829793093673
0.013573950025739892
0.07867919029674669
0.17050996923277922
0.029520999704428414
0.3949400333255001
0.07638635630678031
1.0156103425326968
0.020286472519488907
0.08249968869750575
0.03227935539314366
0.1296874436554029
0.06867135054003629
0.041578596269851345
0.0517463135413386
0.0828491810306283
0.9406491697658537
0.19238176686649927
0.012289300975424688
0.03700501710074865
0.061670830924539136
0.6631146221619273
0.035919825186770386
0.03431955595915964
0.04967620986891972
1.5578026034888028
0.3860000569223299
0.14661733471489147
0.019410160106927577
0.0

0.031070310967948413
0.7529216166081947
1.731999150834229
0.015777622834015133
0.5369240526035667
0.020546080828133244
0.04703459902632838
1.4288407417111635
0.0562338608143264
0.048066284706216864
0.0526697851028034
0.014852262517963115
0.28252361933132986
0.31261609182112277
1.9928509248809518
0.04572274533095112
1.166121687678934
0.6939311990393489
0.07043699959059954
0.08449766777626519
0.41081304211167646
0.24080077784155124
0.03121646253894656
0.05291722921882335
0.11962066131690291
0.06116107291122087
0.02883601399287438
0.05223091562137057
0.31035007725937164
0.4675628228652191
0.04400798524568231
0.0352168778664259
0.13156720814121248
0.17537273055530772
0.04697664170414082
0.13361376016234902
0.042789675948930544
1.3080911895952099
0.024127411237839593
0.011876398421784857
0.03803584713112603
0.9007712134087154
0.034861613889455925
0.20265190536252767
0.034838494295718045
0.060272537888904604
0.03613907195956579
0.058028214539256204
0.06125543395918222
0.02129341219443387
0.0

0.020231256049397996
0.08890107090531169
0.07927064900574751
0.015638802642119753
0.05912469776325181
0.19927213661302234
0.09745050916874146
1.681749384594631
0.019937578725975652
0.027378795867925382
0.037945013150109945
0.3345552227719777
0.20076640701907966
0.03809220368779091
0.011429563517932426
0.013933240305851751
0.6647054258718414
0.23098644904892202
0.28567152374427185
0.21511818607843503
0.06922052037740149
0.04628845942248734
0.03135804284124284
0.10191128345999526
0.030813577387926318
0.03113483551821207
0.12263731648102855
0.4312066003333639
0.15169046799581568
2.2026970042821254
0.14825130457530258
0.012150063566012975
0.07571799272504895
0.1670280840285631
0.02694436765532726
0.39531649483476544
0.0733526417692662
1.0127285778894732
0.018236801234657284
0.07528152725157111
0.030567328885513598
0.12087945027784433
0.06583730367716056
0.039410318841610004
0.04826346650848429
0.07539163356026689
0.9359356732161401
0.18000240573950052
0.011369611412420839
0.033709428996902

In [24]:
def predict(model, x):
    _,_,_,y_pred = feedfoward(x,model['W1'],model['W2'],model['b1'],model['b2'])
    y_pred = np.argmax(y_pred)
    return y_pred

In [25]:
predict(model, X[40])

1

In [26]:
print(y[40])

1
