In [1]:
import numpy as np
from numba import jit
from timeit import default_timer as timer

In [2]:
IL = 1 #input layer nodes
HL1 = 10 #hidden layer nodes
HL2 = 10 #hidden layer nodes
OL = 1 #output layer nodes
w1 = np.random.randn(HL1,IL) #weight matrix W1
b1 = np.random.randn(HL1) #bias b1
w2 = np.random.randn(HL2,HL1) #weight matrix W2
b2 = np.random.randn(HL2) #bias b2
w3 = np.random.randn(OL,HL2)  #weight matrix W3

#Number of elements in weight matrices
NumWeights1 = len(w1.flatten())
NumWeights2 = len(w2.flatten())
NumWeights3 = len(w3.flatten())

In [3]:
s = np.random.randn(IL,10000) #input data 
x = 2*s**2 + 5 #output data

In [4]:
#forward propagation
@jit
def predict(s,w1,w2,w3,b1,b2):
    h1 = np.dot(w1, s) + b1 #input to hidden layer 1        
    h1 = np.where(h1 < 0, h1, 0) #relu                      
    h2 = np.dot(w2, h1) + b2 #input to hidden layer 2          
    h2 = np.where(h2 < 0, h2, 0) #relu          
    out = np.dot(w3, h2) #hidden layer to output
    #out = 1.0 / (1.0 + np.exp(-out)) #sigmoid if needed
    return out

In [5]:
#reward function
@jit
def f(out): return np.linalg.norm(x)**2/np.linalg.norm(out-x)**2

In [6]:
npop = 100    # population size
sigma = 0.01    # noise standard deviation
alpha = 0.0001  # learning rate


@jit
def ES_DL():
    w = np.random.randn(NumWeights1 + NumWeights2 + NumWeights3 + HL1 + HL2)
    for i in range(5000):
        N = np.random.randn(npop, NumWeights1 + NumWeights2 + NumWeights3 + HL1 + HL2) #initiate population
        R = np.zeros(npop) #reward
        for j in range(npop):
            w_trial = w + sigma*N[j]
            
            #Reshape weight and biases
            w1_trial = w_trial [:NumWeights1].reshape(w1.shape)
            w2_trial = w_trial [NumWeights1:NumWeights1+NumWeights2].reshape(w2.shape)
            w3_trial = w_trial [NumWeights1+NumWeights2: NumWeights1 + NumWeights2 + NumWeights3].reshape(w3.shape)
            b1_trial = w_trial [NumWeights1 + NumWeights2 + NumWeights3 : NumWeights1 + NumWeights2 + NumWeights3 + HL1].reshape((HL1,1))
            b2_trial = w_trial [NumWeights1 + NumWeights2 + NumWeights3 + HL1:].reshape((HL2,1))
            
            #Compute output
            out = predict(s,w1_trial,w2_trial,w3_trial,b1_trial,b2_trial)
            
            #Observe reward score
            R[j] = f(out)
        
        #Reward Standardization
        A = (R - np.mean(R)) / np.std(R)
        
        #Update
        w = w + alpha/(npop*sigma) * np.dot(N.T, A)
        
        #Check current performance
        w1_test = w [:NumWeights1].reshape(w1.shape)
        w2_test = w [NumWeights1:NumWeights1+NumWeights2].reshape(w2.shape)
        w3_test = w [NumWeights1+NumWeights2: NumWeights1 + NumWeights2 + NumWeights3].reshape(w3.shape)
        b1_test = w [NumWeights1 + NumWeights2 + NumWeights3 : NumWeights1 + NumWeights2 + NumWeights3 + HL1].reshape((HL1,1))
        b2_test = w [NumWeights1 + NumWeights2 + NumWeights3 + HL1:].reshape((HL2,1))
        
        out_test = predict(s,w1_test,w2_test,w3_test,b1_test,b2_test)
        
        print('At i =', i) 
        print('NMSE =',1/f(out_test))
    return w

In [7]:
start = timer() 
w = ES_DL()
print("Execution time:", timer()-start) 



('At i =', 0)

('NMSE =', 6.723838395675305)

('At i =', 1)

('NMSE =', 6.662387904944113)

('At i =', 2)

('NMSE =', 6.594545152755219)

('At i =', 3)

('NMSE =', 6.528466776242186)

('At i =', 4)

('NMSE =', 6.46816873627983)

('At i =', 5)

('NMSE =', 6.397782371325512)

('At i =', 6)

('NMSE =', 6.3318951985350775)

('At i =', 7)

('NMSE =', 6.265987390824007)

('At i =', 8)

('NMSE =', 6.199608645917765)

('At i =', 9)

('NMSE =', 6.137650547284819)

('At i =', 10)

('NMSE =', 6.079376136195273)

('At i =', 11)

('NMSE =', 6.01619755765749)

('At i =', 12)

('NMSE =', 5.953349138131918)

('At i =', 13)

('NMSE =', 5.88576041152035)

('At i =', 14)

('NMSE =', 5.8270953325208135)

('At i =', 15)

('NMSE =', 5.7640707381029435)

('At i =', 16)

('NMSE =', 5.704930164163159)

('At i =', 17)

('NMSE =', 5.6500696137639)

('At i =', 18)

('NMSE =', 5.595011620729443)

('At i =', 19)

('NMSE =', 5.5315127110984745)

('At i =', 20)

('NMSE =', 5.474770604365594)

('At i =', 21)

('NMSE =

('At i =', 172)

('NMSE =', 0.5508316009360805)

('At i =', 173)

('NMSE =', 0.5370689279991727)

('At i =', 174)

('NMSE =', 0.524016168968783)

('At i =', 175)

('NMSE =', 0.5113162713871771)

('At i =', 176)

('NMSE =', 0.4988325509248608)

('At i =', 177)

('NMSE =', 0.4870038971095926)

('At i =', 178)

('NMSE =', 0.47491664439686404)

('At i =', 179)

('NMSE =', 0.4621785454649419)

('At i =', 180)

('NMSE =', 0.4510663447528036)

('At i =', 181)

('NMSE =', 0.4384533815307264)

('At i =', 182)

('NMSE =', 0.42537633545273906)

('At i =', 183)

('NMSE =', 0.4127881518630182)

('At i =', 184)

('NMSE =', 0.4004263762699472)

('At i =', 185)

('NMSE =', 0.3888166755046046)

('At i =', 186)

('NMSE =', 0.37664484211217725)

('At i =', 187)

('NMSE =', 0.36580853463887314)

('At i =', 188)

('NMSE =', 0.35507831725307826)

('At i =', 189)

('NMSE =', 0.3449510781437781)

('At i =', 190)

('NMSE =', 0.3339165439756779)

('At i =', 191)

('NMSE =', 0.32351546507115364)

('At i =', 192)

KeyboardInterrupt: 