In [1]:
%cd ../../..

E:\システムトレード入門\trade_system_git_workspace


In [2]:
from numpy.random import RandomState
import numpy as np
import pandas as pd
from scipy.special import softmax
from pathlib import Path
import datetime
from pytz import timezone

In [3]:
from bokeh.io import output_notebook, show, reset_output, output_file
import bokeh
output_notebook()

In [4]:
import py_workdays

In [5]:
from get_stock_price import StockDatabase

In [6]:
from portfolio.trade_transformer import PortfolioTransformer, PortfolioRestrictorIdentity, FeeCalculatorFree
from portfolio.price_supply import StockDBPriceSupplier

In [7]:
from utils import middle_sample_type_with_check, get_naive_datetime_from_datetime

In [8]:
from visualization import visualize_portfolio_transform_bokeh

## グローバルパラメータ― 

In [9]:
jst = timezone("Asia/Tokyo")
start_datetime = jst.localize(datetime.datetime(2020,11,10,0,0,0))
end_datetime = jst.localize(datetime.datetime(2020,11,20,0,0,0))
ticker_number = 10
episode_length = 300
freq_str = "5T"

## 利用するデータベース

In [10]:
db_path = Path("db/sub_stock_db/nikkei_255_stock.db")
stock_db = StockDatabase(db_path)

## サンプリングするクラス 

In [11]:
class TickerSampler:
    """
    ticker_nameをサンプリングするためのクラス
    """
    def __init__(self, all_ticker_names, sampling_ticker_number, select_datetime=None):
        """
        all_ticker_names: list of str
            サンプリングを行う銘柄名のリスト
        sampling_ticker_number: int
            サンプリング数
        select_datetime: datetime.datetime
            指定する日時．サンプリングは行わないのでこのまま返る
        """
        self.all_ticker_names = all_ticker_names
        self.sampling_ticker_number = sampling_ticker_number
        self.select_datetime = select_datetime

    def sample(self, seed=None, *args):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        Returns
        -------
        list of str 
            サンプリングされた銘柄名のndarray
        datetime.datetime
        """
        random_ticker_names = RandomState(seed).choice(self.all_ticker_names, self.sampling_ticker_number, replace=False)  # 重複を許さない
        return random_ticker_names, self.select_datetime

In [12]:
class DatetimeSampler:
    """
    datetime.datetimeをサンプリングするためのクラス
    """
    def __init__(self, start_datetime, end_datetime, episode_length, freq_str, ticker_names=None):
        """
        start_datetime: datetime.datetime
            サンプリングする範囲の開始日時
        end_datetime: datetime.datetime
            サンプリングする範囲の終了日時
        episode_length: int
            エピソード長
        freq_str: str
            サンプリング周期を表す文字列
        ticker_names: list of str
            指定する銘柄名のリスト．サンプリングしないのでこのまま返る
        """
        self.ticker_names = ticker_names
        
        self.start_datetime = start_datetime
        self.end_datetime = end_datetime

        self.freq_str = middle_sample_type_with_check(freq_str)


        all_datetime_index = pd.date_range(start=self.start_datetime,
                                           end=self.end_datetime,
                                           freq=self.freq_str,
                                           closed="left"
                                          )  
        
        self.all_datetime_value = py_workdays.extract_workdays_intraday_jp_index(all_datetime_index).to_pydatetime()
        
        self.all_datetime_index_range = np.arange(0, len(self.all_datetime_value))
        self.episode_length = episode_length
        
    def sample(self, seed=None, window=np.array([0])):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        window: list of int
            ウィンドウを表すリスト
        
        Returns
        -------
        list of str
        datetime.datetime
            サンプリングされた日時
        """
        
        if not isinstance(window, np.ndarray):
            self.window = np.array(window)
        else:
            self.window = window
        
        min_window = min(self.window)
        max_window = max(self.window)
        
        random_datetime = self.all_datetime_value[RandomState(seed).choice(self.all_datetime_index_range[abs(min_window):-(max_window+self.episode_length)],1)].item()
        
        return self.ticker_names, random_datetime      

In [13]:
class TickerDatetimeSampler():
    """
    指定した範囲から，エピソード分利用可能なdatetimeをサンプリングするためのクラス
    """
    def __init__(self, all_ticker_names, sampling_ticker_number, start_datetime, end_datetime, episode_length, freq_str):
        """
        all_ticker_names: list of str
            サンプリングを行う銘柄名のリスト
        sampling_ticker_number: int
            銘柄名のサンプリング数
        start_datetime: datetime.datetime
            サンプリングする範囲の開始日時
        end_datetime: datetime.datetime
            サンプリングする範囲の終了日時
        episode_length: int
            エピソード長
        freq_str: str
            サンプリング周期を表す文字列
        """
        self.ticker_sampler = TickerSampler(all_ticker_names, sampling_ticker_number, select_datetime=None)
        self.datetime_sampler = DatetimeSampler(start_datetime, end_datetime, episode_length, freq_str, ticker_names=None)

        
    def sample(self, seed=None, window=np.array([0])):
        """
        Parameters
        ----------
        seed: int
            ランダムシード
            
        window: list of int
            ウィンドウを表すリスト
        
        Returns
        -------
        list of str 
            サンプリングされた銘柄名のndarray
        datetime.datetime
            サンプリングされた日時
        """
        random_ticker_names, _ = self.ticker_sampler.sample(seed=seed)
        _, random_datetime = self.datetime_sampler.sample(seed=seed, window=window)
        
        return random_ticker_names, random_datetime

csvファイルの読み込み

In [14]:
ticker_codes_df = pd.read_csv(Path("portfolio/rl_base/nikkei225.csv"), header=0)  # 自分で作成
ticker_codes = ticker_codes_df["code"].values.astype(str).tolist()
ticker_codes

['1333',
 '1332',
 '1605',
 '1802',
 '1801',
 '1963',
 '1812',
 '1803',
 '1925',
 '1808',
 '1928',
 '1721',
 '2002',
 '2269',
 '2914',
 '2503',
 '2501',
 '2801',
 '2502',
 '2802',
 '2531',
 '2871',
 '2282',
 '3101',
 '3401',
 '3402',
 '3103',
 '3863',
 '3861',
 '6988',
 '4021',
 '4061',
 '4901',
 '4004',
 '4911',
 '4043',
 '4183',
 '4042',
 '3407',
 '4208',
 '4188',
 '3405',
 '4005',
 '4631',
 '4063',
 '4452',
 '4519',
 '4578',
 '4506',
 '4502',
 '4507',
 '4523',
 '4568',
 '5019',
 '5020',
 '5101',
 '5202',
 '5214',
 '5232',
 '5233',
 '5301',
 '5332',
 '5333',
 '5401',
 '5406',
 '5411',
 '5541',
 '3436',
 '5703',
 '5706',
 '5707',
 '5711',
 '5713',
 '5714',
 '5801',
 '5802',
 '5803',
 '5901',
 '5631',
 '6103',
 '6113',
 '6301',
 '6302',
 '6305',
 '6326',
 '6361',
 '6367',
 '6471',
 '6472',
 '6473',
 '7004',
 '7011',
 '7013',
 '3105',
 '6479',
 '6501',
 '6503',
 '6504',
 '6506',
 '6645',
 '6674',
 '6701',
 '6702',
 '6703',
 '6724',
 '6752',
 '6758',
 '6762',
 '6770',
 '6841',
 '6857',
 

In [15]:
sampler = TickerDatetimeSampler(all_ticker_names=ticker_codes,
                        sampling_ticker_number=ticker_number,
                        start_datetime=start_datetime,
                        end_datetime=end_datetime,
                        episode_length=episode_length,
                        freq_str=freq_str
                       )

In [16]:
sampler.sample(window=[-5,0,5])

(array(['5631', '5214', '8795', '7733', '8411', '8309', '6758', '6302',
        '6326', '2801'], dtype='<U4'),
 datetime.datetime(2020, 11, 12, 11, 0, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>))

## 環境クラス 

In [17]:
class TradeEnv:
    """
    PortfolioStateを利用して基本的な売買を行う強化学習用の環境
    """
    def __init__(self, 
                 portfolio_transformer,
                 sampler,
                 window=[0],
                 fee_const=0.0025,
                ):
        """
        portfolio_transformer: PortfolioTransformer
             ポートフォリオを遷移させるTransformer
        sampler: Sampler
            銘柄と日時をサンプリングするためのサンプラー
        window: list
            PortfolioStateのウィンドウサイズ
        fee_const: float
            単位報酬・取引量当たりの手数料
        """
        self.portfolio_transformer = portfolio_transformer
        self.sampler = sampler
        self.window = window
        self.fee_const = fee_const
        
    def reset(self, seed=None):
        """
        Parameters
        ----------
        seed: int
            乱数シード
            
        Returns
        -------
        PortfolioState
            遷移したPortfolioStateのインスタンス
        float
            報酬．resetなので0
        bool
            終了を示すフラッグ
        dict
            その他の情報
        """
        random_ticker_names, random_datetime = self.sampler.sample(seed=seed, window=self.window)  # 銘柄名,日時のサンプリング
        self.portfolio_transformer.price_supplier.ticker_names = list(random_ticker_names)  # 銘柄名の変更
        portfolio_state, done = self.portfolio_transformer.reset(random_datetime, window=self.window)
        
        self.portfolio_state = portfolio_state
        
        return self.portfolio_state.copy(), 0, done, None
    
    def step(self, portfolio_vector):
        """
        Parameters
        ----------
        portfolio_vector: ndarray
            actionを意味するポートフォリオベクトル
            
        Returns
        -------
        PortfolioState
            遷移したPortfolioStateのインスタンス
        float
            報酬．
        bool
            終了を示すフラッグ
        dict
            その他の情報
        """        

        previous_portfolio_state = self.portfolio_state
        
        #状態遷移
        portfolio_state, done = self.portfolio_transformer.step(portfolio_vector)
        
        #報酬の計算
        portfolio_vector = portfolio_state.portfolio_vector
        
        price_change_ratio = portfolio_state.now_price_array / previous_portfolio_state.now_price_array  # y
        raw_reward_ratio = np.dot(portfolio_vector, price_change_ratio)  # r
        
        portfolio_change_vector = portfolio_vector - previous_portfolio_state.portfolio_vector #W_{t}-w_{t-1}
        reward = np.log(raw_reward_ratio*(1-self.fee_const*np.dot(portfolio_change_vector, portfolio_change_vector)))
        
        return portfolio_state.copy(), reward, done, None

In [18]:
sampler = TickerDatetimeSampler(all_ticker_names=ticker_codes,
                        sampling_ticker_number=ticker_number,
                        start_datetime=start_datetime,
                        end_datetime=end_datetime,
                        episode_length=episode_length,
                        freq_str=freq_str
                       )


price_supplier = StockDBPriceSupplier(stock_db,
                                      [],  # 最初は何の銘柄コードも指定していしない
                                      episode_length,
                                      freq_str,
                                      interpolate=True
                                     )

portfolio_transformer = PortfolioTransformer(price_supplier,
                                             portfolio_restrictor=PortfolioRestrictorIdentity(),
                                             use_ohlc="Close",
                                             initial_portfolio_vector=None,
                                             initial_all_assets=1e6,  # 学習には関係ない
                                             fee_calculator=FeeCalculatorFree()
                                            )

trade_env = TradeEnv(portfolio_transformer,
                     sampler,
                     window=np.arange(0,20),
                     fee_const=0.0025
                    )

In [19]:
trade_env.reset()

(PortfolioState(names=['yen', '8804', '7012', '5703', '2801', '1812', '9064', '3407', '1803', '5411', '5233'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19]), datetime=datetime.datetime(2020, 11, 12, 9, 25, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [1.35200000e+03, 1.34900000e+03, 1.34500000e+03, 1.34800000e+03,
         1.34500000e+03, 1.34900000e+03, 1.34600000e+03, 1.34500000e+03,
         1.34600000e+03, 1.34900000e+03, 1.34900000e+03, 1.34800000e+03,
         1.34700000e+03, 1.34700000e+03, 1.34400000e+03, 1.34000000e+

In [20]:
portfolio_vector = softmax(np.abs(np.random.randn(1+ticker_number)))
trade_env.step(portfolio_vector)

(PortfolioState(names=['yen', '8804', '7012', '5703', '2801', '1812', '9064', '3407', '1803', '5411', '5233'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19]), datetime=datetime.datetime(2020, 11, 12, 9, 30, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00,
         1.00000000e+00, 1.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [1.34900000e+03, 1.34500000e+03, 1.34800000e+03, 1.34500000e+03,
         1.34900000e+03, 1.34600000e+03, 1.34500000e+03, 1.34600000e+03,
         1.34900000e+03, 1.34900000e+03, 1.34800000e+03, 1.34700000e+03,
         1.34700000e+03, 1.34400000e+03, 1.34000000e+03, 1.34300000e+

### プロファイリング 

In [21]:
def temp_func():
    sampler = TickerDatetimeSampler(all_ticker_names=ticker_codes,
                            sampling_ticker_number=ticker_number,
                            start_datetime=start_datetime,
                            end_datetime=end_datetime,
                            episode_length=episode_length,
                            freq_str=freq_str
                           )


    price_supplier = StockDBPriceSupplier(stock_db,
                                          [],  # 最初は何の銘柄コードも指定していしない
                                          episode_length,
                                          freq_str,
                                          interpolate=True
                                         )

    portfolio_transformer = PortfolioTransformer(price_supplier,
                                                 portfolio_restrictor=PortfolioRestrictorIdentity(),
                                                 use_ohlc="Close",
                                                 initial_portfolio_vector=None,
                                                 initial_all_assets=1e6,  # 学習には関係ない
                                                 fee_calculator=FeeCalculatorFree()
                                                )

    trade_env = TradeEnv(portfolio_transformer,
                         sampler,
                         window=np.arange(0,20),
                         fee_const=0.0025
                        )
    
    trade_env.reset()
    for i in range(30):
        portfolio_vector = softmax(np.abs(np.random.randn(1+ticker_number)))
        trade_env.step(portfolio_vector)


from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(TradeEnv)                                                                                          
#prf.add_function()                                                                                      
prf.runcall(temp_func)                                                                                          
prf.print_stats()

Timer unit: 1e-07 s

Total time: 4.8e-05 s
File: <ipython-input-17-fe21c210166e>
Function: __init__ at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
     5                                               def __init__(self, 
     6                                                            portfolio_transformer,
     7                                                            sampler,
     8                                                            window=[0],
     9                                                            fee_const=0.0025,
    10                                                           ):
    11                                                   """
    12                                                   portfolio_transformer: PortfolioTransformer
    13                                                        ポートフォリオを遷移させるTransformer
    14                                                   sampler: Sampler
    15                              

In [22]:
portfolio_state_list = []
reward_list = []
portfolio_state, reward, done, info = trade_env.reset()
portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
reward_list.append(reward)

while True:
    action = softmax(np.abs(np.random.randn(1+ticker_number)))
    portfolio_state, reward, done, info = trade_env.step(action)
    portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
    reward_list.append(reward)
    if done:
        break

In [23]:
visualize_portfolio_transform_bokeh(portfolio_state_list, is_save=False, is_show=True)

## 学習過程の可視化 

In [24]:
def visualize_portfolio_rl_bokeh(portfolio_state_list, reward_list, save_path=None, is_save=False, is_show=True, is_jupyter=True):
    all_datetime_array = np.array([get_naive_datetime_from_datetime(one_state.datetime) for one_state in portfolio_state_list])
    reward_array = np.array(reward_list)
    x = np.arange(0, len(portfolio_state_list))

    layout_list = visualize_portfolio_transform_bokeh(portfolio_state_list, is_save=False, is_show=False)

    add_p1 = bokeh.plotting.figure(plot_width=1200,plot_height=300,title="報酬")
    add_p1.line(x, reward_array, legend_label="reward", line_width=2, color="green")
    add_p1.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    add_p1.yaxis[0].axis_label = "報酬"

    layout_list.extend([add_p1])
    created_figure = bokeh.layouts.column(*layout_list)

    if is_save:
            if save_path.suffix == ".png":
                bokeh.io.export_png(created_figure, filename=save_path)
            elif save_path.suffix == ".html":
                output_file(save_path)
                bokeh.io.save(created_figure, filename=save_path, title="trading process")    
            else:
                raise Exception("The suffix of save_path is must be '.png' or '.html'.")
            
            return None
    if is_show:
        try:
            reset_output()
            if is_jupyter:
                output_notebook()
            show(created_figure)
        except:
            if is_jupyter:
                output_notebook()
            show(created_figure)
            
        return None
        
    if not is_save and not is_show:
        return layout_list

In [25]:
visualize_portfolio_rl_bokeh(portfolio_state_list, reward_list, is_show=True)

##  前処理用の関数

In [26]:
class ComposeFunction:
    """
    callabreのリスト・ディクショナリ全てを実行する．動的なアクセスのために各collableはアトリビュートとして保持する．
    """
    def __init__(self, collection):
        """
        collection: list or dict of function
            適用したいcallableなオブジェクトのリストか辞書．辞書の場合はキーはアトリビュート名とする．
        """
        if isinstance(collection, list):
            self.function_name_list = []
            for i, func in enumerate(collection):
                func_name = "func_"+str(i)
                setattr(self, func_name, func)
                self.function_name_list.append(func_name)
            
        elif isinstance(collection, dict):
            self.function_name_list = []
            for key in collection:
                setattr(self, key, collection[key])
                self.function_name_list.append(key)
                
    def __call__(self, x):
        """
        x: any
            各関数の引数
        """
        for func_name in self.function_name_list:
            x = getattr(self, func_name)(x)
            
        return x

In [27]:
class State2Feature:
    """
    最後に実行
    """
    def __call__(self, portfolio_state):
        price_array = portfolio_state.price_array
        price_portfolio = price_array * portfolio_state.portfolio_vector[:,None]
        price_mean_cost = price_array * portfolio_state.mean_cost_price_array[:,None]
        feature = np.stack([price_array, price_portfolio, price_mean_cost], axis=0)
        return feature

In [28]:
state, _,_,_ = trade_env.reset()

In [29]:
feature = State2Feature()(state)

In [30]:
feature.shape

(3, 11, 20)

In [31]:
class PriceNormalizeConst:
    """
    price_arrayを定数で割る
    """
    def __init__(self, const_array=None):
        self.const_array = const_array
        
    def __call__(self, portfolio_state):
        assert portfolio_state.price_array.shape[0]==self.const_array.shape[0]
        new_price_array = portfolio_state.price_array / self.const_array[:,None]
        return portfolio_state._replace(price_array=new_price_array)

In [32]:
PriceNormalizeConst(np.array([100000]*(1+ticker_number)))(state)

PortfolioState(names=['yen', '5713', '9433', '7951', '3436', '7735', '2002', '9984', '1808', '1333', '9983'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), datetime=datetime.datetime(2020, 11, 12, 13, 5, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05,
        1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05,
        1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05, 1.000e-05,
        1.000e-05, 1.000e-05],
       [3.678e-02, 3.672e-02, 3.670e-02, 3.669e-02, 3.669e-02, 3.661e-02,
        3.661e-02, 3.669e-02, 3.665e-02, 3.660e-02, 3.658e-02, 3.667e-02,
        3.671e-02, 3.683e-02, 3.672e-02, 3.683e-02, 3.690e-02, 3.704e-02,
        3.704e-02, 3.706e-02],
       [3.143e-02, 3.143e-02, 3.142e-02, 3.145e-02, 3.144e-02, 3.145e-02,
        3.144e-02, 3.146e-02, 3.148e-02, 3.146e-02, 3.144e-02, 3.143e-02,
 

In [33]:
class MeanCostPriceNormalizeConst:
    """
    mean_cost_price_arrayを定数で割る
    """
    def __init__(self, const_array):
        self.const_array = const_array
    
    def __call__(self, portfolio_state):
        assert portfolio_state.mean_cost_price_array.shape[0]==self.const_array.shape[0]
        new_mean_cost_price = portfolio_state.mean_cost_price_array / self.const_array
        return portfolio_state._replace(mean_cost_price_array=new_mean_cost_price)

In [34]:
MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number)))(state)

PortfolioState(names=['yen', '5713', '9433', '7951', '3436', '7735', '2002', '9984', '1808', '1333', '9983'], key_currency_index=0, window=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19]), datetime=datetime.datetime(2020, 11, 12, 13, 5, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=array([[1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00,
        1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00,
        1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00, 1.000e+00,
        1.000e+00, 1.000e+00],
       [3.678e+03, 3.672e+03, 3.670e+03, 3.669e+03, 3.669e+03, 3.661e+03,
        3.661e+03, 3.669e+03, 3.665e+03, 3.660e+03, 3.658e+03, 3.667e+03,
        3.671e+03, 3.683e+03, 3.672e+03, 3.683e+03, 3.690e+03, 3.704e+03,
        3.704e+03, 3.706e+03],
       [3.143e+03, 3.143e+03, 3.142e+03, 3.145e+03, 3.144e+03, 3.145e+03,
        3.144e+03, 3.146e+03, 3.148e+03, 3.146e+03, 3.144e+03, 3.143e+03,
 

In [35]:
transform = ComposeFunction([PriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             State2Feature()
                            ])

In [36]:
transform(state)

array([[[1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05],
        [3.6780000e-02, 3.6720000e-02, 3.6700000e-02, 3.6690000e-02,
         3.6690000e-02, 3.6610000e-02, 3.6610000e-02, 3.6690000e-02,
         3.6650000e-02, 3.6600000e-02, 3.6580000e-02, 3.6670000e-02,
         3.6710000e-02, 3.6830000e-02, 3.6720000e-02, 3.6830000e-02,
         3.6900000e-02, 3.7040000e-02, 3.7040000e-02, 3.7060000e-02],
        [3.1430000e-02, 3.1430000e-02, 3.1420000e-02, 3.1450000e-02,
         3.1440000e-02, 3.1450000e-02, 3.1440000e-02, 3.1460000e-02,
         3.1480000e-02, 3.1460000e-02, 3.1440000e-02, 3.1430000e-02,
         3.1460000e-02, 3.1510000e-02, 3.1510000e-02, 3.1530000e-02,
         3.1570000e-02, 3.157000

In [37]:
dir(transform)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'func_0',
 'func_1',
 'func_2',
 'function_name_list']

In [38]:
transform = ComposeFunction({"price_normalize":PriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             "mean_cost_price_normalize":MeanCostPriceNormalizeConst(np.array([100000]*(1+ticker_number))),
                             "state2feature":State2Feature()
                            })

In [39]:
transform(state)

array([[[1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05,
         1.0000000e-05, 1.0000000e-05, 1.0000000e-05, 1.0000000e-05],
        [3.6780000e-02, 3.6720000e-02, 3.6700000e-02, 3.6690000e-02,
         3.6690000e-02, 3.6610000e-02, 3.6610000e-02, 3.6690000e-02,
         3.6650000e-02, 3.6600000e-02, 3.6580000e-02, 3.6670000e-02,
         3.6710000e-02, 3.6830000e-02, 3.6720000e-02, 3.6830000e-02,
         3.6900000e-02, 3.7040000e-02, 3.7040000e-02, 3.7060000e-02],
        [3.1430000e-02, 3.1430000e-02, 3.1420000e-02, 3.1450000e-02,
         3.1440000e-02, 3.1450000e-02, 3.1440000e-02, 3.1460000e-02,
         3.1480000e-02, 3.1460000e-02, 3.1440000e-02, 3.1430000e-02,
         3.1460000e-02, 3.1510000e-02, 3.1510000e-02, 3.1530000e-02,
         3.1570000e-02, 3.157000

In [40]:
dir(transform)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'function_name_list',
 'mean_cost_price_normalize',
 'price_normalize',
 'state2feature']