In [1]:
%cd ../../..

E:\システムトレード入門\trade_system_git_workspace


In [2]:
from numpy.random import RandomState
import numpy as np
import pandas as pd
from scipy.special import softmax
from pathlib import Path
import datetime
from pytz import timezone

In [3]:
from bokeh.io import output_notebook, show, reset_output, output_file
import bokeh
output_notebook()

In [4]:
import py_workdays

In [5]:
from get_stock_price import StockDatabase

In [6]:
from portfolio.trade_transformer import PortfolioTransformer, PortfolioRestrictorIdentity, FeeCalculatorFree
from portfolio.price_supply import StockDBPriceSupplier

In [8]:
from utils import middle_sample_type_with_check, get_naive_datetime_from_datetime
from utils import TradeSysteBaseError

In [9]:
from visualization import visualize_portfolio_transform_bokeh

## グローバルパラメータ― 

In [10]:
jst = timezone("Asia/Tokyo")
start_datetime = jst.localize(datetime.datetime(2020,11,10,0,0,0))
end_datetime = jst.localize(datetime.datetime(2020,11,20,0,0,0))
ticker_number = 19
episode_length = 300
freq_str = "5T"

## 利用するデータベース

In [11]:
db_path = Path("db/sub_stock_db/nikkei_255_stock.db")
stock_db = StockDatabase(db_path)

## サンプリングするクラス 

In [12]:
class ConstSamper:
    """
    定数をサンプリング値として取得するためのクラス
    """
    def __init__(self, const_object):
        """
        const_object: any
            定数としてサンプリングされる値
        """
        self.const_object = const_object
        
    def sample(self, seed=None):
        """
        seted: int
            ランダムシード
        """
        return self.const_object

In [13]:
class TickerSampler:
    """
    ticker_nameをサンプリングするためのクラス
    """
    def __init__(self, all_ticker_names, sampling_ticker_number):
        """
        all_ticker_names: list of str
            サンプリングを行う銘柄名のリスト
        sampling_ticker_number: int
            サンプリング数
        """
        self.all_ticker_names = all_ticker_names
        self.sampling_ticker_number = sampling_ticker_number

    def sample(self, seed=None):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        Returns
        -------
        list of str 
            サンプリングされた銘柄名のndarray
        datetime.datetime
        """
        random_ticker_names = RandomState(seed).choice(self.all_ticker_names, self.sampling_ticker_number, replace=False)  # 重複を許さない
        return random_ticker_names

In [14]:
class DatetimeSampler:
    """
    datetime.datetimeをサンプリングするためのクラス
    """
    def __init__(self, start_datetime, end_datetime, episode_length, freq_str):
        """
        start_datetime: datetime.datetime
            サンプリングする範囲の開始日時
        end_datetime: datetime.datetime
            サンプリングする範囲の終了日時
        episode_length: int
            エピソード長
        freq_str: str
            サンプリング周期を表す文字列
        """
        
        self.start_datetime = start_datetime
        self.end_datetime = end_datetime

        self.freq_str = middle_sample_type_with_check(freq_str)


        all_datetime_index = pd.date_range(start=self.start_datetime,
                                           end=self.end_datetime,
                                           freq=self.freq_str,
                                           closed="left"
                                          )  
        
        self.all_datetime_value = py_workdays.extract_workdays_intraday_jp_index(all_datetime_index).to_pydatetime()
        
        self.all_datetime_index_range = np.arange(0, len(self.all_datetime_value))
        self.episode_length = episode_length
        
    def sample(self, seed=None, window=np.array([0])):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        window: list of int
            ウィンドウを表すリスト
        
        Returns
        -------
        list of str
        datetime.datetime
            サンプリングされた日時
        """
        
        if not isinstance(window, np.ndarray):
            self.window = np.array(window)
        else:
            self.window = window
        
        min_window = min(self.window)
        max_window = max(self.window)
        
        random_datetime = self.all_datetime_value[RandomState(seed).choice(self.all_datetime_index_range[abs(min_window):-(max_window+self.episode_length)],1)].item()
        
        return random_datetime      

In [15]:
class PortfolioVectorSampler:
    """
    ポートフォリオベクトルをサンプリングするためのクラス
    """
    def __init__(self, vector_length):
        """
        vector_length: int
            ポートフォリオベクトルの長さ
        """
        self.vector_length = vector_length
    def sample(self, seed=None):
        """
        seed: int
            ランダムシード
        """
        portfolio_vector = softmax(RandomState(seed).randn(self.vector_length))
        return portfolio_vector

In [16]:
class MeanCostPriceSampler:
    """
    平均取得価格をサンプリングするためのクラス
    """
    def __init__(self, mean_array=None, var_array=None):
        """
        mean_array: np.ndarray
            平均ベクトル
        var_array: np.ndarray
            分散ベクトル
        """
        self.mean_array = mean_array
        self.var_array = var_array
        
    def sample(self, seed=None):
        """
        seed: int
            ランダムシード
        """
        mean_cost_price = self.mean_array + self.var_array * RandomState(seed).randn(self.vector_length)
        return mean_cost_price

In [17]:
class SamplerManager:
    """
    TradeEnv環境のサンプリングを担うクラス．利用するサンプリングが追加・変更された場合．こちらを変更する．
    """
    def __init__(self, 
                 ticker_names_sampler,
                 datetime_sampler,
                 portfolio_vector_sampler=ConstSamper(None),
                 mean_cost_price_array_sampler=ConstSamper(None)
                ):
        """
        ticker_names_sampler:
            銘柄名をサンプリングするクラス
        datetime_sampler:
            日時をサンプリングするクラス
        portfolio_vector_samper:
            ポートフォリオベクトルをサンプリングするクラス
        mean_cost_price_array_sampler
            平均取得価格をサンプリングするクラス
        """
        
        self.ticker_names_sampler = ticker_names_sampler
        self.datetime_sampler = datetime_sampler
        self.portfolio_vector_sampler = portfolio_vector_sampler
        self.mean_cost_price_array_sampler = mean_cost_price_array_sampler

csvファイルの読み込み

In [18]:
ticker_codes_df = pd.read_csv(Path("portfolio/rl_base/nikkei225.csv"), header=0)  # 自分で作成
ticker_codes = ticker_codes_df["code"].values.astype(str).tolist()
#ticker_codes

In [19]:
ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
                                     sampling_ticker_number=ticker_number)



start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                         end_datetime=end_datetime,
                                         episode_length=episode_length,
                                         freq_str=freq_str
                                        )

sampler_manager = SamplerManager(ticker_names_sampler=ticker_names_sampler,
                                 datetime_sampler=start_datetime_sampler
                                )

In [20]:
sampler_manager.datetime_sampler.sample(window=[-5,0,5])

datetime.datetime(2020, 11, 11, 12, 30, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>)

In [21]:
sampler_manager.ticker_names_sampler.sample()

array(['8001', '9503', '9501', '3105', '9001', '1605', '8630', '9202',
       '7911', '6702', '4568', '6952', '2531', '3101', '8795', '5202',
       '6472', '8309', '9437'], dtype='<U4')

In [22]:
print(sampler_manager.mean_cost_price_array_sampler.sample())

None


## 環境クラス 

In [23]:
class TradeEnv:
    """
    PortfolioStateを利用して基本的な売買を行う強化学習用の環境
    """
    def __init__(self, 
                 portfolio_transformer,
                 sampler_manager,
                 window=[0],
                 fee_const=0.0025,
                ):
        """
        portfolio_transformer: PortfolioTransformer
             ポートフォリオを遷移させるTransformer
        sampler_manager: SamplerManager
            各種サンプリングを行うためのマネージャー
        window: list
            PortfolioStateのウィンドウサイズ
        fee_const: float
            単位報酬・取引量当たりの手数料
        """
        self.portfolio_transformer = portfolio_transformer
        self.sampler_manager = sampler_manager
        self.window = window
        self.fee_const = fee_const
        
        
        
    def reset(self, seed=None):
        """
        Parameters
        ----------
        seed: int
            乱数シード
            
        Returns
        -------
        PortfolioState
            遷移したPortfolioStateのインスタンス
        float
            報酬．resetなので0
        bool
            終了を示すフラッグ
        dict
            その他の情報
        """
        #from IPython.core.debugger import Pdb; Pdb().set_trace()
        random_ticker_names = self.sampler_manager.ticker_names_sampler.sample(seed=seed)  # 銘柄名のサンプリング
        random_datetime = self.sampler_manager.datetime_sampler.sample(seed=seed, window=self.window)  # 開始日時のサンプリング
        random_portfolio_vector = self.sampler_manager.portfolio_vector_sampler.sample(seed=seed)  # ポートフォリオベクトルのサンプリング
        random_mean_cost_price_array = self.sampler_manager.mean_cost_price_array_sampler.sample(seed=seed)  # 平均取得価格のサンプリング
        
        self.portfolio_transformer.price_supplier.ticker_names = list(random_ticker_names)  # 銘柄名の変更
        self.portfolio_transformer.initial_portfolio_vector = random_portfolio_vector  # 初期ポートフォリオを変更(Noneの場合デフォルト)
        self.portfolio_transformer.initial_mean_cost_price_array = random_mean_cost_price_array  #初期平均取得価格を変更(Noneの場合デフォルト) 
        
        portfolio_state, done = self.portfolio_transformer.reset(random_datetime, window=self.window)
        
        self.portfolio_state = portfolio_state
        
        return self.portfolio_state.copy(), 0, done, None
    
    def step(self, portfolio_vector):
        """
        Parameters
        ----------
        portfolio_vector: ndarray
            actionを意味するポートフォリオベクトル
            
        Returns
        -------
        PortfolioState
            遷移したPortfolioStateのインスタンス
        float
            報酬．
        bool
            終了を示すフラッグ
        dict
            その他の情報
        """        

        previous_portfolio_state = self.portfolio_state
        
        #状態遷移
        portfolio_state, done = self.portfolio_transformer.step(portfolio_vector)
        
        #報酬の計算
        portfolio_vector = portfolio_state.portfolio_vector
        
        price_change_ratio = portfolio_state.now_price_array / previous_portfolio_state.now_price_array  # y
        raw_reward_ratio = np.dot(portfolio_vector, price_change_ratio)  # r
        
        portfolio_change_vector = portfolio_vector - previous_portfolio_state.portfolio_vector #W_{t}-w_{t-1}
        reward = np.log(raw_reward_ratio*(1-self.fee_const*np.dot(portfolio_change_vector, portfolio_change_vector)))
        
        return portfolio_state.copy(), reward, done, None

In [24]:
# samplerの設定
ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
                                     sampling_ticker_number=ticker_number)



start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                         end_datetime=end_datetime,
                                         episode_length=episode_length,
                                         freq_str=freq_str
                                        )

sampler_manager = SamplerManager(ticker_names_sampler=ticker_names_sampler,
                                 datetime_sampler=start_datetime_sampler
                                )


# PriceSupplierの設定
price_supplier = StockDBPriceSupplier(stock_db,
                                      [],  # 最初は何の銘柄コードも指定しない
                                      episode_length,
                                      freq_str,
                                      interpolate=True
                                     )

# PortfolioTransformerの設定
portfolio_transformer = PortfolioTransformer(price_supplier,
                                             portfolio_restrictor=PortfolioRestrictorIdentity(),
                                             use_ohlc="Close",
                                             initial_all_assets=1e6,  # 学習には関係ない
                                             fee_calculator=FeeCalculatorFree()
                                            )

# TradeEnvの設定
trade_env = TradeEnv(portfolio_transformer,
                     sampler_manager,
                     window=np.arange(0,20),
                     fee_const=0.0025
                    )

In [25]:
portfolio_state,_,_,_ = trade_env.reset()
print(len(portfolio_state.names))
print(portfolio_state)

20
PortfolioState( 
names=['yen', '5803', '9107', '8802', '6472', '5332', '2501', '7201', '2871', '9766', '8697', '9104', '8750', '7012', '4188', '1808', '5019', '8028', '8001', '9613']
key_currency_index=0
window=[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
datetime=2020-11-11 09:35:00+09:00
price_array=[[1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00]
 [4.02000000e+02 4.04000000e+02 4.04000000e+02 4.02000000e+02
  4.01000000e+02 4.00000000e+02 3.98000000e+02 3.99000000e+02
  3.98000000e+02 3.98000000e+02 3.99000000e+02 4.00000000e+02
  4.00000000e+02 4.00000000e+02 4.03000000e+02 4.03000000e+02
  4.03000000e+02 4.03000000e+02 4.03000000e+02 4.02000000e+02]
 [1.48700000e+03 1.49700000e+03 1.48900000e+03 1.483000

In [26]:
portfolio_vector = softmax(np.abs(np.random.randn(1+ticker_number)))
trade_env.step(portfolio_vector)

(PortfolioState(names=['yen', '5803', '9107', '8802', '6472', '5332', '2501', '7201', '2871', '9766', '8697', '9104', '8750', '7012', '4188', '1808', '5019', '8028', '8001', '9613'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19]), datetime=datetime.datetime(2020, 11, 11, 9, 40, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [4.04000000e+02, 4.04000000e+02, 4.02000000e+02, 4.01000000e+02,
         4.00000000e+02, 3.98000000e+02, 3.99000000e+02, 3.98000000e+02,
         3.98000000e+02, 3.99000000e+02, 4.00000000e+02, 4.00000000e+0

### プロファイリング 

In [27]:
def temp_func():
    # samplerの設定
    ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
                                         sampling_ticker_number=ticker_number)



    start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                             end_datetime=end_datetime,
                                             episode_length=episode_length,
                                             freq_str=freq_str
                                            )

    sampler_manager = SamplerManager(ticker_names_sampler=ticker_names_sampler,
                                     datetime_sampler=start_datetime_sampler
                                    )


    # PriceSupplierの設定
    price_supplier = StockDBPriceSupplier(stock_db,
                                          [],  # 最初は何の銘柄コードも指定しない
                                          episode_length,
                                          freq_str,
                                          interpolate=True
                                         )

    # PortfolioTransformerの設定
    portfolio_transformer = PortfolioTransformer(price_supplier,
                                                 portfolio_restrictor=PortfolioRestrictorIdentity(),
                                                 use_ohlc="Close",
                                                 initial_all_assets=1e6,  # 学習には関係ない
                                                 fee_calculator=FeeCalculatorFree()
                                                )

    # TradeEnvの設定
    trade_env = TradeEnv(portfolio_transformer,
                         sampler_manager,
                         window=np.arange(0,20),
                         fee_const=0.0025
                        )
    
    trade_env.reset()
    for i in range(30):
        portfolio_vector = softmax(np.abs(np.random.randn(1+ticker_number)))
        trade_env.step(portfolio_vector)


from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(TradeEnv)                                                                                          
#prf.add_function()                                                                                      
prf.runcall(temp_func)                                                                                          
prf.print_stats()

Timer unit: 1e-07 s

Total time: 8.24e-05 s
File: <ipython-input-23-f3f976fcf34c>
Function: __init__ at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
     5                                               def __init__(self, 
     6                                                            portfolio_transformer,
     7                                                            sampler_manager,
     8                                                            window=[0],
     9                                                            fee_const=0.0025,
    10                                                           ):
    11                                                   """
    12                                                   portfolio_transformer: PortfolioTransformer
    13                                                        ポートフォリオを遷移させるTransformer
    14                                                   sampler_manager: SamplerManager
    15      

In [28]:
portfolio_state_list = []
reward_list = []
portfolio_state, reward, done, info = trade_env.reset()
portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
reward_list.append(reward)

while True:
    action = softmax(np.abs(np.random.randn(1+ticker_number)))
    portfolio_state, reward, done, info = trade_env.step(action)
    portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
    reward_list.append(reward)
    if done:
        break

In [29]:
visualize_portfolio_transform_bokeh(portfolio_state_list, is_save=False, is_show=True)

## 学習過程の可視化 

In [30]:
def visualize_portfolio_rl_bokeh(portfolio_state_list, reward_list, save_path=None, is_save=False, is_show=True, is_jupyter=True):
    all_datetime_array = np.array([get_naive_datetime_from_datetime(one_state.datetime) for one_state in portfolio_state_list])
    reward_array = np.array(reward_list)
    x = np.arange(0, len(portfolio_state_list))

    layout_list = visualize_portfolio_transform_bokeh(portfolio_state_list, is_save=False, is_show=False)

    add_p1 = bokeh.plotting.figure(plot_width=1200,plot_height=300,title="報酬")
    add_p1.line(x, reward_array, legend_label="reward", line_width=2, color="green")
    add_p1.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    add_p1.yaxis[0].axis_label = "報酬"

    layout_list.extend([add_p1])
    created_figure = bokeh.layouts.column(*layout_list)

    if is_save:
            if save_path.suffix == ".png":
                bokeh.io.export_png(created_figure, filename=save_path)
            elif save_path.suffix == ".html":
                output_file(save_path)
                bokeh.io.save(created_figure, filename=save_path, title="trading process")    
            else:
                raise Exception("The suffix of save_path is must be '.png' or '.html'.")
            
            return None
    if is_show:
        try:
            reset_output()
            if is_jupyter:
                output_notebook()
            show(created_figure)
        except:
            if is_jupyter:
                output_notebook()
            show(created_figure)
            
        return None
        
    if not is_save and not is_show:
        return layout_list

In [31]:
visualize_portfolio_rl_bokeh(portfolio_state_list, reward_list, is_show=True)

##  前処理用の関数

In [47]:
class StateTransformInvalidError(TradeSysteBaseError):
    """
    StateTransformで起きたエラー
    """
    def __init__(self, err_str=None):
        """
        err_str:
            エラーメッセージ
        """
        self.err_str = err_str
    def __str__(self):
        if self.err_str is None:
            return "Cannot get all data."
        else:
            return self.err_str

In [32]:
class ComposeFunction:
    """
    callabreのリスト・ディクショナリ全てを実行する．動的なアクセスのために各collableはアトリビュートとして保持する．
    """
    def __init__(self, collection):
        """
        collection: list or dict of function
            適用したいcallableなオブジェクトのリストか辞書．辞書の場合はキーはアトリビュート名とする．
        """
        if isinstance(collection, list):
            self.function_name_list = []
            for i, func in enumerate(collection):
                func_name = "func_"+str(i)
                setattr(self, func_name, func)
                self.function_name_list.append(func_name)
            
        elif isinstance(collection, dict):
            self.function_name_list = []
            for key in collection:
                setattr(self, key, collection[key])
                self.function_name_list.append(key)
                
    def __call__(self, x):
        """
        x: any
            各関数の引数
        """
        for func_name in self.function_name_list:
            x = getattr(self, func_name)(x)
            
        return x

In [33]:
class State2Feature:
    """
    最後に実行
    """
    def __call__(self, portfolio_state):
        price_array = portfolio_state.price_array
        price_portfolio = price_array * portfolio_state.portfolio_vector[:,None]
        price_mean_cost = price_array * portfolio_state.mean_cost_price_array[:,None]
        feature = np.stack([price_array, price_portfolio, price_mean_cost], axis=0)
        return feature

In [34]:
state, _,_,_ = trade_env.reset()

In [35]:
feature = State2Feature()(state)

In [36]:
feature.shape

(3, 20, 20)

In [58]:
class PriceNormalizeConst:
    """
    price_arrayを定数で割る
    """
    def __init__(self, const_array=None):
        self._const_array = const_array
        
    @property
    def const_array(self):
        return self._const_array
    
    @const_array.setter
    def const_array(self, const_array):
        if const_array is not None:
            if np.isnan(self._const_array).sum() > 0:
                raise StateTransformInvalidError("PriceNormalizeConst.const_array has nan.")
        else:
            raise StateTransformInvalidError("PriceNormalizeConst.const_array cannot be sestted None")
        self._const_array = const_array
        
    def __call__(self, portfolio_state):
        if portfolio_state.price_array.shape[0]!=self._const_array.shape[0]:
            err_str = "portfolio_state.price_array shape({}) and PriceNormalizeConst.const_array({})".format(portfolio_state.price_array.shape,
                                                                                                             self._const_array.shape
                                                                                                            )
            raise StateTransformInvalidError(err_str)
        
        new_price_array = portfolio_state.price_array / self._const_array[:,None]
        return portfolio_state._replace(price_array=new_price_array)

In [59]:
PriceNormalizeConst(np.array([100000]*(1+ticker_number)))(state)

PortfolioState(names=['yen', '7832', '5406', '8802', '6301', '5711', '6501', '6178', '6645', '9984', '2432', '4751', '3402', '5233', '4631', '9437', '7211', '4507', '8750', '5541'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), datetime=datetime.datetime(2020, 11, 12, 10, 15, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
       [8.68100000e-02, 8.70600000e-02, 8.70000000e-02, 8.71500000e-02,
        8.72400000e-02, 8.71900000e-02, 8.74500000e-02, 8.73800000e-02,
        8.71700000e-02, 8.70400000e-02, 8.71500000e-02, 8.70900000e-02,
     

In [60]:
class MeanCostPriceNormalizeConst:
    """
    mean_cost_price_arrayを定数で割る
    """
    def __init__(self, const_array):
        self._const_array = const_array
        
    @property
    def const_array(self):
        return self._const_array
    
    @const_array.setter
    def const_array(self, const_array):
        if const_array is not None:
            if np.isnan(self._const_array).sum() > 0:
                raise StateTransformInvalidError("MeanCostPriceNormalizeConst.const_array has nan.")
        else:
            raise StateTransformInvalidError("MeanCostPriceNormalizeConst.const_array cannot be sestted None")
        self._const_array = const_array
    
    def __call__(self, portfolio_state):
        if portfolio_state.mean_cost_price_array.shape[0]!=self._const_array.shape[0]:
            err_str = "portfolio_state.price_array shape({}) and MeanCostPriceNormalizeConst.const_array({})".format(portfolio_state.mean_cost_price_array,
                                                                                                             self._const_array.shape
                                                                                                            )
            raise StateTransformInvalidError(err_str)
        new_mean_cost_price = portfolio_state.mean_cost_price_array / self._const_array
        return portfolio_state._replace(mean_cost_price_array=new_mean_cost_price)

In [61]:
MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number)))(state)

PortfolioState(names=['yen', '7832', '5406', '8802', '6301', '5711', '6501', '6178', '6645', '9984', '2432', '4751', '3402', '5233', '4631', '9437', '7211', '4507', '8750', '5541'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), datetime=datetime.datetime(2020, 11, 12, 10, 15, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
       [8.68100000e+03, 8.70600000e+03, 8.70000000e+03, 8.71500000e+03,
        8.72400000e+03, 8.71900000e+03, 8.74500000e+03, 8.73800000e+03,
        8.71700000e+03, 8.70400000e+03, 8.71500000e+03, 8.70900000e+03,
     

In [56]:
transform = ComposeFunction([PriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             State2Feature()
                            ])

In [57]:
transform(state)

array([[[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [8.68100000e-02, 8.70600000e-02, 8.70000000e-02, ...,
         8.68000000e-02, 8.68300000e-02, 8.67200000e-02],
        [4.88000000e-03, 4.88000000e-03, 4.87000000e-03, ...,
         4.82000000e-03, 4.82000000e-03, 4.81000000e-03],
        ...,
        [5.66100000e-02, 5.67500000e-02, 5.67400000e-02, ...,
         5.61100000e-02, 5.59800000e-02, 5.59100000e-02],
        [1.72650000e-02, 1.72950000e-02, 1.72850000e-02, ...,
         1.71000000e-02, 1.70950000e-02, 1.70800000e-02],
        [1.77600000e-02, 1.77300000e-02, 1.77100000e-02, ...,
         1.76000000e-02, 1.77100000e-02, 1.76500000e-02]],

       [[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 

In [43]:
dir(transform)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'func_0',
 'func_1',
 'func_2',
 'function_name_list']

In [44]:
transform = ComposeFunction({"price_normalize":PriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             "mean_cost_price_normalize":MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             "state2feature":State2Feature()
                            })

In [45]:
transform(state)

array([[[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [8.68100000e-02, 8.70600000e-02, 8.70000000e-02, ...,
         8.68000000e-02, 8.68300000e-02, 8.67200000e-02],
        [4.88000000e-03, 4.88000000e-03, 4.87000000e-03, ...,
         4.82000000e-03, 4.82000000e-03, 4.81000000e-03],
        ...,
        [5.66100000e-02, 5.67500000e-02, 5.67400000e-02, ...,
         5.61100000e-02, 5.59800000e-02, 5.59100000e-02],
        [1.72650000e-02, 1.72950000e-02, 1.72850000e-02, ...,
         1.71000000e-02, 1.70950000e-02, 1.70800000e-02],
        [1.77600000e-02, 1.77300000e-02, 1.77100000e-02, ...,
         1.76000000e-02, 1.77100000e-02, 1.76500000e-02]],

       [[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 

In [46]:
dir(transform)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'function_name_list',
 'mean_cost_price_normalize',
 'price_normalize',
 'state2feature']