In [1]:
%cd ../../..

E:\システムトレード入門\trade_system_git_workspace


In [2]:
from numpy.random import RandomState
import numpy as np
import pandas as pd
from scipy.special import softmax
from pathlib import Path
import datetime
from pytz import timezone

In [3]:
from bokeh.io import output_notebook, show, reset_output, output_file
import bokeh
output_notebook()

In [4]:
import py_workdays

In [5]:
from get_stock_price import StockDatabase

In [6]:
from portfolio.trade_transformer import PortfolioTransformer, PortfolioRestrictorIdentity, FeeCalculatorFree
from portfolio.price_supply import StockDBPriceSupplier

In [7]:
from utils import middle_sample_type_with_check, get_naive_datetime_from_datetime

In [8]:
from visualization import visualize_portfolio_transform_bokeh

## グローバルパラメータ― 

In [9]:
jst = timezone("Asia/Tokyo")
start_datetime = jst.localize(datetime.datetime(2020,11,10,0,0,0))
end_datetime = jst.localize(datetime.datetime(2020,11,20,0,0,0))
ticker_number = 19
episode_length = 300
freq_str = "5T"

## 利用するデータベース

In [10]:
db_path = Path("db/sub_stock_db/nikkei_255_stock.db")
stock_db = StockDatabase(db_path)

## サンプリングするクラス 

In [11]:
class ConstSamper:
    """
    定数をサンプリング値として取得するためのクラス
    """
    def __init__(self, const_object):
        """
        const_object: any
            定数としてサンプリングされる値
        """
        self.const_object = const_object
        
    def sample(self, seed=None):
        """
        seted: int
            ランダムシード
        """
        return self.const_object

In [12]:
class TickerSampler:
    """
    ticker_nameをサンプリングするためのクラス
    """
    def __init__(self, all_ticker_names, sampling_ticker_number):
        """
        all_ticker_names: list of str
            サンプリングを行う銘柄名のリスト
        sampling_ticker_number: int
            サンプリング数
        """
        self.all_ticker_names = all_ticker_names
        self.sampling_ticker_number = sampling_ticker_number

    def sample(self, seed=None):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        Returns
        -------
        list of str 
            サンプリングされた銘柄名のndarray
        datetime.datetime
        """
        random_ticker_names = RandomState(seed).choice(self.all_ticker_names, self.sampling_ticker_number, replace=False)  # 重複を許さない
        return random_ticker_names

In [13]:
class DatetimeSampler:
    """
    datetime.datetimeをサンプリングするためのクラス
    """
    def __init__(self, start_datetime, end_datetime, episode_length, freq_str):
        """
        start_datetime: datetime.datetime
            サンプリングする範囲の開始日時
        end_datetime: datetime.datetime
            サンプリングする範囲の終了日時
        episode_length: int
            エピソード長
        freq_str: str
            サンプリング周期を表す文字列
        """
        
        self.start_datetime = start_datetime
        self.end_datetime = end_datetime

        self.freq_str = middle_sample_type_with_check(freq_str)


        all_datetime_index = pd.date_range(start=self.start_datetime,
                                           end=self.end_datetime,
                                           freq=self.freq_str,
                                           closed="left"
                                          )  
        
        self.all_datetime_value = py_workdays.extract_workdays_intraday_jp_index(all_datetime_index).to_pydatetime()
        
        self.all_datetime_index_range = np.arange(0, len(self.all_datetime_value))
        self.episode_length = episode_length
        
    def sample(self, seed=None, window=np.array([0])):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        window: list of int
            ウィンドウを表すリスト
        
        Returns
        -------
        list of str
        datetime.datetime
            サンプリングされた日時
        """
        
        if not isinstance(window, np.ndarray):
            self.window = np.array(window)
        else:
            self.window = window
        
        min_window = min(self.window)
        max_window = max(self.window)
        
        random_datetime = self.all_datetime_value[RandomState(seed).choice(self.all_datetime_index_range[abs(min_window):-(max_window+self.episode_length)],1)].item()
        
        return random_datetime      

In [14]:
class PortfolioVectorSampler:
    """
    ポートフォリオベクトルをサンプリングするためのクラス
    """
    def __init__(self, vector_length):
        """
        vector_length: int
            ポートフォリオベクトルの長さ
        """
        self.vector_length = vector_length
    def sample(self, seed=None):
        """
        seed: int
            ランダムシード
        """
        portfolio_vector = softmax(RandomState(seed).randn(self.vector_length))
        return portfolio_vector

In [15]:
class MeanCostPriceSampler:
    """
    平均取得価格をサンプリングするためのクラス
    """
    def __init__(self, mean_array=None, var_array=None):
        """
        mean_array: np.ndarray
            平均ベクトル
        var_array: np.ndarray
            分散ベクトル
        """
        self.mean_array = mean_array
        self.var_array = var_array
        
    def sample(self, seed=None):
        """
        seed: int
            ランダムシード
        """
        mean_cost_price = self.mean_array + self.var_array * RandomState(seed).randn(self.vector_length)
        return mean_cost_price

In [16]:
class SamplerManager:
    """
    TradeEnv環境のサンプリングを担うクラス．利用するサンプリングが追加・変更された場合．こちらを変更する．
    """
    def __init__(self, 
                 ticker_names_sampler,
                 datetime_sampler,
                 portfolio_vector_sampler=ConstSamper(None),
                 mean_cost_price_array_sampler=ConstSamper(None)
                ):
        """
        ticker_names_sampler:
            銘柄名をサンプリングするクラス
        datetime_sampler:
            日時をサンプリングするクラス
        portfolio_vector_samper:
            ポートフォリオベクトルをサンプリングするクラス
        mean_cost_price_array_sampler
            平均取得価格をサンプリングするクラス
        """
        
        self.ticker_names_sampler = ticker_names_sampler
        self.datetime_sampler = datetime_sampler
        self.portfolio_vector_sampler = portfolio_vector_sampler
        self.mean_cost_price_array_sampler = mean_cost_price_array_sampler

csvファイルの読み込み

In [17]:
ticker_codes_df = pd.read_csv(Path("portfolio/rl_base/nikkei225.csv"), header=0)  # 自分で作成
ticker_codes = ticker_codes_df["code"].values.astype(str).tolist()
#ticker_codes

In [18]:
ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
                                     sampling_ticker_number=ticker_number)



start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                         end_datetime=end_datetime,
                                         episode_length=episode_length,
                                         freq_str=freq_str
                                        )

sampler_manager = SamplerManager(ticker_names_sampler=ticker_names_sampler,
                                 datetime_sampler=start_datetime_sampler
                                )

In [19]:
sampler_manager.datetime_sampler.sample(window=[-5,0,5])

datetime.datetime(2020, 11, 10, 13, 45, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>)

In [20]:
sampler_manager.ticker_names_sampler.sample()

array(['5802', '6506', '8002', '7735', '3436', '9502', '8035', '2502',
       '9301', '2801', '1812', '6703', '6473', '7011', '2501', '5332',
       '5801', '6326', '6113'], dtype='<U4')

In [21]:
print(sampler_manager.mean_cost_price_array_sampler.sample())

None


## 環境クラス 

In [22]:
class TradeEnv:
    """
    PortfolioStateを利用して基本的な売買を行う強化学習用の環境
    """
    def __init__(self, 
                 portfolio_transformer,
                 sampler_manager,
                 window=[0],
                 fee_const=0.0025,
                ):
        """
        portfolio_transformer: PortfolioTransformer
             ポートフォリオを遷移させるTransformer
        sampler_manager: SamplerManager
            各種サンプリングを行うためのマネージャー
        window: list
            PortfolioStateのウィンドウサイズ
        fee_const: float
            単位報酬・取引量当たりの手数料
        """
        self.portfolio_transformer = portfolio_transformer
        self.sampler_manager = sampler_manager
        self.window = window
        self.fee_const = fee_const
        
        
        
    def reset(self, seed=None):
        """
        Parameters
        ----------
        seed: int
            乱数シード
            
        Returns
        -------
        PortfolioState
            遷移したPortfolioStateのインスタンス
        float
            報酬．resetなので0
        bool
            終了を示すフラッグ
        dict
            その他の情報
        """
        #from IPython.core.debugger import Pdb; Pdb().set_trace()
        random_ticker_names = self.sampler_manager.ticker_names_sampler.sample(seed=seed)  # 銘柄名のサンプリング
        random_datetime = self.sampler_manager.datetime_sampler.sample(seed=seed, window=self.window)  # 開始日時のサンプリング
        random_portfolio_vector = self.sampler_manager.portfolio_vector_sampler.sample(seed=seed)  # ポートフォリオベクトルのサンプリング
        random_mean_cost_price_array = self.sampler_manager.mean_cost_price_array_sampler.sample(seed=seed)  # 平均取得価格のサンプリング
        
        self.portfolio_transformer.price_supplier.ticker_names = list(random_ticker_names)  # 銘柄名の変更
        self.portfolio_transformer.initial_portfolio_vector = random_portfolio_vector  # 初期ポートフォリオを変更(Noneの場合デフォルト)
        self.portfolio_transformer.initial_mean_cost_price_array = random_mean_cost_price_array  #初期平均取得価格を変更(Noneの場合デフォルト) 
        
        portfolio_state, done = self.portfolio_transformer.reset(random_datetime, window=self.window)
        
        self.portfolio_state = portfolio_state
        
        return self.portfolio_state.copy(), 0, done, None
    
    def step(self, portfolio_vector):
        """
        Parameters
        ----------
        portfolio_vector: ndarray
            actionを意味するポートフォリオベクトル
            
        Returns
        -------
        PortfolioState
            遷移したPortfolioStateのインスタンス
        float
            報酬．
        bool
            終了を示すフラッグ
        dict
            その他の情報
        """        

        previous_portfolio_state = self.portfolio_state
        
        #状態遷移
        portfolio_state, done = self.portfolio_transformer.step(portfolio_vector)
        
        #報酬の計算
        portfolio_vector = portfolio_state.portfolio_vector
        
        price_change_ratio = portfolio_state.now_price_array / previous_portfolio_state.now_price_array  # y
        raw_reward_ratio = np.dot(portfolio_vector, price_change_ratio)  # r
        
        portfolio_change_vector = portfolio_vector - previous_portfolio_state.portfolio_vector #W_{t}-w_{t-1}
        reward = np.log(raw_reward_ratio*(1-self.fee_const*np.dot(portfolio_change_vector, portfolio_change_vector)))
        
        return portfolio_state.copy(), reward, done, None

In [23]:
# samplerの設定
ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
                                     sampling_ticker_number=ticker_number)



start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                         end_datetime=end_datetime,
                                         episode_length=episode_length,
                                         freq_str=freq_str
                                        )

sampler_manager = SamplerManager(ticker_names_sampler=ticker_names_sampler,
                                 datetime_sampler=start_datetime_sampler
                                )


# PriceSupplierの設定
price_supplier = StockDBPriceSupplier(stock_db,
                                      [],  # 最初は何の銘柄コードも指定しない
                                      episode_length,
                                      freq_str,
                                      interpolate=True
                                     )

# PortfolioTransformerの設定
portfolio_transformer = PortfolioTransformer(price_supplier,
                                             portfolio_restrictor=PortfolioRestrictorIdentity(),
                                             use_ohlc="Close",
                                             initial_all_assets=1e6,  # 学習には関係ない
                                             fee_calculator=FeeCalculatorFree()
                                            )

# TradeEnvの設定
trade_env = TradeEnv(portfolio_transformer,
                     sampler_manager,
                     window=np.arange(0,20),
                     fee_const=0.0025
                    )

In [24]:
portfolio_state,_,_,_ = trade_env.reset()
print(len(portfolio_state.names))
print(portfolio_state)

20
PortfolioState( 
names=['yen', '2269', '5711', '3863', '7751', '7912', '9602', '6752', '5020', '2768', '5802', '8795', '8355', '7731', '2002', '9984', '6770', '9983', '3103', '5703']
key_currency_index=0
window=[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
datetime=2020-11-10 09:45:00+09:00
price_array=[[1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00
  1.00000000e+00 1.00000000e+00 1.00000000e+00 1.00000000e+00]
 [7.82000000e+03 7.83000000e+03 7.82000000e+03 7.83000000e+03
  7.84000000e+03 7.84000000e+03 7.85000000e+03 7.84000000e+03
  7.85000000e+03 7.84000000e+03 7.83000000e+03 7.82000000e+03
  7.81000000e+03 7.81000000e+03 7.81000000e+03 7.80000000e+03
  7.81000000e+03 7.82000000e+03 7.82000000e+03 7.82000000e+03]
 [2.00300000e+03 2.00600000e+03 2.00300000e+03 2.003000

In [25]:
portfolio_vector = softmax(np.abs(np.random.randn(1+ticker_number)))
trade_env.step(portfolio_vector)

(PortfolioState(names=['yen', '2269', '5711', '3863', '7751', '7912', '9602', '6752', '5020', '2768', '5802', '8795', '8355', '7731', '2002', '9984', '6770', '9983', '3103', '5703'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19]), datetime=datetime.datetime(2020, 11, 10, 9, 50, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [7.83000000e+03, 7.82000000e+03, 7.83000000e+03, 7.84000000e+03,
         7.84000000e+03, 7.85000000e+03, 7.84000000e+03, 7.85000000e+03,
         7.84000000e+03, 7.83000000e+03, 7.82000000e+03, 7.81000000e+0

### プロファイリング 

In [26]:
def temp_func():
    # samplerの設定
    ticker_names_sampler = TickerSampler(all_ticker_names=ticker_codes,
                                         sampling_ticker_number=ticker_number)



    start_datetime_sampler = DatetimeSampler(start_datetime=start_datetime,
                                             end_datetime=end_datetime,
                                             episode_length=episode_length,
                                             freq_str=freq_str
                                            )

    sampler_manager = SamplerManager(ticker_names_sampler=ticker_names_sampler,
                                     datetime_sampler=start_datetime_sampler
                                    )


    # PriceSupplierの設定
    price_supplier = StockDBPriceSupplier(stock_db,
                                          [],  # 最初は何の銘柄コードも指定しない
                                          episode_length,
                                          freq_str,
                                          interpolate=True
                                         )

    # PortfolioTransformerの設定
    portfolio_transformer = PortfolioTransformer(price_supplier,
                                                 portfolio_restrictor=PortfolioRestrictorIdentity(),
                                                 use_ohlc="Close",
                                                 initial_all_assets=1e6,  # 学習には関係ない
                                                 fee_calculator=FeeCalculatorFree()
                                                )

    # TradeEnvの設定
    trade_env = TradeEnv(portfolio_transformer,
                         sampler_manager,
                         window=np.arange(0,20),
                         fee_const=0.0025
                        )
    
    trade_env.reset()
    for i in range(30):
        portfolio_vector = softmax(np.abs(np.random.randn(1+ticker_number)))
        trade_env.step(portfolio_vector)


from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(TradeEnv)                                                                                          
#prf.add_function()                                                                                      
prf.runcall(temp_func)                                                                                          
prf.print_stats()

Timer unit: 1e-07 s

Total time: 5.45e-05 s
File: <ipython-input-22-f3f976fcf34c>
Function: __init__ at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
     5                                               def __init__(self, 
     6                                                            portfolio_transformer,
     7                                                            sampler_manager,
     8                                                            window=[0],
     9                                                            fee_const=0.0025,
    10                                                           ):
    11                                                   """
    12                                                   portfolio_transformer: PortfolioTransformer
    13                                                        ポートフォリオを遷移させるTransformer
    14                                                   sampler_manager: SamplerManager
    15      

In [27]:
portfolio_state_list = []
reward_list = []
portfolio_state, reward, done, info = trade_env.reset()
portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
reward_list.append(reward)

while True:
    action = softmax(np.abs(np.random.randn(1+ticker_number)))
    portfolio_state, reward, done, info = trade_env.step(action)
    portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
    reward_list.append(reward)
    if done:
        break

In [28]:
visualize_portfolio_transform_bokeh(portfolio_state_list, is_save=False, is_show=True)

## 学習過程の可視化 

In [29]:
def visualize_portfolio_rl_bokeh(portfolio_state_list, reward_list, save_path=None, is_save=False, is_show=True, is_jupyter=True):
    all_datetime_array = np.array([get_naive_datetime_from_datetime(one_state.datetime) for one_state in portfolio_state_list])
    reward_array = np.array(reward_list)
    x = np.arange(0, len(portfolio_state_list))

    layout_list = visualize_portfolio_transform_bokeh(portfolio_state_list, is_save=False, is_show=False)

    add_p1 = bokeh.plotting.figure(plot_width=1200,plot_height=300,title="報酬")
    add_p1.line(x, reward_array, legend_label="reward", line_width=2, color="green")
    add_p1.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    add_p1.yaxis[0].axis_label = "報酬"

    layout_list.extend([add_p1])
    created_figure = bokeh.layouts.column(*layout_list)

    if is_save:
            if save_path.suffix == ".png":
                bokeh.io.export_png(created_figure, filename=save_path)
            elif save_path.suffix == ".html":
                output_file(save_path)
                bokeh.io.save(created_figure, filename=save_path, title="trading process")    
            else:
                raise Exception("The suffix of save_path is must be '.png' or '.html'.")
            
            return None
    if is_show:
        try:
            reset_output()
            if is_jupyter:
                output_notebook()
            show(created_figure)
        except:
            if is_jupyter:
                output_notebook()
            show(created_figure)
            
        return None
        
    if not is_save and not is_show:
        return layout_list

In [30]:
visualize_portfolio_rl_bokeh(portfolio_state_list, reward_list, is_show=True)

##  前処理用の関数

In [31]:
class ComposeFunction:
    """
    callabreのリスト・ディクショナリ全てを実行する．動的なアクセスのために各collableはアトリビュートとして保持する．
    """
    def __init__(self, collection):
        """
        collection: list or dict of function
            適用したいcallableなオブジェクトのリストか辞書．辞書の場合はキーはアトリビュート名とする．
        """
        if isinstance(collection, list):
            self.function_name_list = []
            for i, func in enumerate(collection):
                func_name = "func_"+str(i)
                setattr(self, func_name, func)
                self.function_name_list.append(func_name)
            
        elif isinstance(collection, dict):
            self.function_name_list = []
            for key in collection:
                setattr(self, key, collection[key])
                self.function_name_list.append(key)
                
    def __call__(self, x):
        """
        x: any
            各関数の引数
        """
        for func_name in self.function_name_list:
            x = getattr(self, func_name)(x)
            
        return x

In [32]:
class State2Feature:
    """
    最後に実行
    """
    def __call__(self, portfolio_state):
        price_array = portfolio_state.price_array
        price_portfolio = price_array * portfolio_state.portfolio_vector[:,None]
        price_mean_cost = price_array * portfolio_state.mean_cost_price_array[:,None]
        feature = np.stack([price_array, price_portfolio, price_mean_cost], axis=0)
        return feature

In [33]:
state, _,_,_ = trade_env.reset()

In [34]:
feature = State2Feature()(state)

In [35]:
feature.shape

(3, 20, 20)

In [36]:
class PriceNormalizeConst:
    """
    price_arrayを定数で割る
    """
    def __init__(self, const_array=None):
        self.const_array = const_array
        
    def __call__(self, portfolio_state):
        if np.isnan(self.const_array).sum() > 0:
            raise Exception("const array has nan data.")
        assert portfolio_state.price_array.shape[0]==self.const_array.shape[0]
        new_price_array = portfolio_state.price_array / self.const_array[:,None]
        return portfolio_state._replace(price_array=new_price_array)

In [37]:
PriceNormalizeConst(np.array([100000]*(1+ticker_number)))(state)

PortfolioState(names=['yen', '7201', '2002', '4005', '2801', '9602', '6762', '8253', '7733', '9766', '6301', '9433', '6506', '5411', '7211', '6701', '8601', '2502', '4506', '9001'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), datetime=datetime.datetime(2020, 11, 10, 13, 40, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05,
        1.00000000e-05, 1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
       [4.07600006e-03, 4.06899994e-03, 4.06899994e-03, 4.06500000e-03,
        4.07700012e-03, 4.08000000e-03, 4.08700012e-03, 4.08700012e-03,
        4.09299988e-03, 4.09299988e-03, 4.09899994e-03, 4.08700012e-03,
     

In [38]:
class MeanCostPriceNormalizeConst:
    """
    mean_cost_price_arrayを定数で割る
    """
    def __init__(self, const_array):
        self.const_array = const_array
    
    def __call__(self, portfolio_state):
        if np.isnan(self.const_array).sum() > 0:
            raise Exception("const array has nan data.")
        assert portfolio_state.mean_cost_price_array.shape[0]==self.const_array.shape[0]
        new_mean_cost_price = portfolio_state.mean_cost_price_array / self.const_array
        return portfolio_state._replace(mean_cost_price_array=new_mean_cost_price)

In [39]:
MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number)))(state)

PortfolioState(names=['yen', '7201', '2002', '4005', '2801', '9602', '6762', '8253', '7733', '9766', '6301', '9433', '6506', '5411', '7211', '6701', '8601', '2502', '4506', '9001'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), datetime=datetime.datetime(2020, 11, 10, 13, 40, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
        1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
       [4.07600006e+02, 4.06899994e+02, 4.06899994e+02, 4.06500000e+02,
        4.07700012e+02, 4.08000000e+02, 4.08700012e+02, 4.08700012e+02,
        4.09299988e+02, 4.09299988e+02, 4.09899994e+02, 4.08700012e+02,
     

In [40]:
transform = ComposeFunction([PriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             State2Feature()
                            ])

In [41]:
transform(state)

array([[[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [4.07600006e-03, 4.06899994e-03, 4.06899994e-03, ...,
         4.24200012e-03, 4.21399994e-03, 4.22700012e-03],
        [1.70600000e-02, 1.70400000e-02, 1.70700000e-02, ...,
         1.76400000e-02, 1.75700000e-02, 1.75600000e-02],
        ...,
        [4.00300000e-02, 4.00600000e-02, 4.01800000e-02, ...,
         4.07600000e-02, 4.07500000e-02, 4.09400000e-02],
        [1.35000000e-02, 1.34900000e-02, 1.35100000e-02, ...,
         1.40600000e-02, 1.41400000e-02, 1.41000000e-02],
        [3.38500000e-02, 3.38500000e-02, 3.38500000e-02, ...,
         3.50000000e-02, 3.48500000e-02, 3.49000000e-02]],

       [[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 

In [42]:
dir(transform)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'func_0',
 'func_1',
 'func_2',
 'function_name_list']

In [43]:
transform = ComposeFunction({"price_normalize":PriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             "mean_cost_price_normalize":MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             "state2feature":State2Feature()
                            })

In [44]:
transform(state)

array([[[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [4.07600006e-03, 4.06899994e-03, 4.06899994e-03, ...,
         4.24200012e-03, 4.21399994e-03, 4.22700012e-03],
        [1.70600000e-02, 1.70400000e-02, 1.70700000e-02, ...,
         1.76400000e-02, 1.75700000e-02, 1.75600000e-02],
        ...,
        [4.00300000e-02, 4.00600000e-02, 4.01800000e-02, ...,
         4.07600000e-02, 4.07500000e-02, 4.09400000e-02],
        [1.35000000e-02, 1.34900000e-02, 1.35100000e-02, ...,
         1.40600000e-02, 1.41400000e-02, 1.41000000e-02],
        [3.38500000e-02, 3.38500000e-02, 3.38500000e-02, ...,
         3.50000000e-02, 3.48500000e-02, 3.49000000e-02]],

       [[1.00000000e-05, 1.00000000e-05, 1.00000000e-05, ...,
         1.00000000e-05, 1.00000000e-05, 1.00000000e-05],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 

In [45]:
dir(transform)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'function_name_list',
 'mean_cost_price_normalize',
 'price_normalize',
 'state2feature']