## 教師あり学習でLQRエージェントを再現する

In [1]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append('../../module/')

import keras2
from keras2.models import Model
from keras2.layers import concatenate, Dense, Input, Flatten
from keras2.optimizers import Adam
from util import moving_average, dlqr, discretized_system, lqr
import gym2
from rl2.agents import selfDDPGAgent
from rl2.memory import SequentialMemory

Using TensorFlow backend.
Using TensorFlow backend.


In [2]:
# GymのPendulum環境を作成
env = gym2.make("Pendulum-v2")
clip = 10.
env.max_torque, max_torque = clip, clip
env.action_space.high, env.action_space.low = max_torque, -max_torque
control_interval = 5

# 取りうる”打ち手”のアクション数と値の定義
nb_actios = 2
ACT_ID_TO_VALUE = {0: [-1], 1: [+1]}

In [3]:
def gain():
    m, l, g = env.m, env.l, env.g

    A = np.array([[0, 1], [(3*g)/(2*l), 0]])
    B = np.array([[0], [3/(m*l**2)]])
    Q = np.array([[1, 0], [0, 0.1]])
    R = np.array([[0.001]])
    
    # Ad, Bd = discretized_system(A, B, control_interval * dt)
    
    # K = dlqr(Ad,Bd,Q,R)[0]
    K = lqr(A,B,Q,R)[0]
    
    return K

def model(a_shape, s_shape):
    action_input = Input(shape=(1,)+s_shape)
    x = Flatten()(action_input)
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(2, activation="self_trigger_output")(x)
    model = Model(inputs=action_input, outputs=x)
    return model

In [4]:
NN = model((2,),(2,))






In [5]:
# 学習データの用意
K = gain()
x_train = []
y_train = []
for i in range(10000):
    x_train.append([np.random.randn(2,) / 16.])
    y_train.append([np.dot(K, x_train[-1][0]), .7])

x_train = np.array(x_train)
y_train = np.array(y_train)
print(x_train.shape, y_train.shape)

(10000, 1, 2) (10000, 2)


In [6]:
# learn
NN.compile(loss='mean_squared_error',optimizer='adam')
history = NN.fit(x_train, y_train, batch_size=128, epochs=50, verbose=0)






In [7]:
# save
NN.save_weights('./saved_agent/supervised_test.h5')

## 学習したNNを, actorにloadする

In [10]:
def actor_net(a_shape, s_shape):
    action_input = Input(shape=(1,)+s_shape)
    x = Flatten()(action_input)
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(2, activation="self_trigger_output")(x)
    actor = Model(inputs=action_input, outputs=x)
    return actor

def critic_net(a_shape , s_shape):
    action_input = Input(a_shape)
    observation_input = Input(shape=(1,)+s_shape)
    flattened_observation = Flatten()(observation_input)
    x = concatenate([action_input, flattened_observation])
    x = Dense(16, activation="relu")(x)
    x = Dense(16, activation="relu")(x)
    x = Dense(1, activation="linear")(x)
    critic = Model(inputs=[action_input, observation_input], outputs=x)
    return (critic, action_input)

def agent(a_shape, s_shape):
    actor = actor_net(a_shape, s_shape)
    critic,  critic_action_input = critic_net(a_shape, s_shape)
    memory = SequentialMemory(limit = 50000, window_length = 1)
    agent = selfDDPGAgent(
        a_shape[0],
        actor,
        critic,
        critic_action_input,
        memory,
        clip_com = 0.01
    )
    return agent

In [11]:
a = agent((2,), (2,))
a.compile(Adam(lr=0.001, clipnorm=1.), metrics=["mae"])
a.actor.load_weights('./saved_agent/supervised_test.h5')

In [17]:
for i in range(1000):
    x = np.random.randn(2,) / 16.
    print(f'output of NN: {a.forward(x)[0]}, optimal_agent: {np.dot(K, x)}')

output of NN: -2.1574277877807617, optimal_agent: -2.098307421766262
output of NN: 0.7644964456558228, optimal_agent: 0.7292422287641805
output of NN: -1.5843902826309204, optimal_agent: -1.5305860994738802
output of NN: 4.115272521972656, optimal_agent: 4.0838682177452
output of NN: -2.3474106788635254, optimal_agent: -2.2895092103578665
output of NN: 1.3494236469268799, optimal_agent: 1.2950294996851153
output of NN: 0.948069155216217, optimal_agent: 0.9059683734154873
output of NN: 0.13344940543174744, optimal_agent: 0.12449022561917486
output of NN: 4.727049827575684, optimal_agent: 4.707878641288088
output of NN: 0.05330720543861389, optimal_agent: 0.048123404194963304
output of NN: -0.15598075091838837, optimal_agent: -0.1521015356992535
output of NN: -0.48911604285240173, optimal_agent: -0.4711356821931548
output of NN: 1.2730462551116943, optimal_agent: 1.2209812113917846
output of NN: -0.845525860786438, optimal_agent: -0.8134098503448557
output of NN: 0.24300017952919006, opt

output of NN: 0.013554980978369713, optimal_agent: 0.009575195376511969
output of NN: -0.8997747898101807, optimal_agent: -0.8660150317223838
output of NN: -4.255058765411377, optimal_agent: -4.3471376676190605
output of NN: 0.04042408987879753, optimal_agent: 0.03582097615447899
output of NN: -1.1342437267303467, optimal_agent: -1.0920616802804877
output of NN: 1.9547491073608398, optimal_agent: 1.8900106244232444
output of NN: -0.04150724038481712, optimal_agent: -0.04276150155940711
output of NN: -0.5806862711906433, optimal_agent: -0.5588722538681716
output of NN: 4.04924201965332, optimal_agent: 4.019050924107083
output of NN: -3.2939071655273438, optimal_agent: -3.2737791641877427
output of NN: -0.24922123551368713, optimal_agent: -0.24105521646096295
output of NN: 1.1883068084716797, optimal_agent: 1.138323587807385
output of NN: 0.16935960948467255, optimal_agent: 0.15876449582186442
output of NN: -0.3877403736114502, optimal_agent: -0.3743857800392003
output of NN: 1.396290063

output of NN: 1.9984130859375, optimal_agent: 1.9330257943108577
output of NN: 2.6870272159576416, optimal_agent: 2.6302017696673534
output of NN: -0.3177803158760071, optimal_agent: -0.3068511696333901
output of NN: -4.307676792144775, optimal_agent: -4.4083861554982215
output of NN: 5.051145553588867, optimal_agent: 5.059464003812652
output of NN: 2.3222813606262207, optimal_agent: 2.2582816962598753
output of NN: 5.619335651397705, optimal_agent: 5.574495854640217
output of NN: -0.19612835347652435, optimal_agent: -0.19087467971287464
output of NN: -1.0069634914398193, optimal_agent: -0.969281518347475
output of NN: -3.9288570880889893, optimal_agent: -3.9722777179688427
output of NN: -1.5839956998825073, optimal_agent: -1.529922376814148
output of NN: 0.2793991267681122, optimal_agent: 0.26432137170362346
output of NN: -1.157780408859253, optimal_agent: -1.1148957960267518
output of NN: -2.0784764289855957, optimal_agent: -2.0193324619481254
output of NN: -0.786963701248169, optima

output of NN: -1.5326799154281616, optimal_agent: -1.4797462162616501
output of NN: 0.5355321168899536, optimal_agent: 0.5094499524328986
output of NN: -0.3527750074863434, optimal_agent: -0.3405607722447731
output of NN: 1.0077993869781494, optimal_agent: 0.9634428362478031
output of NN: 4.512246131896973, optimal_agent: 4.485256395547662
output of NN: 2.006504774093628, optimal_agent: 1.9414191303195336
output of NN: 2.50728178024292, optimal_agent: 2.4459129278091787
output of NN: -0.3478791117668152, optimal_agent: -0.3358159932483377
output of NN: 0.20368468761444092, optimal_agent: 0.19167959384995958
output of NN: 0.1134360209107399, optimal_agent: 0.10545623833003859
output of NN: 3.3904573917388916, optimal_agent: 3.371457498024279
output of NN: -0.18476557731628418, optimal_agent: -0.17962064641820952
output of NN: -2.90447735786438, optimal_agent: -2.861686753383645
output of NN: -0.07801225036382675, optimal_agent: -0.07745276817308144
output of NN: 1.9965373277664185, opti

output of NN: 2.6609458923339844, optimal_agent: 2.603393819839925
output of NN: -0.523228645324707, optimal_agent: -0.5035286442832732
output of NN: 2.6336309909820557, optimal_agent: 2.5753971117250223
output of NN: -2.40065860748291, optimal_agent: -2.3435837689361496
output of NN: 0.47562655806541443, optimal_agent: 0.45226694604630713
output of NN: 2.4885880947113037, optimal_agent: 2.427235982856063
output of NN: 2.9424452781677246, optimal_agent: 2.895712992192726
output of NN: -1.5086745023727417, optimal_agent: -1.4562268823300684
output of NN: 2.030013084411621, optimal_agent: 1.9650356421874762
output of NN: -4.456619739532471, optimal_agent: -4.5850087340998815
output of NN: 1.1564528942108154, optimal_agent: 1.1077436177720235
output of NN: 2.074678659439087, optimal_agent: 2.009481174382292
output of NN: -1.8477380275726318, optimal_agent: -1.7898399059348509
output of NN: 0.389963835477829, optimal_agent: 0.3700660709267458
output of NN: -1.2562065124511719, optimal_agen