In [1]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append('../../../module/')

from keras2.models import Model
from keras2.layers import concatenate, Dense, Input, Flatten
from keras2.optimizers import Adam
import csv
from util import *
import gym2
from rl2.agents import selfDDPGAgent, selfDDPGAgent2
from rl2.memory import SequentialMemory

Using TensorFlow backend.
Using TensorFlow backend.


In [2]:
# GymのPendulum環境を作成
env = gym2.make("Pendulum-v2")

# 取りうる”打ち手”のアクション数と値の定義
nb_actios = 2
ACT_ID_TO_VALUE = {0: [-1], 1: [+1]}

In [3]:
def critic_net(a_shape , s_shape):
    action_input = Input(a_shape)
    observation_input = Input(shape=(1,)+s_shape)
    flattened_observation = Flatten()(observation_input)
    x = concatenate([action_input, flattened_observation])
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(1, activation="linear")(x)
    critic = Model(inputs=[action_input, observation_input], outputs=x)
    return (critic, action_input)

def branch_actor(a_shape, s_shape):
    action_input = Input(shape=(1,)+s_shape)
    x = Flatten()(action_input) # 実質的なinput layer
    
    x1 = Dense(8, activation="relu")(x)
    x1 = Dense(8, activation="relu")(x1)
    x1 = Dense(1, activation="multiple_tanh")(x1) # action signal
    
    x2 = Dense(8, activation="relu")(x)
    x2 = Dense(8, activation="relu")(x2)
    x2 = Dense(1, activation="tau_output")(x2) # tau
    
    output = concatenate([x1, x2])
    actor = Model(inputs=action_input, outputs=output)
    return actor


def agent2(a_shape, s_shape):
    actor = branch_actor(a_shape, s_shape)
    critic,  critic_action_input = critic_net(a_shape, s_shape)
    memory = SequentialMemory(limit = 50000, window_length = 1)
    agent = selfDDPGAgent2(
        a_shape[0],
        actor,
        critic,
        critic_action_input,
        memory,
        original_noise=True,
        action_clipper=[-10., 10.],
        tau_clipper=[0.001, 1.],
        params_logging=False,
        gradient_logging=False,
        batch_size=128,
    )
    return agent

In [4]:
# agent compilation
l = 1.
step = 50000  # num of interval
episode_step = step
a = agent2((2,), (2,))
actor_optimizer, critic_optimizer = Adam(lr=100., clipnorm=1.), Adam(lr=0.001, clipnorm=1.) # actorの方は何でもいい
optimizer = [actor_optimizer, critic_optimizer]
a.compile(optimizer=optimizer, metrics=["mse"], action_lr=0.0001, tau_lr=0.001)









In [5]:
# agent setup
a.load_weights('../saved_agent/learned_self_clipped3.h5')
a.training = False

In [14]:
# experiment

l = 1.

view_path = False
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']

# step wise evaluation

step_limit = 5000
n_episodes = 1
gamma = .99
average_reward = 0

#start_point = np.array([[1.5, -1.1],[-1.5, -1.1]])


for ep in range(n_episodes):
    state_log = []
    env.reset()
    env.set_state(np.random.uniform(low=-np.array([np.pi,2*np.pi]), high=np.array([np.pi,2*np.pi])))
    episode_reward = 0
    for steps in range(step_limit):
        reward = 0
        x = env.state
        a_agent, tau = a.forward(x)
        state_log.append(x)
        action_repetition = int(np.ceil(20 * tau))  # minimum natural number which makes `dt` smaller than 0.005
        dt = tau / action_repetition
        for p in range(action_repetition):
            _,r,_,_ = env.step(np.array([a_agent]), dt, tau)
            reward += r
        reward *= dt
        reward += - tau * 0.01 * a_agent**2 + l * tau # step reward
        if steps < 5000:
            print(steps, reward, pow(gamma, steps) * reward)

        episode_reward += pow(gamma, steps) * reward
    #print(episode_reward)
    average_reward += episode_reward / n_episodes
    state_log = np.array(state_log)
    if view_path:
        plt.scatter(state_log[:,0], state_log[:,1], alpha=0.4, color=cycle[ep])
        plt.plot(state_log[:,0], state_log[:,1], alpha=0.4, color=cycle[ep])
        plt.scatter(state_log[0,0], state_log[0,1], marker='o', color=cycle[ep])
        plt.scatter(state_log[-1,0], state_log[-1,1], marker='x', color=cycle[ep])

if view_path:
    plt.show()

0 -3.7119965521505165 -3.7119965521505165
1 -0.5853970428010997 -0.5795430723730887
2 0.17655748389949252 0.17304398996989262
3 0.8216476434183237 0.7972438867611561
4 0.516158844114647 0.49582012618274185
5 0.6672401138451437 0.6345387091608748
6 0.8494791481780214 0.7997677553396778
7 0.4271655081711453 0.39814616798740465
8 0.5949514336949667 0.548988278884315
9 0.7669159868694605 0.7005909813761896
10 0.4698186534354515 0.42489556867179595
11 0.6251041685664653 0.5596796750141454
12 0.799473516895306 0.7086412307136885
13 0.3885904004291232 0.3409962457121427
14 0.5719118882202621 0.4968460581641526
15 0.7478803489439951 0.6432207423813251
16 0.46883136434476597 0.3991901085043639
17 0.6196791268103531 0.5223543020268824
18 0.7924286940989478 0.6612926501934937
19 0.4188707610547969 0.3460578802256064
20 0.5853077295987626 0.478727252668112
21 0.754629806295293 0.6110447843477087
22 0.5173082754615368 0.41469013783165887
23 0.6664201508333092 0.5288805506092734
24 0.848335435672252

227 0.6721073631896652 0.06864799019455763
228 0.8530466320556164 0.08625755443206735
229 0.4157256704457528 0.041616581692915974
230 0.583536514325512 0.05783127860075703
231 0.7529402176117965 0.07387380134952684
232 0.5228720443408283 0.05078793665823658
233 0.6708748443051297 0.06451219944205688
234 0.8520378430441774 0.08111373883214089
235 0.4192435426036758 0.039512736886561456
236 0.587097061067024 0.05477921840174174
237 0.7574118096277597 0.06996376816400322
238 0.5062120564098039 0.046292304240867214
239 0.6571953084572911 0.059498492641086025
240 0.8403314611241078 0.07531774811265998
241 0.46011879075790213 0.040827413538068555
242 0.6170977145694669 0.054208954407574374
243 0.7914156913032564 0.06882669948900456
244 0.41921082139041643 0.036092749446343525
245 0.5854419500811237 0.04990068589851151
246 0.7547142022781352 0.0636854751907516
247 0.5171097538218336 0.043199209045653636
248 0.6662659964284262 0.0551030844197586
249 0.8482041307749433 0.06944865479116127
250 0

474 0.8532829798098602 0.007280736644126047
475 0.4148710773495571 0.003504536543737054
476 0.5826641986199828 0.00487271445852031
477 0.7523226723098919 0.006228621967514536
478 0.5247692294974549 0.0043012172546667265
479 0.6723922672776368 0.005456082423967179
480 0.853285537053108 0.006854689566535921
481 0.414868887753651 0.0032994341750180564
482 0.5826624943456367 0.004587550517916746
483 0.7523216015658644 0.005864115594417424
484 0.5247681572978927 0.004049502389638493
485 0.6723928337962723 0.0051367976236241985
486 0.8532872497813869 0.006453567110876092
487 0.4148509934522774 0.003106217795548183
488 0.5826471233762485 0.004318973806991282
489 0.7523179666896526 0.0055209217512221295
490 0.524771937096738 0.003812553575650659
491 0.6723952332836852 0.004836210252472525
492 0.8532878331748026 0.006075909481820047
493 0.4148551415720507 0.0029244716358982955
494 0.5826486982146877 0.004066239095683777
495 0.7523181981593025 0.005197839834418236
496 0.5247750894287471 0.003589

679 0.41488362467540124 0.0004510486896060569
680 0.5826786017912756 0.000627135514941129
681 0.7523255180352096 0.0008016288039561794
682 0.5247605701001259 0.0005535589931584759
683 0.672387078724407 0.0007021942504141066
684 0.8532808638691995 0.0008821960537324478
685 0.414869006649463 0.00042463842542328386
686 0.5826629955009923 0.0005904198242490063
687 0.7523217956644929 0.0007547138719064724
688 0.524770073212367 0.0005211742415760857
689 0.6723951606897435 0.0006611098941085963
690 0.8532858745218078 0.0008305749497588515
691 0.4148441106700882 0.0003997646571905959
692 0.5826384661151522 0.0005558451429704219
693 0.752316368767032 0.0007105430033157403
694 0.5247587131431091 0.000490664580829543
695 0.6723849294933149 0.0006224123710757842
696 0.8532789726790594 0.0007819635027881263
697 0.41489636535174523 0.0003764178976362465
698 0.5826904555970323 0.0005233638644237468
699 0.7523291594067764 0.0006689735063955059
700 0.5247600974421797 0.0004619521814791278
701 0.6723834

872 0.58267577318722 9.105769450721013e-05
873 0.7523248900095303 0.00011639392510458888
874 0.524772126512794 8.037682325469374e-05
875 0.6723952855415011 0.00010195767715136357
876 0.8532878079592126 0.0001280931801051022
877 0.41484623278884486 6.165280107100026e-05
878 0.5826402500265593 8.572378531068373e-05
879 0.7523165705593544 0.00010958135819759543
880 0.5247688961651217 7.567271774365253e-05
881 0.6723943661526002 9.599099786511508e-05
882 0.8532867320428174 0.00012059703428078909
883 0.4148627862768287 5.80472045115316e-05
884 0.5826559233438912 8.07094132671458e-05
885 0.7523201863079001 0.00010316916933186664
886 0.5247689445219973 7.124436817196313e-05
887 0.6723937794975602 9.037354016143266e-05
888 0.8532867484571064 0.0001135397160361067
889 0.41482503742988674 5.464531808237268e-05
890 0.5826176366295128 7.598131734902103e-05
891 0.7523106228196866 9.713049021842922e-05
892 0.5247699374011445 6.707528529880333e-05
893 0.6723948361463644 8.508502780171365e-05
894 0.85

1052 0.5826419696150631 1.4915110722462014e-05
1053 0.7523168103028215 1.906604779284077e-05
1054 0.524748944665124 1.3165782603819325e-05
1055 0.6723773329401244 1.670103416956092e-05
1056 0.853274656508429 2.098236036710697e-05
1057 0.41488843212187954 1.0100247491519032e-05
1058 0.582683041477791 1.4043270546151994e-05
1059 0.7523266738604862 1.7950540869322744e-05
1060 0.5247709040265216 1.2395841684447991e-05
1061 0.6723943976256025 1.572409120654095e-05
1062 0.8532879114481633 1.975478264320013e-05
1063 0.41485745991121514 9.508472638738998e-06
1064 0.5826522728343568 1.3220762291166053e-05
1065 0.7523193911179535 1.6899914301775236e-05
1066 0.5247705765391846 1.1670431597997689e-05
1067 0.6723925185751536 1.480387836779997e-05
1068 0.8532850320187036 1.859867295266573e-05
1069 0.41487827100830665 8.952487314587002e-06
1070 0.5826725077027043 1.2447517530572211e-05
1071 0.7523241555045329 1.5911034604541072e-05
1072 0.5247651254345373 1.0987365550958974e-05
1073 0.67238763268986 

1244 0.5826504141337568 2.1656593202857653e-06
1245 0.752319711828344 2.7683418301457042e-06
1246 0.52477930434879 1.911741500604793e-06
1247 0.6724009202028874 2.4250235598256755e-06
1248 0.8532919612670188 3.0466355961699615e-06
1249 0.4148515317265006 1.466394254305178e-06
1250 0.5826451426926003 2.038906813547592e-06
1251 0.7523170018717018 2.606329491455926e-06
1252 0.5247780433139145 1.7998623485594624e-06
1253 0.6724008002872873 2.2831111362367994e-06
1254 0.8532902184868919 2.8683410778848313e-06
1255 0.41483296318962104 1.3805192875392107e-06
1256 0.5826284500731604 1.919535295711636e-06
1257 0.7523141803204589 2.453798276043803e-06
1258 0.5247522848394135 1.6944514974144278e-06
1259 0.6723795975384503 2.1494360335492554e-06
1260 0.8532753069507601 2.7004389946275174e-06
1261 0.4148928331613677 1.299919086339914e-06
1262 0.5826870465214709 1.8073861321997737e-06
1263 0.7523272730710885 2.310242572678353e-06
1264 0.5247647338414755 1.5953302949842466e-06
1265 0.672389404896285 

1452 0.8532860317479563 3.9210031568204554e-07
1453 0.41484857001206354 1.8872409334941775e-07
1454 0.582641693621051 2.6240644898275654e-07
1455 0.7523171604661917 3.3543556525491245e-07
1456 0.5247783644716308 2.316430399366392e-07
1457 0.6724002886577751 2.9383695839866703e-07
1458 0.8532907874224795 3.691567212222024e-07
1459 0.4148647167637949 1.7768690326937528e-07
1460 0.5826590924903953 2.4705784022278167e-07
1461 0.752320927020315 3.158075072058243e-07
1462 0.5247643855275562 2.1808151447950463e-07
1463 0.672388397727218 2.7663677127766545e-07
1464 0.8532827462583528 3.475504498034738e-07
1465 0.4148719499716399 1.6729160893173925e-07
1466 0.5826636533631662 2.326018730439457e-07
1467 0.7523213099497446 2.973266504045252e-07
1468 0.524778811664162 2.0532506120680993e-07
1469 0.6724005817539185 2.604527482057463e-07
1470 0.8532928473133352 3.2721572289851993e-07
1471 0.41479946250162175 1.5747420987256862e-07
1472 0.5825926076935255 2.1896334416788403e-07
1473 0.752304569605413

1694 0.5826338021481922 2.3518809010828723e-08
1695 0.7523146348744542 3.006452544796169e-08
1696 0.5247584128839544 2.0761056355372217e-08
1697 0.6723835949934158 2.6335546856799676e-08
1698 0.8532754250500825 3.308641132001196e-08
1699 0.4149026815877719 1.592728909564389e-08
1700 0.5826973796182886 2.2144908027803113e-08
1701 0.7523308280664733 2.830576316453922e-08
1702 0.5247532844559342 1.9545931416262834e-08
1703 0.6723780084138928 2.4794188582103687e-08
1704 0.8532727378563953 3.1150101372356e-08
1705 0.41488749082930787 1.499467749974465e-08
1706 0.5826813493419039 2.084841775302302e-08
1707 0.7523265788093981 2.6649163614454854e-08
1708 0.5247505544138662 1.840201069254654e-08
1709 0.6723780056731111 2.3343236275402646e-08
1710 0.8532727922392023 2.9327203963052993e-08
1711 0.41493117478646113 1.4118677627164512e-08
1712 0.5827293046725429 1.962998689794614e-08
1713 0.7523386054492379 2.50900596226676e-08
1714 0.5247506499425106 1.7325130930064592e-08
1715 0.6723753139583784 

1885 0.4149106451953179 2.4563847578547326e-09
1886 0.5827064993977108 3.4152843900211917e-09
1887 0.7523323753356866 4.365379020218309e-09
1888 0.5247516315921303 3.014402197874234e-09
1889 0.672380173907576 3.8238204714213e-09
1890 0.85327532681241 4.804044008358525e-09
1891 0.4149091482170154 2.312629144922958e-09
1892 0.5827038003439905 3.215407564163321e-09
1893 0.7523317415016011 4.109914229573922e-09
1894 0.5247667147133018 2.83808140523916e-09
1895 0.6723883837836537 3.6000950259510803e-09
1896 0.853282689956777 4.5229511001541024e-09
1897 0.41487796667909965 2.1771308033411804e-09
1898 0.582670987023852 3.027071923280991e-09
1899 0.7523242889047927 3.869364332589092e-09
1900 0.5247568013298427 2.6719468286369923e-09
1901 0.6723835887905475 3.389393831984907e-09
1902 0.8532757769636959 4.258234178527147e-09
1903 0.4148922568146798 2.0497960351278113e-09
1904 0.5826878641064716 2.8500106748318453e-09
1905 0.752327944472177 3.642947411042928e-09
1906 0.5247580167340556 2.51559072

2115 0.7523308377376868 4.4141332340599163e-10
2116 0.5247459437485953 3.048041379690145e-10
2117 0.6723724048086698 3.8664895793200574e-10
2118 0.8532691671236821 4.857672260564966e-10
2119 0.41489470862249445 2.3383808636673175e-10
2120 0.5826859284064556 3.251225383778267e-10
2121 0.7523265798175542 4.1557952962489133e-10
2122 0.5247764491251777 2.8698372778370324e-10
2123 0.6724003099483287 3.6403742651797343e-10
2124 0.8532911799633639 4.573519991323866e-10
2125 0.4148483333511832 2.2012930856383402e-10
2126 0.5826429089310737 3.060738170262914e-10
2127 0.7523170814296315 3.9125493784554383e-10
2128 0.5247665048422411 2.7018436293757346e-10
2129 0.6723894448203258 3.427284725629537e-10
2130 0.8532819870783404 4.3058318955800666e-10
2131 0.4148641825397243 2.0725529215417485e-10
2132 0.5826577575745557 2.881697667955307e-10
2133 0.7523210337441016 3.683606925176017e-10
2134 0.5247643501358058 2.543721699205107e-10
2135 0.6723869130214682 3.2267083857104575e-10
2136 0.85328141751475

2313 0.7523215130777055 6.034045560684917e-11
2314 0.5247626507499715 4.166805376753358e-11
2315 0.6723872736627993 5.2856083300193725e-11
2316 0.8532798017431517 6.640520983284941e-11
2317 0.414856186826872 3.1962699643905125e-11
2318 0.5826507279550677 4.4441567700363786e-11
2319 0.7523194772428162 5.6809187429590905e-11
2320 0.5247715266041695 3.9230309017929845e-11
2321 0.6723937945925225 4.976343581206714e-11
2322 0.8532852540978905 6.251958636447897e-11
2323 0.4148664737910486 3.0092993417159686e-11
2324 0.5826599337491413 4.184151487736484e-11
2325 0.7523210878976256 5.3484836775000747e-11
2326 0.5247701503061385 3.6934460328410205e-11
2327 0.6723934323078401 4.6851261739654864e-11
2328 0.8532855969691102 5.886097316271062e-11
2329 0.4148503063405021 2.8330851834881258e-11
2330 0.5826447462770628 3.939192887070678e-11
2331 0.7523175350260793 5.03546743141867e-11
2332 0.5247756320750607 3.4773424468755303e-11
2333 0.6723993749873934 4.41099227466832e-11
2334 0.8532908575652751 5.

2509 0.41485269002836667 4.640845984850556e-12
2510 0.5826475442755679 6.452742878651503e-12
2511 0.7523185474703667 8.248508065275839e-12
2512 0.5247734231567005 5.696140085374866e-12
2513 0.6723946793605657 7.225506430235652e-12
2514 0.8532856179551711 9.077654388897941e-12
2515 0.41485192171546315 4.369256279226256e-12
2516 0.5826456019630204 6.075109077401542e-12
2517 0.752318035148932 7.76580131719433e-12
2518 0.5247767438583033 5.362836753743632e-12
2519 0.6723991737363605 6.8027163433963524e-12
2520 0.8532914133096294 8.546489456015259e-12
2521 0.4148351188231307 4.11340144125302e-12
2522 0.5826291260740943 5.719432864727316e-12
2523 0.7523136080648252 7.311304760037596e-12
2524 0.5247680792973852 5.048920984285367e-12
2525 0.6723936132244612 6.40456943526996e-12
2526 0.853286286667301 8.046301826794823e-12
2527 0.414852498994282 3.872848055732906e-12
2528 0.5826467858117244 5.384895721514395e-12
2529 0.7523184851786258 6.883492921953612e-12
2530 0.5247789381832739 4.75355724465

2696 0.5826293577020216 9.951235387329694e-13
2697 0.7523139847918918 1.272093298147226e-12
2698 0.5247497028714511 8.784301339394275e-13
2699 0.6723785198547245 1.1143049238937494e-12
2700 0.8532728282588126 1.3999525936928856e-12
2701 0.41488004310874893 6.738809082560933e-13
2702 0.5826725892278831 9.36958575774333e-13
2703 0.7523239985679404 1.1976665298802138e-12
2704 0.524760532997252 8.270416024100708e-13
2705 0.6723856274733371 1.0491070560714254e-12
2706 0.853280413292286 1.3180392934613957e-12
2707 0.4148678912647087 6.344269152631986e-13
2708 0.5826621764363734 8.821121356224614e-13
2709 0.7523215292145968 1.1275755624303603e-12
2710 0.5247593786627676 7.786415385885091e-13
2711 0.6723846454700853 9.877120253556895e-13
2712 0.8532791191952545 1.2409059489462184e-12
2713 0.4148858040930539 5.973261367145343e-13
2714 0.5826800239903189 8.305165040471662e-13
2715 0.7523262209981755 1.0615966294856776e-12
2716 0.5247709357828794 7.330916970840176e-13
2717 0.6723950738469945 9.29

2897 0.6724005050764056 1.52330469310426e-13
2898 0.8532905695312417 1.9137750554051694e-13
2899 0.41484031343799577 9.211070196492782e-14
2900 0.5826325326077825 1.280734367644493e-13
2901 0.7523137047678341 1.6371878043047594e-13
2902 0.5247700318866496 1.1305863762899702e-13
2903 0.6723965124844511 1.4341526142589883e-13
2904 0.853288119837514 1.8017760523874103e-13
2905 0.41485264599920024 8.672297551072368e-14
2906 0.5826471452029205 1.2058162252597545e-13
2907 0.7523185778079817 1.5413898027348535e-13
2908 0.5247635110547246 1.06441140383962e-13
2909 0.6723873217031882 1.3502077617092218e-13
2910 0.8532827385706113 1.6963256890327332e-13
2911 0.4148532407201062 8.164807698851905e-14
2912 0.5826459840842071 1.1352497775394778e-13
2913 0.7523174865169353 1.4511857967137993e-13
2914 0.5247733690250382 1.0021410329262436e-13
2915 0.6723969642069744 1.271212035023352e-13
2916 0.8532880762031272 1.5970669533877805e-13
2917 0.41485044567764934 7.68695258153266e-14
2918 0.582645024299586

KeyboardInterrupt: 

In [9]:
print(average_reward)

63.26072555038902


In [None]:
# good_agentはどうなん？