ルートディレクトリに移動 

In [1]:
%cd ../..

E:\システムトレード入門\trade_system_git_workspace


In [2]:
from collections import namedtuple
import pandas as pd
from pathlib import Path
from pytz import timezone
import datetime
import numpy as np
from scipy.special import softmax
from abc import ABCMeta, abstractmethod

In [3]:
import bokeh.plotting
from bokeh.models import Range1d, LinearAxis, Div, HoverTool
from bokeh.io import show
from bokeh.io import output_notebook, reset_output, output_file
from bokeh.palettes import d3
output_notebook()

In [4]:
from get_stock_price import StockDatabase

In [5]:
from utils import get_previous_workday_intraday_datetime, get_next_workday_intraday_datetime, get_naive_datetime_from_datetime

In [6]:
from utils import py_restart
from utils import py_workdays

## データベース 

In [7]:
db_path = Path("E:/システムトレード入門/trade_system_git_workspace/db/sub_stock_db") / Path("sub_stock.db")
stock_db = StockDatabase(db_path)

In [8]:
jst_timezone = timezone("Asia/Tokyo")
all_start_datetime = jst_timezone.localize(datetime.datetime(2020,11,1,0,0,0))
all_end_datetime = jst_timezone.localize(datetime.datetime(2020,12,1,0,0,0))
py_workdays.get_workdays_jp(all_start_datetime.date(), all_end_datetime.date())

array([datetime.date(2020, 11, 2), datetime.date(2020, 11, 4),
       datetime.date(2020, 11, 5), datetime.date(2020, 11, 6),
       datetime.date(2020, 11, 9), datetime.date(2020, 11, 10),
       datetime.date(2020, 11, 11), datetime.date(2020, 11, 12),
       datetime.date(2020, 11, 13), datetime.date(2020, 11, 16),
       datetime.date(2020, 11, 17), datetime.date(2020, 11, 18),
       datetime.date(2020, 11, 19), datetime.date(2020, 11, 20),
       datetime.date(2020, 11, 24), datetime.date(2020, 11, 25),
       datetime.date(2020, 11, 26), datetime.date(2020, 11, 27),
       datetime.date(2020, 11, 30)], dtype=object)

In [9]:
#start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
#end_datetime = jst_timezone.localize(datetime.datetime(2020,11,20,15,0,0))
#stock_list = ["4755","9984","6701","7203","7267"]

#stock_df = stock_db.search_span(stock_names=stock_list, 
#                                start_datetime=start_datetime,
#                                end_datetime=end_datetime,
#                                freq_str="5T",
#                                to_tokyo=True
#                               )

#stock_df

## 供給データ単位クラス

供給されるデータの単位

In [10]:
field_list = ["names",  # 銘柄名
              "key_currency_index",  # 基軸通貨のインデックス
              "datetime",  # データの日時
              "window",  # データのウィンドウ
              "open_array",  # [銘柄名, ウィンドウ(時間)]に対応する始値
              "close_array",  # [銘柄名, ウィンドウ(時間)]に対応する終値
              "high_array",  # [銘柄名, ウィンドウ(時間)]に対応する高値
              "low_array",  # [銘柄名, ウィンドウ(時間)]に対応する低値
              "volume_array"  # [銘柄名, ウィンドウ(時間)]に対応する取引量
             ]

DataSupplyUnitBase = namedtuple("DataSupplyUnitBase", field_list)

In [11]:
class DataSupplyUnit(DataSupplyUnitBase):
    """
    DataSupplierによって提供されるデータクラス
    """
    def __str__(self):
        return_str = "DataSupplyUnit( \n"
        for field_str in self._fields:
            return_str += field_str + "="
            return_str += str(getattr(self, field_str)) + "\n"
        return_str += ")"
        return return_str

## ポートフォリオ状態クラス 

ポートフォリオの状態を表すクラス

In [12]:
field_list = ["names",  # 銘柄名
              "key_currency_index",  # 基軸通貨のインデックス
              "window",  # データのウィンドウ
              "datetime",  # データの日時
              "price_array",  # [銘柄名, ウィンドウ(時間)]に対応する現在価格
              "volume_array",  # [銘柄名, ウィンドウ(時間)]に対応する取引量
              "now_price_array",  # 銘柄名に対応する現在価格
              "portfolio_vector",  # ポートフォリオベクトル
              "mean_cost_price_array",  # 銘柄名に対応する平均取得価格
              "all_assets"  # 基軸通貨で換算した全資産
             ]

PortfolioStateBase = namedtuple("PortfolioStateBase", field_list)

In [13]:
class PortfolioState(PortfolioStateBase):
    """
    バックテスト・強化学習で利用するTransformerが提供するデータクラス．強化学習における状態を内包する．
    """
    
    @property
    def numbers(self):
        """
        保有量のプロパティ
        """
        return self.all_assets*self.portfolio_vector/self.now_price_array
    
    def __str__(self):
        return_str = "PortfolioState( \n"
        for field_str in self._fields:
            return_str += field_str + "="
            return_str += str(getattr(self, field_str)) + "\n"
        return_str += ")"
        return return_str
    
    def copy(self):
        """
        自身のコビーを返す．ndarrayのプロパティの場合はそのコビーを保持する．
        """
        arg_dict = {}
        for field_str in self._fields:
            field_value = getattr(self, field_str)
            if isinstance(field_value, np.ndarray):
                field_value = field_value.copy()
            
            arg_dict[field_str] = field_value
        
        return PortfolioState(**arg_dict)
    
    def partial(self, *args):
        """
        メモリ等の状況によって，自身の部分的なコビーを返す．
        引数にを耐えられなかったプロパティはNoneとなる．
        """
        arg_dict = {}
        for field_str in self._fields:
            if field_str in args:
                field_value = getattr(self, field_str)
                if isinstance(field_value, np.ndarray):
                    field_value = field_value.copy()
            else:
                field_value = None
            
            arg_dict[field_str] = field_value
            
        return PortfolioState(**arg_dict)

## データの供給クラス 

In [14]:
class PriceSuppliier(metaclass=ABCMeta):
    @abstractmethod
    def reset(self, start_datetime, window):
        pass
    
    @abstractmethod
    def step(self):
        pass

今回は株価データベースを用いて価格を供給する．

### ナイーブな実装 

In [15]:
class StockDBPriceSupplier(PriceSuppliier):
    """
    StockDatabaseに対応するPriceSupplier
    """
    def __init__(self, stock_db, ticker_names, episode_length, freq_str, interpolate=True):
        self.stock_db = stock_db
        self.ticker_names = ticker_names
        self.episode_length = episode_length
        self.freq_str = freq_str        
        self.interpolate = interpolate
        # column_namesを分かりやすくまとめる
        self.column_names_list_dict = {}
        for column_type in ["Open", "High", "Low", "Close", "Volume"]:
            self.column_names_list_dict[column_type] = [column_type+"_"+ticker_name for ticker_name in self.ticker_names]
        
    def reset(self, start_datetime, window=np.array([0])):
        """
        start_datetime: datetime.datetime 
            データ供給の開始時刻
        window: ndarray
            データ供給のウィンドウ
        """
        # 終了時刻を求める
        # 全datetimeデータを保持
        assert 0 in window
        if not isinstance(window, np.ndarray):
            self.window = np.array(window)
        else:
            self.window = window
        
        min_window = min(self.window)
        max_window = max(self.window)
        
        if min_window <= 0:
            episode_start_datetime = get_previous_workday_intraday_datetime(start_datetime, self.freq_str, abs(min_window))
        else:
            episode_start_datetime = get_next_workday_intraday_datetime(start_datetime, self.freq_str, min_window)
            
        if self.episode_length+max_window <= 0:  # 基本的にあり得ない
            episode_end_datetime = get_previous_workday_intraday_datetime(start_datetime, self.freq_str, abs(self.episode_length+max_window))
        else:
            episode_end_datetime = get_next_workday_intraday_datetime(start_datetime, self.freq_str, self.episode_length+max_window)
        
        episode_df = self.stock_db.search_span(stock_names=self.ticker_names,
                                               start_datetime=episode_start_datetime,
                                               end_datetime=episode_end_datetime,
                                               freq_str=self.freq_str,
                                               is_end_include=True,  # 最後の値も含める
                                               to_tokyo=True,  #必ずTrueに
                                              )
        
        self.episode_df = py_workdays.extract_workdays_intraday_jp(episode_df)
        
        all_datetime_index = pd.date_range(start=episode_start_datetime,
                                           end=episode_end_datetime,
                                           freq=self.freq_str,
                                           closed="left"
                                          )
        self.all_datetime_index = py_workdays.extract_workdays_intraday_jp_index(all_datetime_index)
        
        # episode_dfの補間
        if self.interpolate:
            add_datetime_bool = ~self.all_datetime_index.isin(self.episode_df.index)
            add_datetime_index = self.all_datetime_index[add_datetime_bool]
            # Noneのdfを作成
            nan_df = pd.DataFrame(None, index=add_datetime_index)
            for one_column in self.episode_df.columns:
                  nan_df[one_column] = np.nan
                    
            # Noneのdfを追加
            self.episode_df = self.episode_df.append(nan_df)
            self.episode_df.sort_index(inplace=True)
            
            # np.nanの補間
            self.episode_df.interpolate(limit_direction="both",inplace=True)
        
        # データの取得
        self.now_index = np.argwhere(self.window==0).item()
        now_datetime = self.all_datetime_index[self.now_index].to_pydatetime()
        
        add_window = [self.now_index+one_value for one_value in self.window]
        window_index_array = self.all_datetime_index[add_window]
        window_data_df = self.episode_df.loc[window_index_array,:]
        
        open_array = window_data_df.loc[:,self.column_names_list_dict["Open"]].values.T
        high_array = window_data_df.loc[:,self.column_names_list_dict["High"]].values.T
        low_array = window_data_df.loc[:,self.column_names_list_dict["Low"]].values.T
        close_array = window_data_df.loc[:,self.column_names_list_dict["Close"]].values.T
        volume_array = window_data_df.loc[:,self.column_names_list_dict["Volume"]].values.T
        
        open_array = np.concatenate([np.ones((1, open_array.shape[1])), open_array], axis=0)
        high_array = np.concatenate([np.ones((1, high_array.shape[1])), high_array], axis=0)
        low_array = np.concatenate([np.ones((1, low_array.shape[1])), low_array], axis=0)
        close_array = np.concatenate([np.ones((1, close_array.shape[1])), close_array], axis=0)
        volume_array = np.concatenate([np.ones((1, volume_array.shape[1])), volume_array], axis=0)
        
        
        out_ticker_names = ["yen"]
        out_ticker_names.extend(self.ticker_names)
        
        out_unit = DataSupplyUnit(names=out_ticker_names,
                                  key_currency_index=0,
                                  datetime=now_datetime,
                                  window=self.window,
                                  open_array=open_array,
                                  close_array=close_array,
                                  high_array=high_array,
                                  low_array=low_array,
                                  volume_array=volume_array
                                 )
        done = False
        return out_unit, done
    
    def step(self):
        # indexの更新
        self.now_index += 1
        now_datetime = self.all_datetime_index[self.now_index].to_pydatetime()
        
        add_window = [self.now_index+one_value for one_value in self.window]
        window_index_array = self.all_datetime_index[add_window]
        window_data_df = self.episode_df.loc[window_index_array,:]
        
        open_array = window_data_df.loc[:,self.column_names_list_dict["Open"]].values.T
        high_array = window_data_df.loc[:,self.column_names_list_dict["High"]].values.T
        low_array = window_data_df.loc[:,self.column_names_list_dict["Low"]].values.T
        close_array = window_data_df.loc[:,self.column_names_list_dict["Close"]].values.T
        volume_array = window_data_df.loc[:,self.column_names_list_dict["Volume"]].values.T
        
        
        open_array = np.concatenate([np.ones((1, open_array.shape[1])), open_array], axis=0)
        high_array = np.concatenate([np.ones((1, high_array.shape[1])), high_array], axis=0)
        low_array = np.concatenate([np.ones((1, low_array.shape[1])), low_array], axis=0)
        close_array = np.concatenate([np.ones((1, close_array.shape[1])), close_array], axis=0)
        volume_array = np.concatenate([np.ones((1, volume_array.shape[1])), volume_array], axis=0)
        
        out_ticker_names = ["yen"]
        out_ticker_names.extend(self.ticker_names)
        
        out_unit = DataSupplyUnit(names=out_ticker_names,
                                  key_currency_index=0,
                                  datetime=now_datetime,
                                  window=self.window,
                                  open_array=open_array,
                                  close_array=close_array,
                                  high_array=high_array,
                                  low_array=low_array,
                                  volume_array=volume_array
                                 )
        done = self.now_index >= self.episode_length
        
        return out_unit, done 

In [16]:
start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
stock_list = ["4755","9984","6701","7203","7267"]
episode_length = 50

price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                     ticker_names=stock_list,
                                     episode_length=episode_length,
                                     freq_str="5T",
                                     interpolate=True
                                    )

In [17]:
data_unit, _ = price_supplier.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
print(data_unit)

DataSupplyUnit( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
datetime=2020-11-10 09:00:00+09:00
window=[-3 -2 -1  0  1  2  3]
open_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1200e+03 1.1190e+03 1.1190e+03 1.1250e+03 1.1270e+03 1.1080e+03
  1.1050e+03]
 [7.0770e+03 7.0680e+03 7.0730e+03 7.0040e+03 7.0720e+03 7.0190e+03
  7.0040e+03]
 [5.7600e+03 5.7600e+03 5.7600e+03 5.7300e+03 5.7000e+03 5.7300e+03
  5.7100e+03]
 [7.1810e+03 7.1750e+03 7.1770e+03 7.3200e+03 7.3430e+03 7.3380e+03
  7.3550e+03]
 [2.8435e+03 2.8400e+03 2.8395e+03 2.9200e+03 2.9595e+03 2.9420e+03
  2.9710e+03]]
close_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1200e+03 1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03
  1.0620e+03]
 [7.0690e+03 7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03
  6.9880e+03]
 [5.7600e+03 5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03
  5.6900e+03]

In [18]:
data_unit, _ = price_supplier.step()
print(data_unit)

DataSupplyUnit( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
datetime=2020-11-10 09:05:00+09:00
window=[-3 -2 -1  0  1  2  3]
open_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1190e+03 1.1190e+03 1.1250e+03 1.1270e+03 1.1080e+03 1.1050e+03
  1.0630e+03]
 [7.0680e+03 7.0730e+03 7.0040e+03 7.0720e+03 7.0190e+03 7.0040e+03
  6.9890e+03]
 [5.7600e+03 5.7600e+03 5.7300e+03 5.7000e+03 5.7300e+03 5.7100e+03
  5.6800e+03]
 [7.1750e+03 7.1770e+03 7.3200e+03 7.3430e+03 7.3380e+03 7.3550e+03
  7.3280e+03]
 [2.8400e+03 2.8395e+03 2.9200e+03 2.9595e+03 2.9420e+03 2.9710e+03
  2.9325e+03]]
close_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03 1.0620e+03
  1.0760e+03]
 [7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03 6.9880e+03
  6.9570e+03]
 [5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03 5.6900e+03
  5.6700e+03]

### プロファイリング 

In [19]:
def temp_func():
    price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                         ticker_names=stock_list,
                                         episode_length=episode_length,
                                         freq_str="5T",
                                         interpolate=True
                                        )
    data_unit, _ = price_supplier.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
    for i in range(30):
        data_unit, _ = price_supplier.step()
    
from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(StockDBPriceSupplier)                                                                                          
#prf.add_function()                                                                                      
prf.runcall(temp_func)                         
prf.print_stats()

Timer unit: 1e-07 s

Total time: 0.0002287 s
File: <ipython-input-15-94ef1fcb0f78>
Function: __init__ at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
     5                                               def __init__(self, stock_db, ticker_names, episode_length, freq_str, interpolate=True):
     6         1        221.0    221.0      9.7          self.stock_db = stock_db
     7         1         78.0     78.0      3.4          self.ticker_names = ticker_names
     8         1         68.0     68.0      3.0          self.episode_length = episode_length
     9         1         62.0     62.0      2.7          self.freq_str = freq_str        
    10         1         62.0     62.0      2.7          self.interpolate = interpolate
    11                                                   # column_namesを分かりやすくまとめる
    12         1         67.0     67.0      2.9          self.column_names_list_dict = {}
    13         6        373.0     62.2     16.3          for colum

以下の部分の処理が2msもかかるのはおかしい

In [20]:
window_index_array = price_supplier.all_datetime_index[[1,2,3,4,5]]
window_data_df = price_supplier.episode_df.loc[window_index_array,:]
%timeit open_array = window_data_df.loc[:,price_supplier.column_names_list_dict["Open"]]

2.35 ms ± 613 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%timeit window_data_df.loc[:,["Open_4755", "High_4755"]]

2.28 ms ± 516 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
%timeit window_data_df[["Open_4755", "High_4755"]]

2.38 ms ± 600 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


ちなみにカラムが一つだけなら100倍近く高速である

In [23]:
%timeit window_data_df["Open_4755"].values

16.2 µs ± 5.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [24]:
%timeit window_data_df.loc[:,"Open_4755"].values

99.6 µs ± 4.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [25]:
%timeit window_data_df.loc[:,["Open_4755", "High_4755"]].values

2.52 ms ± 511 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [26]:
bool_list = [False for _ in window_data_df.columns]
bool_list[0] = True
bool_list[1] = True
%timeit window_data_df.loc[:,bool_list]

2.04 ms ± 506 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


ndarrayを直接扱う方がずっと高速である．しかし，一つの時はpandasの方が高速なのはなぜ？

In [27]:
%timeit window_data_df.values[:,0]

25.5 µs ± 4.09 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [28]:
%timeit window_data_df.values[:,[0,1]]

42.9 µs ± 2.89 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [29]:
%%timeit 
bool_list = [False for _ in window_data_df.columns]
bool_list[0] = True
bool_list[1] = True

window_data_df.values[:,bool_list]

69.9 µs ± 2.91 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


locが遅いので，自分で簡単なクエリを実装する

In [30]:
%%timeit
bool_list = window_data_df.columns.str.startswith("Open")
window_data_df.values[:,bool_list]

520 µs ± 80.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [31]:
%%timeit
bool_array = np.char.startswith(window_data_df.columns.values.astype(str), "Open")
window_data_df.values[:,bool_list]

221 µs ± 55.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


reset時にbool_arrayを求めてしまえば

In [32]:
bool_array = np.char.startswith(window_data_df.columns.values.astype(str), "Open")

In [33]:
%%timeit
window_data_df.values[:,bool_list]

43.7 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


### 速度面で修正

dfの保持を止め, ndarrayを直接クエリする

In [34]:
class StockDBPriceSupplier(PriceSuppliier):
    """
    StockDatabaseに対応するPriceSupplier
    """
    def __init__(self, stock_db, ticker_names, episode_length, freq_str, interpolate=True):
        """
        stock_db: get_stock_price.StockDatabase
            利用するStockDatabaseクラス
        ticker_names: list of str
            利用する銘柄名のリスト
        episode_length: int
            エピソードの長さ
        freq_str: str
            利用する周期を表す文字列
        interpolate: bool
            補間するかどうか
        """
        self.stock_db = stock_db
        self.ticker_names = ticker_names
        self.episode_length = episode_length
        self.freq_str = freq_str        
        self.interpolate = interpolate
        
    def reset(self, start_datetime, window=np.array([0])):
        """
        Parameters
        ----------
        start_datetime: datetime.datetime 
            データ供給の開始時刻
        window: ndarray
            データ供給のウィンドウ
        
        Returns
        -------
        DatasupplyUnit
            提供する価格データ
        bool
            エピソードが終了したかどうか
        """
        # 終了時刻を求める
        # 全datetimeデータを保持
        assert 0 in window
        if not isinstance(window, np.ndarray):
            self.window = np.array(window)
        else:
            self.window = window
        
        min_window = min(self.window)
        max_window = max(self.window)
        
        if min_window <= 0:
            episode_start_datetime = get_previous_workday_intraday_datetime(start_datetime, self.freq_str, abs(min_window))
        else:
            episode_start_datetime = get_next_workday_intraday_datetime(start_datetime, self.freq_str, min_window)
            
        if self.episode_length+max_window <= 0:  # 基本的にあり得ない
            episode_end_datetime = get_previous_workday_intraday_datetime(start_datetime, self.freq_str, abs(self.episode_length+max_window))
        else:
            episode_end_datetime = get_next_workday_intraday_datetime(start_datetime, self.freq_str, self.episode_length+max_window)
        
        episode_df = self.stock_db.search_span(stock_names=self.ticker_names,
                                               start_datetime=episode_start_datetime,
                                               end_datetime=episode_end_datetime,
                                               freq_str=self.freq_str,
                                               is_end_include=True,  # 最後の値も含める
                                               to_tokyo=True,  #必ずTrueに
                                              )
        
        self.episode_df = py_workdays.extract_workdays_intraday_jp(episode_df)
        
        # 各OHLCVに対応するboolを求めておく
        column_names_array = self.episode_df.columns.values.astype(str)
        self.open_bool_array = np.char.startswith(column_names_array, "Open")
        self.high_bool_array = np.char.startswith(column_names_array, "High")
        self.low_bool_array = np.char.startswith(column_names_array, "Low")
        self.close_bool_array = np.char.startswith(column_names_array, "Close")
        self.volume_bool_array = np.char.startswith(column_names_array, "Volume")
        
        
        all_datetime_index = pd.date_range(start=episode_start_datetime,
                                           end=episode_end_datetime,
                                           freq=self.freq_str,
                                           closed="left"
                                          )
        self.all_datetime_index = py_workdays.extract_workdays_intraday_jp_index(all_datetime_index)
        
        # episode_dfの補間
        if self.interpolate:
            add_datetime_bool = ~self.all_datetime_index.isin(self.episode_df.index)
            add_datetime_index = self.all_datetime_index[add_datetime_bool]
            # Noneのdfを作成
            nan_df = pd.DataFrame(index=add_datetime_index, columns=self.episode_df.columns)

            # Noneのdfを追加
            self.episode_df = self.episode_df.append(nan_df)
            self.episode_df.sort_index(inplace=True)
            
            # np.nanの補間
            self.episode_df.interpolate(limit_direction="both",inplace=True)
        else:
            share_index_bool = self.all_datetime_index.isin(self.episode_df.index)
            self.all_datetime_index = self.all_datetime_index[share_index_bool]
        
        # dfをndarrayに変更
        self.episode_df_values = self.episode_df.values
        del self.episode_df
        
        self.all_datetime_index_values = self.all_datetime_index.to_pydatetime()
        del self.all_datetime_index
        
        # データの取得
        self.now_index = np.argwhere(self.window==0).item()
        now_datetime = self.all_datetime_index_values[self.now_index]
        
        add_window = self.now_index + self.window
        window_data_value = self.episode_df_values[add_window,:]
        
        open_array = window_data_value[:,self.open_bool_array].T
        high_array = window_data_value[:,self.high_bool_array].T
        low_array = window_data_value[:,self.low_bool_array].T
        close_array = window_data_value[:,self.close_bool_array].T
        volume_array = window_data_value[:,self.volume_bool_array].T
        
        open_array = np.concatenate([np.ones((1, open_array.shape[1])), open_array], axis=0)
        high_array = np.concatenate([np.ones((1, high_array.shape[1])), high_array], axis=0)
        low_array = np.concatenate([np.ones((1, low_array.shape[1])), low_array], axis=0)
        close_array = np.concatenate([np.ones((1, close_array.shape[1])), close_array], axis=0)
        volume_array = np.concatenate([np.ones((1, volume_array.shape[1])), volume_array], axis=0)
        
        
        out_ticker_names = ["yen"]
        out_ticker_names.extend(self.ticker_names)
        
        out_unit = DataSupplyUnit(names=out_ticker_names,
                                  key_currency_index=0,
                                  datetime=now_datetime,
                                  window=self.window,
                                  open_array=open_array,
                                  close_array=close_array,
                                  high_array=high_array,
                                  low_array=low_array,
                                  volume_array=volume_array
                                 )
        done = False
        return out_unit, done
    
    def step(self):
        """
        Returns
        -------
        DatasupplyUnit
            提供する価格データ
        bool
            エピソードが終了したかどうか
        """
        # indexの更新
        self.now_index += 1
        now_datetime = self.all_datetime_index_values[self.now_index]
        
        add_window = self.now_index + self.window
        window_data_value = self.episode_df_values[add_window,:]
        
        open_array = window_data_value[:,self.open_bool_array].T
        high_array = window_data_value[:,self.high_bool_array].T
        low_array = window_data_value[:,self.low_bool_array].T
        close_array = window_data_value[:,self.close_bool_array].T
        volume_array = window_data_value[:,self.volume_bool_array].T
         
        open_array = np.concatenate([np.ones((1, open_array.shape[1])), open_array], axis=0)
        high_array = np.concatenate([np.ones((1, high_array.shape[1])), high_array], axis=0)
        low_array = np.concatenate([np.ones((1, low_array.shape[1])), low_array], axis=0)
        close_array = np.concatenate([np.ones((1, close_array.shape[1])), close_array], axis=0)
        volume_array = np.concatenate([np.ones((1, volume_array.shape[1])), volume_array], axis=0)
        
        out_ticker_names = ["yen"]
        out_ticker_names.extend(self.ticker_names)
        
        out_unit = DataSupplyUnit(names=out_ticker_names,
                                  key_currency_index=0,
                                  datetime=now_datetime,
                                  window=self.window,
                                  open_array=open_array,
                                  close_array=close_array,
                                  high_array=high_array,
                                  low_array=low_array,
                                  volume_array=volume_array
                                 )
        done = self.now_index >= self.episode_length
        
        return out_unit, done 

In [35]:
start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
stock_list = ["4755","9984","6701","7203","7267"]
episode_length = 50

price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                     ticker_names=stock_list,
                                     episode_length=episode_length,
                                     freq_str="5T",
                                     interpolate=False
                                    )
data_unit, _ = price_supplier.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
print(data_unit)

DataSupplyUnit( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
datetime=2020-11-10 09:00:00+09:00
window=[-3 -2 -1  0  1  2  3]
open_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1200e+03 1.1190e+03 1.1190e+03 1.1250e+03 1.1270e+03 1.1080e+03
  1.1050e+03]
 [7.0770e+03 7.0680e+03 7.0730e+03 7.0040e+03 7.0720e+03 7.0190e+03
  7.0040e+03]
 [5.7600e+03 5.7600e+03 5.7600e+03 5.7300e+03 5.7000e+03 5.7300e+03
  5.7100e+03]
 [7.1810e+03 7.1750e+03 7.1770e+03 7.3200e+03 7.3430e+03 7.3380e+03
  7.3550e+03]
 [2.8435e+03 2.8400e+03 2.8395e+03 2.9200e+03 2.9595e+03 2.9420e+03
  2.9710e+03]]
close_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1200e+03 1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03
  1.0620e+03]
 [7.0690e+03 7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03
  6.9880e+03]
 [5.7600e+03 5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03
  5.6900e+03]

In [36]:
def temp_func():
    price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                         ticker_names=stock_list,
                                         episode_length=episode_length,
                                         freq_str="5T",
                                         interpolate=True
                                        )
    data_unit, _ = price_supplier.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
    for i in range(30):
        data_unit, _ = price_supplier.step()
    
from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(StockDBPriceSupplier)                                                                                          
#prf.add_function()                                                                                      
prf.runcall(temp_func)                         
prf.print_stats()

Timer unit: 1e-07 s

Total time: 3.17e-05 s
File: <ipython-input-34-43659aa98df6>
Function: __init__ at line 5

Line #      Hits         Time  Per Hit   % Time  Line Contents
     5                                               def __init__(self, stock_db, ticker_names, episode_length, freq_str, interpolate=True):
     6                                                   """
     7                                                   stock_db: get_stock_price.StockDatabase
     8                                                       利用するStockDatabaseクラス
     9                                                   ticker_names: list of str
    10                                                       利用する銘柄名のリスト
    11                                                   episode_length: int
    12                                                       エピソードの長さ
    13                                                   freq_str: str
    14                                                       利用する周期を表す

## ポートフォリオの制限 

学習時にはポートフォリオベクトルに制限はもうけない．バックテストや実際の運用時に単元数・基軸通貨換算資産によって制限を設ける．これは強化学習を資産の保有率を求める問題にするためである．

In [37]:
class PortfilioRestrictor(metaclass=ABCMeta):
    """
    ポートフォリオの制限を行う抽象基底クラス
    restrictメソッドをオーバーライドする必要がある
    """
    @abstractmethod
    def restrict(self, portfolio_state, supplied_data_unit, portfolio_vector):
        pass

In [38]:
class PortfolioRestrictorSingleKey(PortfilioRestrictor):
    """
    """
    def __init__(self, unit_number, key_name):
        pass
        
    def restrict(self, portfolio_state, supplied_data_unit, portfolio_vector):
        pass

In [39]:
class PortfolioRestrictorIdentity(PortfilioRestrictor):
    """
    portfolioの恒等写像を行うPortfolioRestrictor
    """
    def restrict(self, portfolio_state, supplied_data_unit, portfolio_vector):
        return portfolio_vector

## 取引手数料の計算クラス 

In [40]:
class FeeCalculator(metaclass=ABCMeta):
    """
    手数料を計算する抽象基底クラス
    """
    @abstractmethod
    def calculate(self, pre_portfolio_state, new_portfolio_state):
        pass

In [41]:
class FeeCalculatorPerNumber(FeeCalculator):
    """
    取引個数に応じて手数料を計算するFeeCalculator
    """
    def __init__(self, fee_per_number):
        self.fee_per_number = fee_per_number
    def calculate(self, pre_portfolio_state, new_portfolio_state):
        not_key_currency_indices_list = list(range(len(pre_portfolio_state.names)))
        not_key_currency_indices_list.remove(pre_portfolio_state.key_currency_index)
        
        not_key_currency_indices = np.array(not_key_currency_indices_list) 
        commition_fee = self.fee_per_number*np.abs((new_portfolio_state.numbers[not_key_currency_indices] - pre_portfolio_state.numbers[not_key_currency_indices])).sum()
        return commition_fee

##  ポートフォリオの遷移クラス  

強化学習の環境だけでなく，バックテスト等でも利用できるように汎用的なもの

In [42]:
class PortfolioTransformer:
    """
    price_supplierの提供するデータに応じてPortfolioStateを遷移させるクラス
    バックテスト・強化学習のどちらでも使えるようにする．
    """
    def __init__(self, 
                 price_supplier, 
                 portfolio_restrictor=PortfolioRestrictorIdentity(), 
                 use_ohlc="Close", 
                 initial_portfolio_vector=None, 
                 initial_all_assets=1e6, 
                 fee_calculator=FeeCalculatorPerNumber(fee_per_number=1e-3)):
        """
        price_supplier: PriceSupplier
            価格データを供給するクラス
        portfolio_restrictor: PortfolioRestrictor
            エージェントが渡すportfolio_vectorを制限するクラス
        use_ohlc: str, defalt:'Close'
            利用する価格データの指定
        initial_portfolio_vector: any, defalt:None
            初期ポートフォリオベクトル
        fee_calculator: FeeCalculator
            手数料を計算するクラス
        """
        self.price_supplier = price_supplier
        self.portfolio_restrictor = portfolio_restrictor
        self.initial_portfolio_vector = initial_portfolio_vector
        self.initial_all_assets = initial_all_assets
        self.fee_calculator = fee_calculator
    
        # 利用するohlcのいずれか
        if use_ohlc not in {"Open","High","Low","Close"}:
            raise Exception("use_ohlc must be in {'Open','High','Low','Close'}")
        
        field_name_dict = {"Open":"open_array",
                           "Close":"close_array",
                           "Low":"low_array",
                           "High":"high_array"
                          }
            
        self.use_ohlc_filed = field_name_dict[use_ohlc]
        
    def reset(self, start_datetime, window=[0]):
        """
        Parameters
        ----------
        start_datetime: datetime.datetime 
            データ供給の開始時刻
        window: ndarray
            データ供給のウィンドウ
            
        Returns
        -------
        PortfolioStat
             ポートフォリオ状態
        bool
            エピソードが終了したかどうか
        """
        initial_data_unit, done = self.price_supplier.reset(start_datetime, window)
    
        now_price_bool = initial_data_unit.window==0 
    
        if self.initial_portfolio_vector is None:
            self.initial_portfolio_vector = np.zeros(len(initial_data_unit.names))
            self.initial_portfolio_vector[initial_data_unit.key_currency_index] = 1.0
            
        else:
            assert len(initial_data_unit.names) == len(self.initial_portfolio_vector)
            assert self.initial_portfolio_vector.sum() == 1.0
            
        
        now_price_array = getattr(initial_data_unit, self.use_ohlc_filed)[:,now_price_bool].squeeze()
    
        self.portfolio_state = PortfolioState(names=initial_data_unit.names,
                                              key_currency_index=initial_data_unit.key_currency_index,
                                              window=initial_data_unit.window,
                                              datetime=initial_data_unit.datetime,
                                              price_array=getattr(initial_data_unit, self.use_ohlc_filed),
                                              volume_array=initial_data_unit.volume_array,
                                              now_price_array=now_price_array,
                                              portfolio_vector=self.initial_portfolio_vector,
                                              mean_cost_price_array=now_price_array,
                                              all_assets=self.initial_all_assets
                                             )
        
        
        return self.portfolio_state.copy(), done
    
    def step(self, action):
        """
        Parameters
        ----------
        action: ndarray
            エージェントが渡すポートフォリオベクトル
            
        Returns
        -------
        PortfolioStat
             ポートフォリオ状態
        bool
            エピソードが終了したかどうか
        """
        
        if not isinstance(action, np.ndarray):
            action = np.array(action)
        assert (action<0).sum() == 0 and (action>1).sum() == 0
        assert abs(action.sum() - 1.0) < 1.e-5  # 大体1ならOK
        
        #from IPython.core.debugger import Pdb; Pdb().set_trace()
        
        previous_portfolio_state = self.portfolio_state
        supplied_data_unit, done = self.price_supplier.step()
        
        assert len(action)==len(supplied_data_unit.names)
        
        restricted_portfolio_vector = self.portfolio_restrictor.restrict(previous_portfolio_state, supplied_data_unit, action)
        
        # 全資産の変化率を求める
        now_price_bool = supplied_data_unit.window==0
        now_price_array = getattr(supplied_data_unit, self.use_ohlc_filed)[:,now_price_bool].squeeze()
        
        price_change_ratio = now_price_array / previous_portfolio_state.now_price_array
        
        all_assets_change_ratio = np.dot(restricted_portfolio_vector, price_change_ratio)
        all_assets = previous_portfolio_state.all_assets * all_assets_change_ratio
        
        # 平均取得価格を設ける
        new_numbers = all_assets*restricted_portfolio_vector/now_price_array
        pre_numbers = previous_portfolio_state.numbers
        mean_num = pre_numbers*previous_portfolio_state.now_price_array + (new_numbers - pre_numbers) * now_price_array
        mean_den = new_numbers
        
        new_numbers_near_zero_bool = new_numbers < 1  # 取り合えず1以下の場合
        mean_num[new_numbers_near_zero_bool] = 1  # 適当に1にしておく
        mean_den[new_numbers_near_zero_bool] = 1  # 適当に1にしておく
        
        mean_cost_price_array = mean_num / mean_den
        mean_cost_price_array[new_numbers_near_zero_bool] = now_price_array[new_numbers_near_zero_bool]
        
        self.portfolio_state = PortfolioState(names=supplied_data_unit.names,
                                              key_currency_index=supplied_data_unit.key_currency_index,
                                              window=supplied_data_unit.window,
                                              datetime=supplied_data_unit.datetime,
                                              price_array=getattr(supplied_data_unit, self.use_ohlc_filed),
                                              volume_array=supplied_data_unit.volume_array,
                                              now_price_array=now_price_array,
                                              portfolio_vector=restricted_portfolio_vector,
                                              mean_cost_price_array=mean_cost_price_array,
                                              all_assets=all_assets
                                             )
                 
        # 手数料の計算と更新
        all_fee = self.fee_calculator.calculate(previous_portfolio_state, self.portfolio_state)
        self.portfolio_state = self.portfolio_state._replace(all_assets=all_assets-all_fee)   
        
        return self.portfolio_state.copy(), done
        

In [43]:
transformer = PortfolioTransformer(price_supplier=price_supplier,
                                   portfolio_restrictor=PortfolioRestrictorIdentity(),
                                   use_ohlc="Close",
                                   initial_portfolio_vector=None,
                                   initial_all_assets=1e6,
                                   fee_calculator=FeeCalculatorPerNumber(0.01)
                                  )

portfolio_state, _ = transformer.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
print(portfolio_state)

PortfolioState( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
window=[-3 -2 -1  0  1  2  3]
datetime=2020-11-10 09:00:00+09:00
price_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1200e+03 1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03
  1.0620e+03]
 [7.0690e+03 7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03
  6.9880e+03]
 [5.7600e+03 5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03
  5.6900e+03]
 [7.1750e+03 7.1770e+03 7.1770e+03 7.3420e+03 7.3360e+03 7.3530e+03
  7.3280e+03]
 [2.8400e+03 2.8395e+03 2.8340e+03 2.9595e+03 2.9415e+03 2.9705e+03
  2.9310e+03]]
volume_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [7.1600e+04 9.3300e+04 4.1190e+05 8.7840e+05 5.8380e+05 2.5710e+05
  7.9410e+05]
 [2.3750e+05 3.0980e+05 3.7840e+05 2.5031e+06 8.4220e+05 7.3700e+05
  4.2530e+05]
 [4.8200e+04 1.5600e+04 4.2800e+04 2.6540e+05 7.0900e+04 5.5200e+04
  7.5200e+0

In [44]:
new_portfolio_vector = [0,1,0,0,0,0]
portfolio_state, _ = transformer.step(new_portfolio_vector)
print(portfolio_state)

PortfolioState( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
window=[-3 -2 -1  0  1  2  3]
datetime=2020-11-10 09:05:00+09:00
price_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03 1.0620e+03
  1.0760e+03]
 [7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03 6.9880e+03
  6.9570e+03]
 [5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03 5.6900e+03
  5.6700e+03]
 [7.1770e+03 7.1770e+03 7.3420e+03 7.3360e+03 7.3530e+03 7.3280e+03
  7.3170e+03]
 [2.8395e+03 2.8340e+03 2.9595e+03 2.9415e+03 2.9705e+03 2.9310e+03
  2.9170e+03]]
volume_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [9.3300e+04 4.1190e+05 8.7840e+05 5.8380e+05 2.5710e+05 7.9410e+05
  4.6700e+05]
 [3.0980e+05 3.7840e+05 2.5031e+06 8.4220e+05 7.3700e+05 4.2530e+05
  4.1350e+05]
 [1.5600e+04 4.2800e+04 2.6540e+05 7.0900e+04 5.5200e+04 7.5200e+04
  7.9000e+0

In [46]:
def temp_func():
    price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                     ticker_names=stock_list,
                                     episode_length=episode_length,
                                     freq_str="5T",
                                     interpolate=True
                                    )
    transformer = PortfolioTransformer(price_supplier=price_supplier,
                                       portfolio_restrictor=PortfolioRestrictorIdentity(),
                                       use_ohlc="Close",
                                       initial_portfolio_vector=None,
                                       initial_all_assets=1e6,
                                       fee_calculator=FeeCalculatorPerNumber(0.01)
                                      )
    portfolio_state, _ = transformer.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
    for _ in range(30):
        new_portfolio_vector = [0,1,0,0,0,0]
        portfolio_state, _ = transformer.step(new_portfolio_vector)
    
from line_profiler import LineProfiler
prf = LineProfiler()                                                                                         
prf.add_module(PortfolioTransformer)                                                                                          
#prf.add_function()                                                                                      
prf.runcall(temp_func)                                                                                          
prf.print_stats()

Timer unit: 1e-07 s

Total time: 6.2e-05 s
File: <ipython-input-42-db242ee61848>
Function: __init__ at line 6

Line #      Hits         Time  Per Hit   % Time  Line Contents
     6                                               def __init__(self, 
     7                                                            price_supplier, 
     8                                                            portfolio_restrictor=PortfolioRestrictorIdentity(), 
     9                                                            use_ohlc="Close", 
    10                                                            initial_portfolio_vector=None, 
    11                                                            initial_all_assets=1e6, 
    12                                                            fee_calculator=FeeCalculatorPerNumber(fee_per_number=1e-3)):
    13                                                   """
    14                                                   price_supplier: PriceSupplier
  

高速化できる部分は無い

## ポートフォリオの遷移を可視化 

In [48]:
start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
stock_list = ["4755","9984","6701","7203","7267", "6502"]
episode_length = 100

price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                     ticker_names=stock_list,
                                     episode_length=episode_length,
                                     freq_str="5T",
                                     interpolate=False
                                    )

transformer = PortfolioTransformer(price_supplier=price_supplier,
                                   portfolio_restrictor=PortfolioRestrictorIdentity(),
                                   use_ohlc="Close",
                                   initial_portfolio_vector=None,
                                   initial_all_assets=1e6,
                                   fee_calculator=FeeCalculatorPerNumber(0)
                                  )


portfolio_state_list = []
initial_state, _ = transformer.reset(start_datetime, window=[-1,0,1])
portfolio_state_list.append(initial_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))

while True:
    action = softmax(np.abs(np.random.randn(1+len(stock_list))))
    portfolio_state, done = transformer.step(action)
    portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
    if done:
        break

In [49]:
portfolio_state_list

[PortfolioState(names=['yen', '4755', '9984', '6701', '7203', '7267', '6502'], key_currency_index=None, window=None, datetime=datetime.datetime(2020, 11, 10, 9, 0, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=None, volume_array=None, now_price_array=array([1.0000e+00, 1.1280e+03, 7.0720e+03, 5.6800e+03, 7.3420e+03,
        2.9595e+03, 2.7800e+03]), portfolio_vector=array([1., 0., 0., 0., 0., 0., 0.]), mean_cost_price_array=array([1.0000e+00, 1.1280e+03, 7.0720e+03, 5.6800e+03, 7.3420e+03,
        2.9595e+03, 2.7800e+03]), all_assets=1000000.0),
 PortfolioState(names=['yen', '4755', '9984', '6701', '7203', '7267', '6502'], key_currency_index=None, window=None, datetime=datetime.datetime(2020, 11, 10, 9, 5, tzinfo=<DstTzInfo 'Asia/Tokyo' JST+9:00:00 STD>), price_array=None, volume_array=None, now_price_array=array([1.0000e+00, 1.1070e+03, 7.0190e+03, 5.7200e+03, 7.3360e+03,
        2.9415e+03, 2.8040e+03]), portfolio_vector=array([0.24120279, 0.14954645, 0.15028568, 0.11

In [50]:
def make_y_limit(y_array, upper_ratio=0.1, lowwer_ratio=0.1):
    min_value = np.amin(y_array)
    max_value = np.amax(y_array)
    diff = max_value - min_value
    return min_value-lowwer_ratio*diff, max_value+upper_ratio*diff

In [51]:
def make_y_limit_multi(y_arrays, upper_ratio=0.1, lowwer_ratio=0.1):
    min_values = []
    max_values = []
    for y_array in y_arrays:
        min_values.append(np.amin(y_array))
        max_values.append(np.amax(y_array))
        
    min_value = min(min_values)
    max_value = max(max_values)
    diff = max_value - min_value
    
    return min_value-lowwer_ratio*diff, max_value+upper_ratio*diff

In [52]:
def make_ticker_text(ticker_value_array, ticker_names):
    div_text = ""
    text_sum_line = 150
    text_sum_count = 0

    for i, ticker_name in enumerate(ticker_names):
        div_text += ticker_name + "="
        text_sum_count += len(ticker_name)
        ticke_value_str = str(ticker_value_array[i])
        div_text += ticke_value_str
        text_sum_count += len(ticke_value_str)

        div_text += ", "
        text_sum_count += 2

        if text_sum_count > text_sum_line:
            div_text += "\n"
            text_sum_count = 0
            
    return div_text

ここはメインの開発場所ではない，プロトタイプ版

In [53]:
def visualize_portfolio_transform_bokeh(portfolio_state_list, save_path=None, is_save=False, is_show=True, is_jupyter=True):
    # テータの取り出し
    ticker_names = portfolio_state_list[0].names
    colors = d3["Category20"][len(ticker_names)]

    all_price_array = np.stack([one_state.now_price_array for one_state in portfolio_state_list], axis=1)
    all_portfolio_vector = np.stack([one_state.portfolio_vector for one_state in portfolio_state_list], axis=1)
    all_mean_cost_price_array = np.stack([one_state.mean_cost_price_array for one_state in portfolio_state_list], axis=1)
    all_assets_array = np.array([one_state.all_assets for one_state in portfolio_state_list])
    all_datetime_array = np.array([get_naive_datetime_from_datetime(one_state.datetime) for one_state in portfolio_state_list])
    x = np.arange(0, len(portfolio_state_list))


    # sorceの作成
    portfolio_vector_source = {"x":x, "datetime":all_datetime_array}
    price_source_x = []
    price_source_y = []

    mean_cost_price_source_x = []
    mean_cost_price_source_y = []

    for i, ticker_name in enumerate(ticker_names):
        portfolio_vector_source[ticker_name] = all_portfolio_vector[i,:]

        price_source_x.append(x)
        price_source_y.append(all_price_array[i,:]/all_price_array[i,0])

        mean_cost_price_source_x.append(x)
        mean_cost_price_source_y.append(all_mean_cost_price_array[i,:]/all_mean_cost_price_array[i,0])

    # ホバーツールの設定
    #tool_tips = [("x", "@x")]
    tool_tips = [("datetime", "@datetime{%F %H:%M:%S}")]
    tool_tips.extend([(ticker_name, "@"+ticker_name+"{0.000}") for ticker_name in ticker_names])

    hover_tool = HoverTool(
        tooltips=tool_tips,
        formatters={'@datetime' : 'datetime'}
    )

    # 描画

    p1_text = Div(text=make_ticker_text(all_price_array[:,0], ticker_names))

    p1 = bokeh.plotting.figure(plot_width=1200,plot_height=500,title="正規化価格・ポートフォリオ")
    p1.add_tools(hover_tool)

    p1.extra_y_ranges = {"portfolio_vector": Range1d(start=0, end=3)}
    p1.add_layout(LinearAxis(y_range_name="portfolio_vector"), 'right')
    p1.vbar_stack(ticker_names, x='x', width=1, color=colors,y_range_name="portfolio_vector", source=portfolio_vector_source, legend_label=ticker_names, alpha=0.8)

    p1.multi_line(xs=price_source_x, ys=price_source_y, line_color=colors, line_width=2)
    y_min, y_max = make_y_limit_multi(price_source_y, lowwer_ratio=0.1, upper_ratio=0.1)
    y_min -= (y_max - y_min) * 0.66  #  ポートフォリオ割合のためのオフセット
    p1.y_range = Range1d(start=y_min, end=y_max)

    p1.yaxis[0].axis_label = "正規化価格"
    p1.yaxis[1].axis_label = "保有割合"

    p1.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    p2_text = Div(text=make_ticker_text(all_mean_cost_price_array[:,0], ticker_names))

    p2 = bokeh.plotting.figure(plot_width=1200,plot_height=300,title="正規化平均取得価格・全資産")
    p2.multi_line(xs=mean_cost_price_source_x, ys=mean_cost_price_source_y, line_color=colors, line_width=2)
    y_min, y_max = make_y_limit_multi(mean_cost_price_source_y, lowwer_ratio=0.1, upper_ratio=0.1)
    p2.y_range = Range1d(start=y_min, end=y_max)

    y_max, y_min = make_y_limit(all_assets_array, upper_ratio=0.1, lowwer_ratio=0.1)
    p2.extra_y_ranges = {"all_assets": Range1d(start=y_max, end=y_min)}
    p2.add_layout(LinearAxis(y_range_name="all_assets"), 'right')
    p2.line(x, all_assets_array, color="red", legend_label="all_assets", line_width=4, y_range_name="all_assets")

    # 疑似的なレジェンドをつける
    for ticker_name, color in zip(ticker_names, colors):
        p2.line([], [], legend_label=ticker_name, color=color, line_width=2)

    p2.yaxis[0].axis_label = "正規化平均取得価格"
    p2.yaxis[1].axis_label = "全資産 [円]"

    p2.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    created_figure = bokeh.layouts.column(p1_text, p1, p2_text, p2)

    if is_save:
            if save_path.suffix == ".png":
                bokeh.io.export_png(created_figure, filename=save_path)
            elif save_path.suffix == ".html":
                output_file(save_path)
                bokeh.io.save(created_figure, filename=save_path, title="trading process")    
            else:
                raise Exception("The suffix of save_path is must be '.png' or '.html'.")
    if is_show:
        try:
            reset_output()
            if is_jupyter:
                output_notebook()
            show(created_figure)
        except:
            if is_jupyter:
                output_notebook()
            show(created_figure)
        
    if not is_save and not is_show:
        raise Exception("is_save and is_show is False. This function do nothing")
        

In [54]:
visualize_portfolio_transform_bokeh(portfolio_state_list, save_path=Path("visualization/trade_transform.png"), is_save=True, is_show=True)