In [1]:
import numpy as np
import tensorflow as tf

mu = 0.5
sigma = 1

t = 0.45

def est_deriv_mu(samples):
    approx = np.square(samples * sigma + mu - t) * samples / sigma
    return np.mean(approx)

def est_deriv_sigma(samples):
    approx = np.square(samples * sigma + mu - t) * (np.square(samples) - 1) / sigma
    return np.mean(approx)

def get_noise(num):
    return np.random.normal(size=(num, 1))


for i in range(7):
    num_samples = 10 ** i
    samples = get_noise(num_samples)
    est_mu, est_sigma = est_deriv_mu(samples), est_deriv_sigma(samples)
    print("sample: {} est_mu: {} est_sigma: {}".format(num_samples, est_mu, est_sigma))

sample: 1 est_mu: 0.0006935373586199636 est_sigma: -0.011774594521493162
sample: 10 est_mu: 0.09362509969224495 est_sigma: 0.4174549815457881
sample: 100 est_mu: -0.033885429255860126 est_sigma: 0.722665247297703
sample: 1000 est_mu: 0.17001981843123934 est_sigma: 1.8211163240827901
sample: 10000 est_mu: 0.08997004826256501 est_sigma: 1.9113642117922045
sample: 100000 est_mu: 0.10452717973871888 est_sigma: 1.9908199576288712
sample: 1000000 est_mu: 0.10110889599948944 est_sigma: 2.00844049330221


In [2]:
weights ={
    'wg1': tf.Variable(tf.random_normal([1, 3], stddev=0.01)),
    'wg2': tf.Variable(tf.random_normal([3, 1], stddev=0.01)),
}

biases ={
    'bg1': tf.Variable(tf.zeros([3])),
    'bg2': tf.Variable(tf.zeros([1])),
}

Z = tf.placeholder(tf.float32, [None, 1], name = 'gan_X')
reward = tf.placeholder(tf.float32, [None, 1], name = 'reward')
X = sigma * Z + mu
noise = tf.nn.relu(tf.matmul(X, weights['wg1']) + biases['bg1'])
surr_reward = tf.matmul(noise, weights['wg2']) + biases['bg2']
rev_reward = (reward - surr_reward) 

var_e = sigma + 1e-6
grads_mean = tf.reduce_mean(rev_reward * (Z / var_e) + tf.gradients(surr_reward, X))
grads_var = tf.reduce_mean(rev_reward * ((tf.square(Z) / var_e) - 1 / var_e) + (tf.gradients(surr_reward, X) * Z)) 


S_var_list = [weights['wg1'], biases['bg1'], weights['wg2'], biases['bg2']]

# Gradient of S
opt_S = tf.train.AdamOptimizer(1e-5)
grads_S_m = tf.gradients(grads_mean * grads_mean, S_var_list)
grads_S_v = tf.gradients(grads_var * grads_var, S_var_list)
grads_S = [grads_S_m[i] + grads_S_v[i] for i in range(len(S_var_list))]
grads_and_vars_S = zip(grads_S, S_var_list)


# Ask the optimizer to apply the capped gradients.
train_S = opt_S.apply_gradients(grads_and_vars_S)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(1000):
    num_samples = 1000
    samples = get_noise(num_samples)
    rewards = np.square(samples * sigma + mu - t)
    _, surr_rew, est_mean, est_var = sess.run([train_S, surr_reward, grads_mean, grads_var], feed_dict={Z: samples, reward: rewards})
    print("iter: {}/1000".format(i), "surrogate reward: {}, mean: {}, var: {}".format(np.mean(surr_rew), est_mean, est_var))

iter: 0/1000 surrogate reward: -0.00046388315968215466, mean: 0.08464274555444717, var: 1.7661497592926025
iter: 1/1000 surrogate reward: -0.00047576107317581773, mean: -0.02403576672077179, var: 2.38433837890625
iter: 2/1000 surrogate reward: -0.00045934109948575497, mean: -0.026413818821310997, var: 2.315345048904419
iter: 3/1000 surrogate reward: -0.000463202188257128, mean: 0.20425072312355042, var: 2.059697151184082
iter: 4/1000 surrogate reward: -0.0004569202719721943, mean: 0.0093332938849926, var: 1.853589653968811
iter: 5/1000 surrogate reward: -0.0004686950705945492, mean: 0.18985018134117126, var: 1.9116785526275635
iter: 6/1000 surrogate reward: -0.00045815639896318316, mean: 0.19634389877319336, var: 1.79986572265625
iter: 7/1000 surrogate reward: -0.00046500490861944854, mean: 0.1955706626176834, var: 2.1347286701202393
iter: 8/1000 surrogate reward: -0.0004355458077043295, mean: 0.10062416642904282, var: 1.8110918998718262
iter: 9/1000 surrogate reward: -0.00040830904617

iter: 90/1000 surrogate reward: -0.0003811646602116525, mean: 0.3356274962425232, var: 2.278672695159912
iter: 91/1000 surrogate reward: -0.000344652944477275, mean: 0.0032107392325997353, var: 1.8348532915115356
iter: 92/1000 surrogate reward: -0.0003254428447689861, mean: -0.10482460260391235, var: 1.6305885314941406
iter: 93/1000 surrogate reward: -0.00036254231235943735, mean: 0.09972462803125381, var: 1.9460334777832031
iter: 94/1000 surrogate reward: -0.0003701625973917544, mean: 0.06930607557296753, var: 1.8488893508911133
iter: 95/1000 surrogate reward: -0.00036741537041962147, mean: 0.057263460010290146, var: 2.1598174571990967
iter: 96/1000 surrogate reward: -0.0003876510600093752, mean: 0.2992924153804779, var: 1.8502734899520874
iter: 97/1000 surrogate reward: -0.00037511176196858287, mean: 0.09382960200309753, var: 2.0313758850097656
iter: 98/1000 surrogate reward: -0.0003832534421235323, mean: 0.2485537976026535, var: 1.7385777235031128
iter: 99/1000 surrogate reward: -0.

iter: 188/1000 surrogate reward: -0.0003452951496001333, mean: -0.14668403565883636, var: 1.8055270910263062
iter: 189/1000 surrogate reward: -0.0003404889430385083, mean: -0.04915658012032509, var: 2.2181389331817627
iter: 190/1000 surrogate reward: -0.00034862084430642426, mean: -0.0070398407988250256, var: 2.3371551036834717
iter: 191/1000 surrogate reward: -0.0003270064480602741, mean: 0.05679389089345932, var: 1.9746469259262085
iter: 192/1000 surrogate reward: -0.0003421697474550456, mean: 0.2519560158252716, var: 1.756819486618042
iter: 193/1000 surrogate reward: -0.0003417214029468596, mean: 0.005357707850635052, var: 1.749077320098877
iter: 194/1000 surrogate reward: -0.0003458005958236754, mean: 0.16165371239185333, var: 2.1519243717193604
iter: 195/1000 surrogate reward: -0.0003341592091601342, mean: 0.18964610993862152, var: 2.086228847503662
iter: 196/1000 surrogate reward: -0.00034815381513908505, mean: 0.21034130454063416, var: 2.6066372394561768
iter: 197/1000 surrogate

iter: 293/1000 surrogate reward: -0.00017455275519751012, mean: 0.16468673944473267, var: 1.825097680091858
iter: 294/1000 surrogate reward: -0.00018561970500741154, mean: 0.10180182009935379, var: 2.4420604705810547
iter: 295/1000 surrogate reward: -0.00018230141722597182, mean: 0.12512631714344025, var: 1.972392201423645
iter: 296/1000 surrogate reward: -0.00016861027688719332, mean: 0.22708116471767426, var: 2.14278507232666
iter: 297/1000 surrogate reward: -0.00016746972687542439, mean: 0.10011035948991776, var: 1.6381279230117798
iter: 298/1000 surrogate reward: -0.00014588276098947972, mean: 0.07684144377708435, var: 1.8717753887176514
iter: 299/1000 surrogate reward: -0.00018173451826442033, mean: 0.13500192761421204, var: 2.2753350734710693
iter: 300/1000 surrogate reward: -0.00015445213648490608, mean: -0.14031915366649628, var: 2.019444465637207
iter: 301/1000 surrogate reward: -0.00019256929226685315, mean: 0.24155054986476898, var: 2.834646701812744
iter: 302/1000 surrogate

iter: 404/1000 surrogate reward: 6.004412716720253e-05, mean: 0.044140178710222244, var: 2.103801727294922
iter: 405/1000 surrogate reward: 3.422943700570613e-05, mean: 0.43954265117645264, var: 1.603637456893921
iter: 406/1000 surrogate reward: 0.00010134959302376956, mean: 0.0741359069943428, var: 2.2004566192626953
iter: 407/1000 surrogate reward: 7.37949667382054e-05, mean: 0.04976359382271767, var: 2.3600752353668213
iter: 408/1000 surrogate reward: 6.28885391051881e-05, mean: 0.11141293495893478, var: 1.870942234992981
iter: 409/1000 surrogate reward: 6.437161209760234e-05, mean: 0.13961488008499146, var: 2.0605571269989014
iter: 410/1000 surrogate reward: 7.61039336794056e-05, mean: 0.27392345666885376, var: 2.035325765609741
iter: 411/1000 surrogate reward: 7.693606312386692e-05, mean: 0.3614818751811981, var: 1.739935278892517
iter: 412/1000 surrogate reward: 8.272180275525898e-05, mean: 0.11081106960773468, var: 1.663304090499878
iter: 413/1000 surrogate reward: 7.91863785707

iter: 516/1000 surrogate reward: 0.0003151635464746505, mean: 0.06431818753480911, var: 1.6350549459457397
iter: 517/1000 surrogate reward: 0.000336269848048687, mean: -0.10643596947193146, var: 2.227001428604126
iter: 518/1000 surrogate reward: 0.00034708110615611076, mean: -0.12553313374519348, var: 1.9793362617492676
iter: 519/1000 surrogate reward: 0.0003401131252758205, mean: -0.007559513207525015, var: 1.8935410976409912
iter: 520/1000 surrogate reward: 0.0003622381773311645, mean: -0.04883022978901863, var: 2.0540804862976074
iter: 521/1000 surrogate reward: 0.0003296477079857141, mean: 0.22888709604740143, var: 2.306286573410034
iter: 522/1000 surrogate reward: 0.0003590369888115674, mean: 0.08013517409563065, var: 1.9741110801696777
iter: 523/1000 surrogate reward: 0.0003507809597067535, mean: 0.07866127789020538, var: 1.7116882801055908
iter: 524/1000 surrogate reward: 0.00035056861815974116, mean: 0.19695305824279785, var: 1.7885581254959106
iter: 525/1000 surrogate reward: 

iter: 627/1000 surrogate reward: 0.0005110027268528938, mean: -0.08986084163188934, var: 2.1871109008789062
iter: 628/1000 surrogate reward: 0.00048122438602149487, mean: 0.1645391434431076, var: 1.6054537296295166
iter: 629/1000 surrogate reward: 0.0004805619828402996, mean: 0.10641523450613022, var: 2.5583794116973877
iter: 630/1000 surrogate reward: 0.0004980082157999277, mean: -0.04652063921093941, var: 2.078390598297119
iter: 631/1000 surrogate reward: 0.0005090912454761565, mean: 0.009051910601556301, var: 2.080514907836914
iter: 632/1000 surrogate reward: 0.000508566910866648, mean: 0.051185738295316696, var: 1.4929617643356323
iter: 633/1000 surrogate reward: 0.0004996749339625239, mean: 0.25157633423805237, var: 2.036365509033203
iter: 634/1000 surrogate reward: 0.0004989607841707766, mean: 0.3055306375026703, var: 2.1608777046203613
iter: 635/1000 surrogate reward: 0.0004858992178924382, mean: 0.17301301658153534, var: 2.2535572052001953
iter: 636/1000 surrogate reward: 0.000

iter: 739/1000 surrogate reward: 0.0007108732243068516, mean: 0.011041466146707535, var: 1.8203171491622925
iter: 740/1000 surrogate reward: 0.0006885639159008861, mean: 0.14152073860168457, var: 2.0505552291870117
iter: 741/1000 surrogate reward: 0.0007018688484095037, mean: 0.03704359382390976, var: 1.77042818069458
iter: 742/1000 surrogate reward: 0.000690542918164283, mean: 0.20675094425678253, var: 1.9509419202804565
iter: 743/1000 surrogate reward: 0.0007054904126562178, mean: 0.0666700154542923, var: 1.7315864562988281
iter: 744/1000 surrogate reward: 0.0007075195317156613, mean: 0.11372941732406616, var: 1.947776198387146
iter: 745/1000 surrogate reward: 0.0007002280908636749, mean: 0.20954324305057526, var: 2.0399417877197266
iter: 746/1000 surrogate reward: 0.0006892592646181583, mean: 0.26190024614334106, var: 1.8348965644836426
iter: 747/1000 surrogate reward: 0.0006934707053005695, mean: 0.08523593097925186, var: 2.1669626235961914
iter: 748/1000 surrogate reward: 0.000696

iter: 851/1000 surrogate reward: 0.0006210122373886406, mean: 0.11104398965835571, var: 2.138169288635254
iter: 852/1000 surrogate reward: 0.0006366823217831552, mean: 0.048602212220430374, var: 1.6067484617233276
iter: 853/1000 surrogate reward: 0.0006250198930501938, mean: -0.0148044778034091, var: 2.227543354034424
iter: 854/1000 surrogate reward: 0.0006358270766213536, mean: -0.060290977358818054, var: 1.6241652965545654
iter: 855/1000 surrogate reward: 0.0006156269810162485, mean: 0.20535977184772491, var: 2.022353410720825
iter: 856/1000 surrogate reward: 0.0006378740072250366, mean: 0.06465256959199905, var: 1.940861463546753
iter: 857/1000 surrogate reward: 0.0006218779017217457, mean: 0.13691815733909607, var: 2.3009986877441406
iter: 858/1000 surrogate reward: 0.0006544198840856552, mean: -0.04708308354020119, var: 1.5802398920059204
iter: 859/1000 surrogate reward: 0.000647944281809032, mean: -0.09314162284135818, var: 1.710227131843567
iter: 860/1000 surrogate reward: 0.000

iter: 966/1000 surrogate reward: 0.0007895356393419206, mean: -0.16246536374092102, var: 1.5772888660430908
iter: 967/1000 surrogate reward: 0.0007891146815381944, mean: 0.0481988750398159, var: 2.022115468978882
iter: 968/1000 surrogate reward: 0.0007517116027884185, mean: 0.2295171469449997, var: 2.1948673725128174
iter: 969/1000 surrogate reward: 0.0007771980599500239, mean: 0.15162041783332825, var: 1.8270587921142578
iter: 970/1000 surrogate reward: 0.0007782243774272501, mean: 0.07273710519075394, var: 1.8539619445800781
iter: 971/1000 surrogate reward: 0.0007740508299320936, mean: -0.018332328647375107, var: 1.9331002235412598
iter: 972/1000 surrogate reward: 0.0007622329285368323, mean: 0.16171488165855408, var: 1.8676691055297852
iter: 973/1000 surrogate reward: 0.000747120997402817, mean: 0.2039431780576706, var: 1.7812553644180298
iter: 974/1000 surrogate reward: 0.0007684850716032088, mean: 0.10738377273082733, var: 1.4333873987197876
iter: 975/1000 surrogate reward: 0.0007