# 強化学習で株取引　その１ 

今回は環境の作成もかねて，一つの銘柄の取引を強化学習で行う．予測とはアルゴリズムを分離するため，未来の株価も含めるように調整する．

In [1]:
import sys
sys.path.append(r"E:\システムトレード入門\tutorials\rl\pfrl")
sys.path.append(r"E:\システムトレード入門\trade_system_git_workspace")

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from numpy.random import RandomState
import pandas as pd
from tqdm.notebook import tqdm
from collections import namedtuple
import collections
from copy import deepcopy
import matplotlib.pyplot as plt

In [3]:
import gym
from gym import spaces, logger
from gym.utils import seeding

In [4]:
import datetime
from pytz import timezone

In [5]:
from pathlib import Path

In [6]:
from bokeh.io import output_notebook
output_notebook()

In [7]:
import pfrl

In [8]:
from get_stock_price import StockDatabase

In [9]:
from utils import middle_sample_type_with_check, get_previous_datetime, get_next_workday_jp
from utils import extract_workdays_intraday_jp_index, extract_workdays_intraday_jp

In [10]:
from utils import py_restart

In [11]:
from visualize_trading_process_ver2 import plot_trading_process_matplotlib, plot_trading_process_bokeh

In [12]:
from envs_ver4 import OneStockEnv, make_env

### データベース 

In [13]:
db_path = Path("E:/システムトレード入門/trade_system_git_workspace/db/sub_stock_db") / Path("sub_stock.db")
stock_db = StockDatabase(db_path)

### 利用するデータ

In [14]:
jst_timezone = timezone("Asia/Tokyo")
start_datetime = jst_timezone.localize(datetime.datetime(2020,11,1,0,0,0))
end_datetime = jst_timezone.localize(datetime.datetime(2020,12,1,0,0,0))
#end_datetime = get_next_workday_jp(start_datetime, days=11)  # 営業日で一週間(5日間)

#stock_names = "4755"
#stock_names = "9984"
stock_names = "6502"
#stock_names = ["6502","4755"]
#stock_list = ["4755","9984","6701","7203","7267"]

use_ohlc="Close"

### 環境クラス 

In [15]:
initial_cash = 10.e6  # 種銭：100万円
initial_unit = 100  # 初期単元数

freq_str = "5T"
episode_length = 12*5*7  # 1週間

#state_time_list = [0,1,12,12*3,12*5,12*5*3],  # [現在，次時刻，一時間後，3時間後，5時間後(1日後), 15時間後(3日後)]
state_time_list = [0,
                   1,
                   2,
                   6,
                   12,
                   12*2,
                   12*3,
                   12*4,
                   12*5*1,
                   12*5*2,
                   12*5*3,
                   12*5*4,
                   12*5*5,
                   ]  # 現在，5分後, 10分後, 30分後, 1時間後, 2時間後, 3時間後, 4時間後, 1日後, 2日後, 3日後, 4日後, 5日後, 6日後, 7日後

one_unit_stocks = 20
max_units_number = 5
stay_penalty_unit_bound=30
stay_penalty_cash_bound = 1.e5
penalty_mcp_np_diff_bound = 3


env = OneStockEnv(stock_db,
                  stock_names=stock_names,
                  start_datetime=start_datetime,
                  end_datetime=end_datetime,
                  freq_str="5T",
                  episode_length=episode_length,  # 一週間
                  state_time_list=state_time_list,
                  use_ohlc=use_ohlc,  # 終値を使う
                  initial_cash=initial_cash,  # 種銭
                  initial_unit=initial_unit,
                  use_view=False,
                  one_unit_stocks=one_unit_stocks,  # 独自単元株数
                  max_units_number=max_units_number,  # 一度に売買できる独自単元数
                  low_limmit=1.e4,  # 全財産がこの値以下になれば終了
                  interpolate=True,
                  stay_penalty_unit_bound=stay_penalty_unit_bound,  # このunit数以下の場合のstayはペナルティ
                  stay_penalty_cash_bound=stay_penalty_cash_bound,  # このcash以下の場合のstayはペナルティ
                  penalty_mcp_np_diff_bound=penalty_mcp_np_diff_bound
                 )

In [16]:
#env.reset(select_datetime=jst_timezone.localize(datetime.datetime(2020, 11, 4, 14, 0)), select_stock_name="4755")
env.reset()

(StockState(cash=10000000.0, unit_number=100, mean_cost_price=2804.9044568037407, all_property=15594000.0, price_array=array([2797., 2792., 2790., 2765., 2748., 2802., 2770., 2803., 2748.,
        2722., 2706., 2823., 2837.])),
 0,
 False,
 {'datetime': datetime.datetime(2020, 11, 11, 13, 0, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>),
  'stock_name': '6502',
  'done_active': None,
  'iter_counter': 1,
  'penalty': 0,
  'gein_reward': 0,
  'price_reward': 0,
  'all_property_reward': 0,
  'reward': 0})

### 環境の並列化

In [20]:
env_number = 4
batch_env = pfrl.envs.MultiprocessVectorEnv([make_env for _ in range(env_number)])

### 前処理用のクラス

In [29]:
state_transform = NormalizeState(cash_const=initial_cash,
                                 unit_const=100,
                                 price_const=1.e4,
                                 all_property_const=5*initial_cash
                                )

reward_transform = NormalizeReward(reward_const=1.e5,
                                  )

### モデルの定義 

In [None]:
class PolicyValueModel(nn.Module):
    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.fc1 = nn.Linear(obs_size, 32)
        self.bn1 = nn.BatchNorm1d(32)
        
        self.fc2 = nn.Linear(32, 128)
        self.bn2 = nn.BatchNorm1d(128)
        
        self.fc3 = nn.Linear(128, 256)
        self.bn3 = nn.BatchNorm1d(256)
        
        self.br1_fc1 = nn.Linear(256, 50)
        self.br1_bn1 = nn.BatchNorm1d(50)
        
        self.br1_fc2 = nn.Linear(50, n_actions)
        self.br1_bn2 = nn.BatchNorm1d(n_actions)
        self.policy_head = pfrl.policies.SoftmaxCategoricalHead()
        
        self.br2_fc1 = nn.Linear(256, 50)
        self.br2_bn1 = nn.BatchNorm1d(50)
        self.value_head = nn.Linear(50, 1)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        branch = F.relu(self.bn3(self.fc3(x)))
        
        policy_x = F.relu(self.br1_bn1(self.br1_fc1(branch)))
        policy_x = F.relu(self.br1_bn2(self.br1_fc2(policy_x))
        out_policy = self.policy_head(policy_x)
        
        
        value_x = F.relu(self.br2_bn1(self.br2_fc1(branch))) 
        out_value = self.value_head(value_x)  # exampleの実装では最終層の活性化関数は無い
        
        return out_policy, out_value

In [32]:
obs_size = env.observation_space.low.size
action_dim = env.action_space.n

policy_value_model = PolicyValueModel(obs_size, n_action)

observation size: 17
action size: 11


### 学習係数の減衰 

In [33]:
def extinction_lr(episode_number):
    if episode_number < 200:
        return 1.0
    elif episode_number < 400:
        return 0.8
    elif episode_number < 600:
        return 0.6
    elif episode_number < 800:
        return 0.4
    elif episode_number < 1000:
        return 0.2
    elif episode_number < 1200:
        return 0.1
    elif episode_number < 1400:
        return 0.08
    elif episode_number < 1600:
        return 0.04
    elif episode_number < 1800:
        return 0.02
    else:
        return 0.01

### エージェントの定義 

今回はDobuleQ

In [None]:
a2c_opt = pfrl.optimizers.RMSpropEpsInsideSqrt(
        policy_value_model.parameters(),
        lr=7e-4,
        eps=1e-5,
        alpha=0.99,
    )

gamma = 0.99

update_steps = 10

phi = lambda x: x.astype(np.float32, copy=False)

use_gae = False

tau = 0.95

max_grad_norm = 40

gpu = -1

num_processes = process_number


a2c_agent = pfrl.agents.A2C(
    policy_value_model,
    a2c_opt,
    gamma=gamma,
    gpu=gpu,
    num_processes=process_number,
    update_steps=update_steps,
    phi=phi,
    use_gae=use_gae,
    tau=tau,
    max_grad_norm=max_grad_norm,
)

In [34]:
gamma = 0.95

init_episilon = 0.3
init_explorer = pfrl.explorers.ConstantEpsilonGreedy(epsilon=init_episilon,
                                                random_action_func=env.action_space.sample
                                               )

#replay_buffer = pfrl.replay_buffers.ReplayBuffer(capacity=10**6)

def phi_func(observe):
    observe_array = observe.to_numpy()
    return observe_array.astype(np.float32, copy=False)


phi = phi_func

#gpu = 0 # -1 is cpu
gpu = -1

def initialize_agent():  # agentのアトリビュートの直接的な変更を防ぐため
    q_func = QFunction(obs_size, n_actions)  # Q関数の初期化
    optimizer = torch.optim.Adam(q_func.parameters(), eps=1e-4) # optimizerの初期化
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=extinction_lr)  # schedulerの初期化
    replay_buffer = pfrl.replay_buffers.ReplayBuffer(capacity=10**6)  # バッファの初期化

    agent = pfrl.agents.DoubleDQN(
        q_function=q_func,
        optimizer=optimizer,
        replay_buffer=replay_buffer,
        gamma=gamma,
        explorer=init_explorer,
        replay_start_size=500,
        update_interval=1,
        target_update_interval=100,
        phi=phi,
        gpu=gpu
    )
    return agent, scheduler

agent, scheduler = initialize_agent()

### モデルのロード

In [35]:
UseLoad=False
if UseLoad:
    folder_name = "2020_12_26__22_59_52"
    load_path = Path("agents") / Path(folder_name)
    agent.load(load_path)

###  イプシロンの減衰

In [36]:
def extinction_epsilon(episode_number):
    if episode_number < 200:
        return 0.9
    elif episode_number < 400:
        return 0.8
    elif episode_number < 600:
        return 0.7
    elif episode_number < 800:
        return 0.6
    elif episode_number < 1000:
        return 0.5
    elif episode_number < 1200:
        return 0.4
    elif episode_number < 1400:
        return 0.3
    elif episode_number < 1600:
        return 0.2
    else:
        return 0.1

### ランダムシードの設定

In [37]:
class SeedSetter():
    def __init__(self, env):
        self.env = env
        self.initial_seed = np.random.randint(0,1000)
        self.mcp_seed = np.random.randint(0,1000)
    
    def initialize(self, initial_seed=None, mcp_seed=None):
        if initial_seed is None:
            self.initial_seed = np.random.randint(0,1000)
        else:
            self.initial_seed = initial_seed
        
        if mcp_seed is None:
            self.mcp_seed = np.random.randint(0,1000)
        else:
            self.mcp_seed = mcp_seed

    def set_seed(self, episode_number):
        if episode_number  < 1000:
            self.env.seed(self.initial_seed, self.mcp_seed)
        elif episode_number < 3000:
            self.env.seed(self.initial_seed, None)
        #else:
        #    self.env.seed(None, None)
    
    def get_seed(self):
        return self.initial_seed, self.mcp_seed

In [38]:
seed_setter = SeedSetter(env)

### 初期探索の終了判定クラス

In [39]:
def visualize_unit_number_matplotlib(title_text,
                                     x, 
                                     unit_number_array,
                                     ransac_solution,
                                     sampled_x_tensor,
                                     b_beta_hat, 
                                     sampled_x_tensor_flatten, 
                                     sampled_unit_number_tensor_flatten,
                                     ransac_solution_voted_points_bool,
                                     save_path=None,
                                    ):
    
        
    fig, ax = plt.subplots(figsize=(20,4))
    ax.set_title(title_text)
                                     
    ax.bar(x, unit_number_array, color="blue",zorder=1)
    # ransac結果の直線の描画
    x_min = 0
    x_max = len(x)
    X_start_end = np.array([[x_min, 1],[x_max,1]])
    y_hat_start_end = np.dot(X_start_end, ransac_solution)
    ax.plot([x_min, x_max],y_hat_start_end, color="limegreen")
    
    # サンプリングしたデータの直線の描画
    # 同時に描画できるけど面倒だからfor文
    for i in range(len(b_beta_hat)):
        one_beta_hat = b_beta_hat[i]
        one_beta_hat_squeezed = torch.squeeze(one_beta_hat,dim=1).numpy()


        # 描画領域のxを取得
        sampled_points_x = sampled_x_tensor[i]
        sampled_points_x_min = torch.min(sampled_points_x).item()
        sampled_points_x_max = torch.max(sampled_points_x).item()
        one_X_start_end = np.array([[sampled_points_x_min, 1],[sampled_points_x_max, 1]])
        one_y_hat_start_end = np.dot(one_X_start_end, one_beta_hat_squeezed)

        ax.plot([sampled_points_x_min, sampled_points_x_max], one_y_hat_start_end, color="black",zorder=2)
        
    ax.scatter(sampled_x_tensor_flatten.numpy().astype(int)[~ransac_solution_voted_points_bool],
               sampled_unit_number_tensor_flatten.numpy()[~ransac_solution_voted_points_bool], color="pink",zorder=3)
    ax.scatter(sampled_x_tensor_flatten.numpy().astype(int)[ransac_solution_voted_points_bool],
               sampled_unit_number_tensor_flatten.numpy()[ransac_solution_voted_points_bool],color="red",zorder=3)
    
    if save_path is not None:
        fig.savefig(save_path,bbox_inches='tight', pad_inches=0)
        plt.close()


class RansacGradLearningDecider():
    def __init__(self, line_number=10, point_number=2, distance_th=0.2, decision_rate=0.7, grad_abs_limit=20):
        self.line_number = line_number
        self.point_number = point_number
        self.distance_th = distance_th
        self.decision_rate = decision_rate
        self.grad_abs_limit = grad_abs_limit
        self.save_path = None
        
    def set_save_path(self, save_path):
        self.save_path = save_path
        
    def decide(self, state_list, info_list, env):
        """
        うまく学習できているかどうか判定
        """
        # unit_number の arrayを取得
        unit_number_array = np.array(list(map(lambda state: state.unit_number, state_list)))
        unit_number_tensor = torch.from_numpy(unit_number_array)


        x = np.arange(0, len(unit_number_array))
        x_tensor = torch.from_numpy(x).float()

        # line_number の数だけサンプリング(重複なし)
        random_index = np.array([np.random.permutation(len(unit_number_array))[:self.point_number] for i in range(self.line_number)])  # 重複なく(.line_number, point_number)のインデックスを取得

        sampled_unit_number_tensor = torch.from_numpy(unit_number_array[random_index]).float()  # (line_number(m), point_number(s))
        sampled_x_tensor = torch.from_numpy(x[random_index]).float()  # (line_number(m), point_number(s))

        sampled_unit_number_tensor_flatten = sampled_unit_number_tensor.flatten()  # (m*n)ベクトルとして利用する場合
        sampled_x_tensor_flatten  = sampled_x_tensor.flatten()  # (m*n)ベクトルとして利用する場合


        # バッチの最小二乗解
        b_X = torch.stack([sampled_x_tensor, torch.ones_like(sampled_x_tensor).float()],dim=2)
        b_XtX = torch.bmm(b_X.transpose(1,2),b_X)
        b_inv_XtX = torch.inverse(b_XtX)
        b_Xty = torch.bmm(b_X.transpose(1,2), sampled_unit_number_tensor[:,:,None])

        b_beta_hat = torch.bmm(b_inv_XtX, b_Xty)  # (line_number(m),2,1) であることに注意

        b_beta_hat_squeeze = b_beta_hat.squeeze()  # (line_number(m),2)

        # 距離の計算
        a = b_beta_hat_squeeze[:,0]  # (line_number(m))
        c = b_beta_hat_squeeze[:,1]  # (line_number(m))
        d_num = a[:,None] * sampled_x_tensor_flatten[None,:]  - sampled_unit_number_tensor_flatten[None,:] + c[:,None]  # (line_number(m), all_point_number(n))
        d_den = torch.sqrt(a**2+(-1)**2)  # (line_number(m))
        d = torch.abs(d_num) / d_den[:,None]

        distance_boolean = d < self.distance_th  # (line_number(m), all_point_number(n))

        b_vote_number = distance_boolean.sum(dim=1)  # (line_number(m))
        vote_max = np.amax(b_vote_number.numpy(), axis=0)  # 必要なのは最大投票数

        ransac_solution_index = np.argmax(b_vote_number.numpy(),axis=0)  # 正確な比較のためにnumpyで計算
        ransac_solution_voted_points_bool = distance_boolean[ransac_solution_index].numpy()  # 距離の閾値を満たす点のブール 

        ransac_solution = b_beta_hat_squeeze[ransac_solution_index].numpy()
        # 最大投票数の割合によって判定
        if vote_max / (self.line_number * self.point_number) < self.decision_rate:  # 近い割合がdecision_rateより低い場合
            if abs(ransac_solution[0]) < self.grad_abs_limit:  # 傾きが大きすぎない
                return_bool = True
            else:
                return_bool = False
        else:
            return_bool = False
            
        if self.save_path is not None:
            visualize_unit_number_matplotlib("voted_ratio:{}".format(vote_max/(self.line_number*self.point_number)),
                                             x,
                                             unit_number_array,
                                             ransac_solution,
                                             sampled_x_tensor,
                                             b_beta_hat,
                                             sampled_x_tensor_flatten,
                                             sampled_unit_number_tensor_flatten,
                                             ransac_solution_voted_points_bool,
                                             save_path=self.save_path
                                            )
        
        return return_bool
    

class UnitMinusDecider():
    def __init__(self, minus_ratio=0.2):
        self.minus_ratio = minus_ratio
    def decide(self, state_list, info_list, env):
        # stock_number の arrayを取得
        stock_number_array = np.array(list(map(lambda state: state.unit_number*env.one_unit_stocks, state_list)))
        minus_bool_array = stock_number_array < 0
        if minus_bool_array.sum()/len(stock_number_array) < self.minus_ratio:
            return True
        else:
            return False
        
class DeciderComposer():
    def __init__(self, decider_list):
        self.decider_list = decider_list
        
    def decide(self, state_list, info_list, env):
        bool_list = [decider.decide(state_list, info_list, env) for decider in self.decider_list]
        if all(bool_list):
            return True
        else:
            return False

### プロファイリング

#### 学習のための関数 

In [40]:
def episode(env, agent, state_transform=None, reward_transform=None, print_span=None, is_observe=True):
    state_list = []
    info_list = []
    action_list = []
    
    obs,_,_,info = env.reset()

    state_list.append(obs)
    info_list.append(info)
    R = 0
    t = 1
    if print_span is not None:
        print("\tt:{},all_property:{}, unit_number:{}, price:{}, penalty:{}, cash:{}".format(t,
                                                                                             obs.all_property,
                                                                                             obs.unit_number,
                                                                                             obs.now_price,
                                                                                             info["penalty"],
                                                                                             obs.cash
                                                                                            ))
    
    if state_transform is not None:
        normalized_obs = state_transform(obs)
    else:
        normalized_obs = obs

    while True:
        action = agent.act(normalized_obs)
        action_list.append(action)
        obs, reward, done, info = env.step(action)
        R += reward
        t += 1
        reset = False

        # state, rewardの前処理
        if state_transform is not None:
            normalized_obs = state_transform(obs)
        else:
            normalized_obs = obs
        if reward_transform is not None:
            normalized_reward = reward_transform(reward)
        else:
            normalized_reward = reward

        if is_observe:  # 観測(学習)する場合
            agent.observe(normalized_obs, normalized_reward, done, reset)

        state_list.append(obs)
        info_list.append(info)

        if done or reset:
            break
        if print_span is not None:
            if t%print_span==0:
                print("\tt:{},all_property:{}, unit_number:{}, price:{}, penalty:{}, cash:{}".format(t,
                                                                                                     obs.all_property,
                                                                                                     obs.unit_number,
                                                                                                     obs.now_price,
                                                                                                     info["penalty"],
                                                                                                     obs.cash
                                                                                                    ))
                print("\taction_counter:",collections.Counter(action_list))
    
    if print_span is not None:
        print("\tt:{},all_property:{}, unit_number:{}, price:{}, penalty:{}, cash:{}".format(t,
                                                                                             obs.all_property,
                                                                                             obs.unit_number,
                                                                                             obs.now_price,
                                                                                             info["penalty"],
                                                                                             obs.cash
                                                                                            ))
        print("\taction_counter:",collections.Counter(action_list))
        print("finished. episode length: {}".format(t))
    return state_list, info_list, action_list

In [41]:
def profile_episode():
    #episode(env, agent, state_transform=state_transform, reward_transform=None, print_span=100, is_observe=True)
    pass

#### 速度のプロファイリング

In [42]:
from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(StockState)
prf.add_module(OneStockEnv)
prf.add_function(episode)                                                                                      
prf.runcall(profile_episode)                                                                                          
prf.print_stats()

Timer unit: 1e-07 s

Total time: 0 s
File: <ipython-input-17-33450f4d8cfd>
Function: to_numpy at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
     5                                               def to_numpy(self):
     6                                                   """
     7                                                   ndarrayに変更する．最後に利用するのがいい？
     8                                                   """
     9                                                   cash_unit_mean_all_array = np.array([self.cash, self.unit_number, self.mean_cost_price, self.all_property])
    10                                                   return np.concatenate([cash_unit_mean_all_array, self.price_array.copy()], axis=0)  # コピーすることに注意

Total time: 0 s
File: <ipython-input-17-33450f4d8cfd>
Function: copy at line 25

Line #      Hits         Time  Per Hit   % Time  Line Contents
    25                                               def copy(self):
    26               

   504                                                       #prepenalty = - self.initial_cash * 0
   505                                                       prepenalty = -self.penalty_const
   506                                                       penalty += prepenalty
   507                                                   
   508                                                   prepenalty = - cash * 0.0  # 条件に関わらないペナルティ
   509                                                   penalty += prepenalty
   510                                                       
   511                                                   return penalty

Total time: 0 s
File: <ipython-input-22-ab1a00779c1a>
Function: get_mcp_np_diff_penalty at line 513

Line #      Hits         Time  Per Hit   % Time  Line Contents
   513                                               def get_mcp_np_diff_penalty(self, mean_cost_price, now_price):
   514                                                   penalty = 0
   

#### メモリのプロファイリング 

### 学習 

#### 学習がうまく進んでいるか判定 

In [43]:
lr_eval_ransac = RansacGradLearningDecider(line_number=20, point_number=4, distance_th=2, decision_rate=0.25)
lr_eval_decider = DeciderComposer([lr_eval_ransac,
                                   UnitMinusDecider(minus_ratio=0.1),
                                  ])

#### エージェント名

In [44]:
now_datetime = datetime.datetime.now()
now_str = now_datetime.strftime("%Y_%m_%d__%H_%M_%S")
agent_name = now_str

####  一次保存用のオブジェクト

In [45]:
temp_save_dict = {"agent_name":agent_name}
temp_save_dict["initial_seed"], temp_save_dict["mcp_seed"] = seed_setter.get_seed()

#### 一時保存用の設定 

In [46]:
temp_filepath = Path("training_temp.tmp")

save_funcs = [agent.save]
load_funcs = [agent.load]
func_paths = [Path("temp_agent")]

object_temp_filepath = Path("temp_save_dict")

#### 一時保存のロード 

In [52]:
counter = py_restart.enable_counter(temp_filepath, each_save=True, save_span=100)

temp_save_dict = counter.save_load_object(temp_save_dict, object_temp_filepath)
counter.save_load_funcs(save_funcs=save_funcs,
                        load_funcs=load_funcs,
                        func_paths=func_paths
                        )

In [53]:
agent_name = temp_save_dict["agent_name"]
print(agent_name)
seed_setter.initialize(initial_seed=temp_save_dict["initial_seed"],
                       mcp_seed=temp_save_dict["mcp_seed"]
                       )

2021_01_16__12_22_59


#### 画像を保存するディレクトリ

In [54]:
save_fig_dir_path = Path("tradng_process_figures") / Path(agent_name)
if not save_fig_dir_path.exists():
    save_fig_dir_path.mkdir()

#### 画像を保存するディレクトリ2 

In [55]:
save_fig_dir_path2 = Path("search_result_figures") / Path(agent_name)
if not save_fig_dir_path2.exists():
    save_fig_dir_path2.mkdir()

#### 学習 

In [None]:
# n_episodes = 300
n_episodes = 5000
n_search_episodes = 50

# うまく学習のできる乱数シードを探索
search_learning = True
search_counter = 0

# 一時保存がされていない場合初期探索を行う
if not temp_filepath.exists():
    while search_learning:
        seed_setter.initialize()
        seed_setter.set_seed(0)  # seedのセッティング
        agent, scheduler = initialize_agent()  # エージェントの初期化
        for i in range(1, n_search_episodes+1):
            print("\rsearch counter{}, i:{}".format(search_counter, i), end="")
            _,_,_ = episode(env, agent, state_transform=state_transform, reward_transform=None, print_span=None, is_observe=True)
        
        with agent.eval_mode():
            search_state_list, search_info_list, search_action_list = episode(env, agent, state_transform=state_transform, reward_transform=None, print_span=None, is_observe=False) 
        
        print("\nsearch_counter:{}".format(search_counter))
        print("action_counter:",collections.Counter(search_action_list))
        
        # 判定の描画
        save_image_path = save_fig_dir_path2 / Path("trading_process_initial_search_count_{}.png".format(search_counter))
        lr_eval_ransac.set_save_path(save_image_path)
        search_learning = not lr_eval_decider.decide(search_state_list, search_info_list, env)
        # 取引過程の描画
        save_image_path = save_fig_dir_path / Path("trading_process_initial_search_count_{}.png".format(search_counter))
        plot_trading_process_matplotlib(search_state_list,
                                search_info_list,
                                env,
                                title="stock_name:{},initial_seed:{},mcp_seed:{}".format(env.stock_name, env.initial_seed_number, env.mcp_seed_number),
                                save_path=save_image_path,
                                is_save=True,
                                )

        search_counter += 1

    # search結果のモデルを保存
    agent.save(Path("/content/gdrive/MyDrive/trading_rl/searched_agents")/Path(agent_name))

# 学習の初期化
#agent, scheduler = initialize_agent()  # エージェントの初期化

# 実際の学習
for i in counter(tqdm(range(1, n_episodes + 1))):
    # seedのセッティング
    seed_setter.set_seed(i)
    # temp_save_dictの更新
    temp_save_dict["initial_seed"], temp_save_dict["mcp_seed"] = seed_setter.get_seed()
    counter.object = temp_save_dict
    
    # epsilonの減衰
    #epsilon = extinction_epsilon(i)
    #explorer = pfrl.explorers.ConstantEpsilonGreedy(epsilon=epsilon, random_action_func=env.action_space.sample)
    
    #agent.explorer = explorer
    
    _,_,train_action_list = episode(env, agent, state_transform=state_transform, reward_transform=None, print_span=None, is_observe=True)
    scheduler.step()
    
    if i%100 == 0:
        print("episode:{}".format(i))
        print("action_counter:",collections.Counter(train_action_list))
    if i%500 == 0:
        print("statistics:", agent.get_statistics())
        
    # テスト(イプシロンを無視)
    if i%100 == 0:
        with agent.eval_mode():
            state_list, info_list, _ = episode(env, agent, state_transform=state_transform, reward_transform=None, print_span=1000, is_observe=False) 
        
        # 取引過程の描画
        save_image_path = save_fig_dir_path / Path("trading_process_i_{}.png".format(i))
        plot_trading_process_matplotlib(state_list,
                                info_list,
                                env,
                                title="stock_name:{},initial_seed:{},mcp_seed:{}".format(env.stock_name, env.initial_seed_number, env.mcp_seed_number),
                                save_path=save_image_path,
                                is_save=True,
                                )
        # モデルの保存
        agent.save(Path("/content/gdrive/MyDrive/trading_rl/agents")/Path(agent_name))


print("Finshed")

HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))

episode:700
action_counter: Counter({0: 146, 10: 116, 9: 40, 4: 24, 5: 22, 7: 18, 8: 11, 2: 11, 6: 11, 1: 10, 3: 10})
	t:1,all_property:12158000.0, unit_number:100, price:1079.0, penalty:0, cash:10000000.0
	t:420,all_property:12348280.0, unit_number:168, price:1115.0, penalty:0.0, cash:8601880.0
	action_counter: Counter({0: 172, 10: 119, 9: 70, 7: 17, 8: 13, 5: 13, 4: 9, 3: 4, 6: 1, 1: 1})
finished. episode length: 420
episode:800
action_counter: Counter({0: 141, 10: 93, 8: 45, 9: 37, 3: 18, 6: 18, 5: 16, 7: 13, 1: 13, 4: 13, 2: 12})
	t:1,all_property:12158000.0, unit_number:100, price:1079.0, penalty:0, cash:10000000.0
	t:420,all_property:12419820.0, unit_number:180, price:1115.0, penalty:0.0, cash:8405820.0
	action_counter: Counter({0: 166, 10: 134, 9: 48, 5: 40, 8: 20, 7: 3, 4: 3, 1: 2, 3: 2, 2: 1})
finished. episode length: 420
episode:900
action_counter: Counter({0: 152, 10: 118, 9: 39, 5: 22, 8: 21, 4: 13, 2: 13, 7: 13, 3: 12, 1: 10, 6: 6})
	t:1,all_property:12158000.0, unit_numb

	t:420,all_property:12418360.0, unit_number:1821, price:1115.0, penalty:-100000.0, cash:-28189940.0
	action_counter: Counter({10: 318, 9: 53, 0: 22, 8: 14, 5: 6, 1: 4, 6: 1, 7: 1})
finished. episode length: 420
episode:2600
action_counter: Counter({0: 114, 10: 91, 8: 49, 1: 46, 9: 41, 4: 17, 7: 15, 3: 14, 5: 12, 2: 12, 6: 8})
	t:1,all_property:12158000.0, unit_number:100, price:1079.0, penalty:0, cash:10000000.0
	t:420,all_property:12425140.0, unit_number:125, price:1115.0, penalty:0.0, cash:9637640.0
	action_counter: Counter({0: 149, 10: 97, 9: 60, 8: 45, 5: 30, 1: 20, 4: 8, 7: 5, 3: 3, 2: 2})
finished. episode length: 420
episode:2700
action_counter: Counter({0: 100, 10: 95, 1: 57, 9: 43, 8: 36, 3: 23, 6: 17, 5: 16, 2: 12, 4: 11, 7: 9})
	t:1,all_property:12158000.0, unit_number:100, price:1079.0, penalty:0, cash:10000000.0
	t:420,all_property:12355200.0, unit_number:134, price:1115.0, penalty:0.0, cash:9367000.0
	action_counter: Counter({1: 96, 10: 94, 0: 77, 8: 62, 9: 52, 3: 32, 4: 

###  モデルの評価

In [None]:
with agent.eval_mode():
    seed_setter.set_seed(i)
    #env.seed()
    #env.seed(0,0)
    #env.seed(0)
    #seed_setter.set_seed(0)
    state_list, info_list, _ = episode(env, agent, state_transform=state_transform, reward_transform=reward_transform, print_span=100, is_observe=False)

matplotlibで描画

In [None]:
plot_trading_process_bokeh(state_list, 
                           info_list,
                           env,
                           title="stock_name:{},initial_seed:{},mcp_seed:{}".format(env.stock_name, env.initial_seed_number, env.mcp_seed_number),
                           is_save=False)

### 環境のクローズ 

In [None]:
env.close()

### モデルの保存 

In [None]:
agent.save(Path("agents")/Path(agent_name))
print(agent_name)