In [1]:
import sys

### pfrlライブラリのパスへの追加

In [2]:
sys.path.append(r"E:\システムトレード入門\tutorials\rl\pfrl")

### インポート 

In [3]:
import pfrl
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np

In [4]:
from tqdm.notebook import tqdm

### 環境の作成 

今回はカートポールの環境を利用する．状態は [カートの位置，加速度，ポールの角度，ポールの角速度] の連続値．行動は左右のどちらへ移動するかの離散値

In [5]:
env = gym.make("CartPole-v0")
print("observation space:", env.observation_space)
print("action space:", env.action_space)

observation space: Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
action space: Discrete(2)


#### 最初の状態を観測 

In [6]:
obs = env.reset()
print("initial observation:", obs)

initial observation: [ 0.03072463  0.02555485  0.0458568  -0.02463818]


#### 適当な行動からObservation, rewardを取得 

In [7]:
action = env.action_space.sample()
obs, r, done, info = env.step(action)
print("next obserbation:", obs)
print("reward:", r)
print("done:", done)
print("info:", info)

next obserbation: [ 0.03123572  0.2199902   0.04536404 -0.30250743]
reward: 1.0
done: False
info: {}


### Q関数の定義 

In [8]:
class MyQFunction(nn.Module):
    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.fc1 = nn.Linear(obs_size, 50)
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, n_actions)
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        out = pfrl.action_value.DiscreteActionValue(x)
        return out

In [9]:
obs_size = env.observation_space.low.size
print("observation size:", obs_size)
n_actions = env.action_space.n
print("action size:",n_actions)
q_func = MyQFunction(obs_size, n_actions)

observation size: 4
action size: 2


### エージェントの作成 

In [10]:
optimizer = torch.optim.Adam(q_func.parameters(), eps=1e-2)

In [11]:
gamma = 0.9

explorer = pfrl.explorers.ConstantEpsilonGreedy(epsilon=0.3,
                                                random_action_func=env.action_space.sample
                                               )

replay_buffer = pfrl.replay_buffers.ReplayBuffer(capacity=10**6)

phi = lambda x: x.astype(np.float32, copy=False)

gpu = -1 # -1 is cpu

agent = pfrl.agents.DoubleDQN(
    q_function=q_func,
    optimizer=optimizer,
    replay_buffer=replay_buffer,
    gamma=gamma,
    explorer=explorer,
    replay_start_size=500,
    update_interval=1,
    target_update_interval=100,
    phi=phi,
    gpu=gpu
)

### 学習のイテレーション 

In [12]:
n_episodes = 300  # エピソードの回数
max_episode_len = 200
for i in tqdm(range(1, n_episodes + 1)):
    obs = env.reset()  # 観測のリセット
    R = 0  # Return (sum ofrewards)
    t = 0  # time step
    while True:
        action = agent.act(obs)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
        reset = t == max_episode_len
        agent.observe(obs, reward, done, reset)
        if done or reset:
            break
    
    if i%10 == 0:
        print("episode:{}, return:{}".format(i, R))
    if i%50 == 0:
        print("statistics:", agent.get_statistics())
        
print("Finshed")

HBox(children=(FloatProgress(value=0.0, max=300.0), HTML(value='')))

episode:10, return:9.0
episode:20, return:8.0
episode:30, return:13.0
episode:40, return:11.0
episode:50, return:13.0
statistics: [('average_q', 0.8225266), ('average_loss', 0.19188212599144894), ('cumulative_steps', 588), ('n_updates', 89), ('rlen', 588)]
episode:60, return:11.0
episode:70, return:10.0
episode:80, return:21.0
episode:90, return:11.0
episode:100, return:11.0
statistics: [('average_q', 5.2032013), ('average_loss', 0.21298901025205852), ('cumulative_steps', 1251), ('n_updates', 752), ('rlen', 1251)]
episode:110, return:18.0
episode:120, return:39.0
episode:130, return:62.0
episode:140, return:19.0
episode:150, return:68.0
statistics: [('average_q', 9.157737), ('average_loss', 0.2495987607515417), ('cumulative_steps', 3228), ('n_updates', 2729), ('rlen', 3228)]
episode:160, return:200.0
episode:170, return:196.0
episode:180, return:200.0
episode:190, return:200.0
episode:200, return:200.0
statistics: [('average_q', 10.138581), ('average_loss', 0.11087872546864673), ('cumu

In [13]:
agent.save("agent")