ルートディレクトリに移動 

In [1]:
%cd ../..

E:\システムトレード入門\trade_system_git_workspace


In [2]:
from collections import namedtuple
import pandas as pd
from pathlib import Path
from pytz import timezone
import datetime
import numpy as np
from scipy.special import softmax
from abc import ABCMeta, abstractmethod
import dataclasses
from dataclasses import dataclass

In [3]:
import bokeh.plotting
from bokeh.models import Range1d, LinearAxis, Div, HoverTool
from bokeh.io import show
from bokeh.io import output_notebook, reset_output, output_file
from bokeh.palettes import d3
output_notebook()

In [4]:
from get_stock_price import StockDatabase

In [5]:
from utils import get_previous_workday_intraday_datetime, get_next_workday_intraday_datetime, get_naive_datetime_from_datetime

In [6]:
from utils import py_restart
from utils import py_workdays

## データベース 

In [7]:
db_path = Path("E:/システムトレード入門/trade_system_git_workspace/db/sub_stock_db/nikkei_255_stock_v2.db")
stock_db = StockDatabase(db_path)

In [8]:
jst_timezone = timezone("Asia/Tokyo")
all_start_datetime = jst_timezone.localize(datetime.datetime(2020,11,1,0,0,0))
all_end_datetime = jst_timezone.localize(datetime.datetime(2020,12,1,0,0,0))
py_workdays.get_workdays_jp(all_start_datetime.date(), all_end_datetime.date())

array([datetime.date(2020, 11, 2), datetime.date(2020, 11, 4),
       datetime.date(2020, 11, 5), datetime.date(2020, 11, 6),
       datetime.date(2020, 11, 9), datetime.date(2020, 11, 10),
       datetime.date(2020, 11, 11), datetime.date(2020, 11, 12),
       datetime.date(2020, 11, 13), datetime.date(2020, 11, 16),
       datetime.date(2020, 11, 17), datetime.date(2020, 11, 18),
       datetime.date(2020, 11, 19), datetime.date(2020, 11, 20),
       datetime.date(2020, 11, 24), datetime.date(2020, 11, 25),
       datetime.date(2020, 11, 26), datetime.date(2020, 11, 27),
       datetime.date(2020, 11, 30)], dtype=object)

In [9]:
#start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
#end_datetime = jst_timezone.localize(datetime.datetime(2020,11,20,15,0,0))
#stock_list = ["4755","9984","6701","7203","7267"]

#stock_df = stock_db.search_span(stock_names=stock_list, 
#                                start_datetime=start_datetime,
#                                end_datetime=end_datetime,
#                                freq_str="5T",
#                                to_tokyo=True
#                               )

#stock_df

## 各種エラー 

In [10]:
class TradeSystemBaseError(Exception):
    """
    TradeSystem全体で利用する例外の基底クラス
    """
    pass

In [11]:
class UnitStateError(TradeSystemBaseError):
    """
    UnitあるいはStateに関するクラスのベースクラス
    """
    pass

In [12]:
class UnitStateHasNanError(UnitStateError):
    """
    DataSupplyUnitの要素の代入時にNaNをチェックする
    """
    def __init__(self, err_str=None):
        """
        err_str:
            エラーメッセージ
        """
        self.err_str = err_str
    def __str__(self):
        if self.err_str is None:
            return "This Unit has nan data"
        else:
            return self.err_str

In [13]:
class UnitStateHasWrongLengthError(UnitStateError):
    """
    Unitのnameを基準に特定のフィールドの長さをチェックする
    """
    def __init__(self, err_str=None):
        """
        err_str:
            エラーメッセージ
        """
        self.err_str = err_str
    def __str__(self):
        if self.err_str is None:
            return "This Unit has wrong length"
        else:
            return self.err_str

In [14]:
class PortfolioVectorInvalidError(TradeSystemBaseError):
    """
    portfolio vector が有効かどうかチェックする
    """
    def __init__(self, err_str=None):
        """
        err_str:
            エラーメッセージ
        """
        self.err_str = err_str
    def __str__(self):
        if self.err_str is None:
            return "This portfolio is invalid"
        else:
            return self.err_str

In [15]:
class CannotGetAllDataError(TradeSystemBaseError):
    """
    PriceSupplierでデータの取得が出来なかった場合の例外
    """
    def __init__(self, err_str=None):
        """
        err_str:
            エラーメッセージ
        """
        self.err_str = err_str
    def __str__(self):
        if self.err_str is None:
            return "Cannot get all data."
        else:
            return self.err_str

## 供給データ単位クラス

供給されるデータの単位

In [16]:
@dataclass
class DataSupplyUnit:
    """
    DataSupplierによって提供されるデータクラス
    nan_check: bool
        初期化時にnanをチェックするかどうか．デバッグ時に利用する
    length_check: bool
        初期化時にlengthをチェックするかどうか，デバッグ時に利用する．
    """
    nan_check = False
    length_check = False
    
    names: np.ndarray # 銘柄名
    key_currency_index: int  # 基軸通貨のインデックス
    datetime: datetime.datetime  # データの日時
    window: np.ndarray  # データのウィンドウ
    open_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する始値
    close_array: np.ndarray # [銘柄名, ウィンドウ(時間)]に対応する終値
    high_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する高値
    low_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する低値
    volume_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する取引量
        
    def _replace(self, **kwargs):
        """
        namedtupleとの互換性のため
        """
        return dataclasses.replace(self, **kwargs)
        
    def __post_init__(self):
        # nanが含まれるがチェック
        if DataSupplyUnit.nan_check:
            for field in dataclasses.fields(self):
                value = getattr(self, field.name)
                if isinstance(value, np.ndarray):
                    if np.isnan(value).sum() > 0:
                        raise UnitStateHasNanError("This Unit has nan data about {}".format(field.name))
                    
        # 長さが適切かチェック
        if DataSupplyUnit.length_check:
            name_length = len(self.names)
            window_length = len(self.window)
            for field in dataclasses.fields(self):
                if field.name in {"open_array", "close_array", "high_array", "low_array", "volume_array"}:
                    value = getattr(self, field.name)
                    if value.shape[0]!=name_length or value.shape[1]!=window_length:
                        err_str = "This Unit has wrong legnth about {}({}) with names({}) and window({})".format(field.name,
                                                                                                                value.shape,
                                                                                                                name_length,
                                                                                                                window_length
                                                                                                               )
                        raise UnitStateHasWrongLengthError(err_str)
                        
    def __str__(self):
        return_str = "DataSupplyUnit( \n"
        for field in dataclasses.fields(self):
            return_str += field.name + "="
            return_str += str(getattr(self, field.name)) + "\n"
        return_str += ")"
        return return_str

In [17]:
field_list = ["names",  # 銘柄名
              "key_currency_index",  # 基軸通貨のインデックス
              "datetime",  # データの日時
              "window",  # データのウィンドウ
              "open_array",  # [銘柄名, ウィンドウ(時間)]に対応する始値
              "close_array",  # [銘柄名, ウィンドウ(時間)]に対応する終値
              "high_array",  # [銘柄名, ウィンドウ(時間)]に対応する高値
              "low_array",  # [銘柄名, ウィンドウ(時間)]に対応する低値
              "volume_array"  # [銘柄名, ウィンドウ(時間)]に対応する取引量
             ]

DataSupplyUnitBaseLegacy = namedtuple("DataSupplyUnitBase", field_list)

In [18]:
class DataSupplyUnitLegacy(DataSupplyUnitBaseLegacy):
    """
    DataSupplierによって提供されるデータクラス
    """    
    def __str__(self):
        return_str = "DataSupplyUnit( \n"
        for field_str in self._fields:
            return_str += field_str + "="
            return_str += str(getattr(self, field_str)) + "\n"
        return_str += ")"
        return return_str

In [19]:
one_data = DataSupplyUnit(names=["yen", "1001"],
                          key_currency_index=0,
                          datetime=datetime.datetime(2021,1,1,0,0,0),
                          window=[0,1,2,3,4,5],
                          open_array=np.random.randn(2,6),
                          close_array=np.random.randn(2,6),
                          high_array=np.random.randn(2,6),
                          low_array=np.random.randn(2,6),
                          volume_array=np.random.randn(2,6),
                         )

In [20]:
print(one_data._replace(key_currency_index=1))

DataSupplyUnit( 
names=['yen', '1001']
key_currency_index=1
datetime=2021-01-01 00:00:00
window=[0, 1, 2, 3, 4, 5]
open_array=[[ 0.5874532   1.24634727 -0.11923178 -0.81670214 -0.58627205 -0.33390115]
 [-0.7141509   1.52538916  1.72543283 -0.65183355  1.02129596  0.42153376]]
close_array=[[-1.47940934  0.82485802  0.80215739 -0.27562762  0.57531573 -0.20494426]
 [ 0.24372     0.3387662   0.28303827 -1.16576931 -0.2896231  -1.60117283]]
high_array=[[ 0.27589806  1.4405244  -0.54848456  0.96900059 -0.22075289  1.42110936]
 [-0.82903232 -0.62501196  0.92551081 -3.1377388   1.84704614  2.01540813]]
low_array=[[ 0.73752888  2.11769921  0.72840431  1.83239935 -0.05126364 -1.56547018]
 [-0.02610597 -0.88506695 -0.21943197 -0.58597981  0.35626412 -2.32032454]]
volume_array=[[-0.88702687 -1.44218096  0.96472738 -2.17363413 -0.9108788  -1.12732916]
 [ 1.3881225  -0.61921945 -0.48811268 -0.17467973 -0.80001206 -0.69210997]]
)


In [21]:
DataSupplyUnit.nan_check = True
DataSupplyUnit.length_check = True

In [22]:
nan_array = np.ones((2,6))
nan_array.fill(np.nan)
one_data = DataSupplyUnit(names=["yen", "1001"],
                          key_currency_index=0,
                          datetime=datetime.datetime(2021,1,1,0,0,0),
                          window=[0,1,2,3,4,5],
                          open_array=np.random.randn(2,6),
                          #open_array=nan_array,
                          close_array=np.random.randn(2,6),
                          high_array=np.random.randn(2,6),
                          low_array=np.random.randn(2,6),
                          volume_array=np.random.randn(2,6),
                         )

## ポートフォリオ状態クラス 

ポートフォリオの状態を表すクラス

In [23]:
@dataclass
class PortfolioState:
    """
    バックテスト・強化学習で利用するTransformerが提供するデータクラス．強化学習における状態を内包する．
    nan_check: bool
        初期化時にnanをチェックするかどうか．デバッグ時に利用する
    length_check: bool
        初期化時にlengthをチェックするかどうか，デバッグ時に利用する．
    """
    nan_check = False
    length_check = False
    portfoio_check = False
    
    names: np.ndarray  # 銘柄名
    key_currency_index: int  # 基軸通貨のインデックス
    window: np.ndarray  # データのウィンドウ
    datetime: datetime.datetime  # データの日時
    price_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する現在価格
    volume_array: np.ndarray  # [銘柄名, ウィンドウ(時間)]に対応する取引量
    now_price_array: np.ndarray # 銘柄名に対応する現在価格
    portfolio_vector: np.ndarray  # ポートフォリオベクトル
    mean_cost_price_array: np.ndarray  # 銘柄名に対応する平均取得価格
    all_assets: float  # 基軸通貨で換算した全資産
        
    def _replace(self, **kwargs):
        """
        namedtupleとの互換性のため，
        """
        return dataclasses.replace(self, **kwargs)
    
    def __post_init__(self):
        # nanが含まれるがチェック
        if PortfolioState.nan_check:
            for field in dataclasses.fields(self):
                value = getattr(self, field.name)
                if isinstance(value, np.ndarray):
                    if np.isnan(value).sum() > 0:
                        raise UnitStateHasNanError("This State has nan data about {}".format(field.name))
                    
        # 長さが適切かチェック
        if PortfolioState.length_check and self.names is not None and self.window is not None:
            name_length = len(self.names)
            window_length = len(self.window)
            for field in dataclasses.fields(self):
                if field.name in {"price_array", "volume_array"}:
                    value = getattr(self, field.name)
                    if value is not None:
                        if value.shape[0]!=name_length or value.shape[1]!=window_length:
                            err_str = "This State has wrong legnth about {}({}) with names({}) and window({})".format(field.name,
                                                                                                                      value.shape,
                                                                                                                      name_length,
                                                                                                                      window_length
                                                                                                                     )
                            raise UnitStateHasWrongLengthError(err_str)
                elif field.name in {"now_price_array", "portfolio_vector", "mean_cost_price_array"}:
                    value = getattr(self, field.name)
                    if value is not None:
                        if len(value.shape)!=1 or value.shape[0]!=name_length:
                            err_str = "This State has wrong length about {}({}) with names({})".format(field.name,
                                                                                                       value.shape,
                                                                                                       name_length
                                                                                                      )
                            raise UnitStateHasWrongLengthError(err_str)
        
        #portfolioの和が適切かチェック
        if PortfolioState.portfoio_check:
            portfolio_vector = self.portfolio_vector
            # nanのチェック
            if portfolio_vector is not None:
                if np.isnan(portfolio_vector).sum() > 0:
                    raise PortfolioVectorInvalidError("This portfolio has nan")

                # 上限と下限のチェック
                upper_bool = portfolio_vector > 1
                lower_bool = portfolio_vector < 0
                if upper_bool.sum() > 0 or lower_bool.sum() > 0:
                    raise PortfolioVectorInvalidError("This portfolio is not in (0,1).{}".format(portfolio_vector))

                # 和のチェック
                if abs(portfolio_vector.sum() - 1) > 1.e-5:
                    raise PortfolioVectorInvalidError("The portfolio sum is must be 1. This portfolio is {}, sum is {}".format(portfolio_vector,
                                                                                                                               portfolio_vector.sum()))
                
    @property
    def numbers(self):
        """
        保有量のプロパティ
        """
        return self.all_assets*self.portfolio_vector/self.now_price_array
    
    def __str__(self):
        return_str = "PortfolioState( \n"
        for field in dataclasses.fields(self):
            return_str += field.name + "="
            return_str += str(getattr(self, field.name)) + "\n"
        return_str += ")"
        return return_str
    
    def copy(self):
        """
        自身のコビーを返す．ndarrayのプロパティの場合はそのコビーを保持する．
        """
        arg_dict = {}
        for field in dataclasses.fields(self):
            field_value = getattr(self, field.name)
            if isinstance(field_value, np.ndarray):
                field_value = field_value.copy()
            
            arg_dict[field.name] = field_value
        
        return PortfolioState(**arg_dict)
    
    def partial(self, *args):
        """
        str:
            フィールド名
        メモリ等の状況によって，自身の部分的なコビーを返す．
        引数にを耐えられなかったプロパティはNoneとなる．
        """
        arg_dict = {}
        for field in dataclasses.fields(self):
            if field.name in args:
                field_value = getattr(self, field.name)
                if isinstance(field_value, np.ndarray):
                    field_value = field_value.copy()
            else:
                field_value = None
            
            arg_dict[field.name] = field_value
            
        return PortfolioState(**arg_dict)

In [24]:
field_list = ["names",  # 銘柄名
              "key_currency_index",  # 基軸通貨のインデックス
              "window",  # データのウィンドウ
              "datetime",  # データの日時
              "price_array",  # [銘柄名, ウィンドウ(時間)]に対応する現在価格
              "volume_array",  # [銘柄名, ウィンドウ(時間)]に対応する取引量
              "now_price_array",  # 銘柄名に対応する現在価格
              "portfolio_vector",  # ポートフォリオベクトル
              "mean_cost_price_array",  # 銘柄名に対応する平均取得価格
              "all_assets"  # 基軸通貨で換算した全資産
             ]

PortfolioStateBaseLegacy = namedtuple("PortfolioStateBase", field_list)

In [25]:
class PortfolioStateLegacy(PortfolioStateBaseLegacy):
    """
    バックテスト・強化学習で利用するTransformerが提供するデータクラス．強化学習における状態を内包する．
    """
    
    @property
    def numbers(self):
        """
        保有量のプロパティ
        """
        return self.all_assets*self.portfolio_vector/self.now_price_array
    
    def __str__(self):
        return_str = "PortfolioState( \n"
        for field_str in self._fields:
            return_str += field_str + "="
            return_str += str(getattr(self, field_str)) + "\n"
        return_str += ")"
        return return_str
    
    def copy(self):
        """
        自身のコビーを返す．ndarrayのプロパティの場合はそのコビーを保持する．
        """
        arg_dict = {}
        for field_str in self._fields:
            field_value = getattr(self, field_str)
            if isinstance(field_value, np.ndarray):
                field_value = field_value.copy()
            
            arg_dict[field_str] = field_value
        
        return PortfolioState(**arg_dict)
    
    def partial(self, *args):
        """
        メモリ等の状況によって，自身の部分的なコビーを返す．
        引数にを耐えられなかったプロパティはNoneとなる．
        """
        arg_dict = {}
        for field_str in self._fields:
            if field_str in args:
                field_value = getattr(self, field_str)
                if isinstance(field_value, np.ndarray):
                    field_value = field_value.copy()
            else:
                field_value = None
            
            arg_dict[field_str] = field_value
            
        return PortfolioState(**arg_dict)

In [26]:
one_state = PortfolioState(names=["yen", "1001"],
                           key_currency_index=0,
                           window=[0,1,2,3,4,5],
                           datetime=datetime.datetime(2021,1,1,0,0,0),
                           price_array=np.random.randn(2,6),
                           volume_array=np.random.randn(2,6),
                           now_price_array=np.random.randn(2),
                           portfolio_vector=softmax(np.random.randn(2)),
                           mean_cost_price_array=np.random.randn(2),
                           all_assets=0.0
                          )

In [27]:
one_state.partial("price_array", "all_assets")

PortfolioState(names=None, key_currency_index=None, window=None, datetime=None, price_array=array([[-0.76663687, -2.5236359 , -1.04007341,  0.69908806, -0.22193732,
        -1.2725906 ],
       [-1.69763469,  0.9293035 , -0.74064379, -1.07487236,  0.35656603,
         0.82256116]]), volume_array=None, now_price_array=None, portfolio_vector=None, mean_cost_price_array=None, all_assets=0.0)

In [28]:
one_state.copy()

PortfolioState(names=['yen', '1001'], key_currency_index=0, window=[0, 1, 2, 3, 4, 5], datetime=datetime.datetime(2021, 1, 1, 0, 0), price_array=array([[-0.76663687, -2.5236359 , -1.04007341,  0.69908806, -0.22193732,
        -1.2725906 ],
       [-1.69763469,  0.9293035 , -0.74064379, -1.07487236,  0.35656603,
         0.82256116]]), volume_array=array([[ 0.25869216,  0.65497567,  0.83698721,  1.23511275,  0.39501303,
        -0.40306677],
       [ 0.88188049,  0.9236073 , -1.02428213, -0.2670186 , -0.36258674,
        -0.38758421]]), now_price_array=array([0.01626379, 0.4702189 ]), portfolio_vector=array([0.1655847, 0.8344153]), mean_cost_price_array=array([ 0.15074982, -0.26328573]), all_assets=0.0)

In [29]:
PortfolioState.nan_check = True
PortfolioState.length_check = True
PortfolioState.portfoio_check = True

In [30]:
one_state = PortfolioState(names=["yen", "1001"],
                           key_currency_index=0,
                           window=[0,1,2,3,4,5],
                           datetime=datetime.datetime(2021,1,1,0,0,0),
                           price_array=np.random.randn(2,6),
                           volume_array=np.random.randn(2,6),
                           now_price_array=np.random.randn(2),
                           portfolio_vector=softmax(np.random.randn(2)),
                           mean_cost_price_array=np.random.randn(2),
                           all_assets=0.0
                          )

## データの供給クラス 

In [31]:
class PriceSuppliier(metaclass=ABCMeta):
    """
    PriceSupplierの基底クラス．このインターフェースを実装していればよい
    """
    @abstractmethod
    def reset(self, start_datetime, window):
        pass
    
    @abstractmethod
    def step(self):
        pass

今回は株価データベースを用いて価格を供給する．

In [32]:
class StockDBPriceSupplier(PriceSuppliier):
    """
    StockDatabaseに対応するPriceSupplier
    """
    def __init__(self, stock_db, ticker_names, episode_length, freq_str, interpolate=True):
        """
        stock_db: get_stock_price.StockDatabase
            利用するStockDatabaseクラス
        ticker_names: list of str
            利用する銘柄名のリスト
        episode_length: int
            エピソードの長さ
        freq_str: str
            利用する周期を表す文字列
        interpolate: bool
            補間するかどうか
        """
        self.stock_db = stock_db
        self.ticker_names = ticker_names
        self.episode_length = episode_length
        self.freq_str = freq_str        
        self.interpolate = interpolate
        
    def reset(self, start_datetime, window=np.array([0])):
        """
        Parameters
        ----------
        start_datetime: datetime.datetime 
            データ供給の開始時刻
        window: ndarray
            データ供給のウィンドウ
        
        Returns
        -------
        DatasupplyUnit
            提供する価格データ
        bool
            エピソードが終了したかどうか
        """
        # 終了時刻を求める
        # 全datetimeデータを保持
        assert 0 in window
        if not isinstance(window, np.ndarray):
            self.window = np.array(window)
        else:
            self.window = window
            
        if isinstance(self.ticker_names, np.ndarray):
            self.ticker_names = self.ticker_names.tolist()
        
        if len(self.ticker_names)!=len(set(self.ticker_names)):
            raise Exception("Ticker_names is duplicate")
        
        min_window = min(self.window)
        max_window = max(self.window)
        
        if min_window <= 0:
            episode_start_datetime = get_previous_workday_intraday_datetime(start_datetime, self.freq_str, abs(min_window))
        else:
            episode_start_datetime = get_next_workday_intraday_datetime(start_datetime, self.freq_str, min_window)
            
        if self.episode_length+max_window <= 0:  # 基本的にあり得ない
            episode_end_datetime = get_previous_workday_intraday_datetime(start_datetime, self.freq_str, abs(self.episode_length+max_window))
        else:
            episode_end_datetime = get_next_workday_intraday_datetime(start_datetime, self.freq_str, self.episode_length+max_window)
        
        episode_df = self.stock_db.search_span(stock_names=self.ticker_names,
                                               start_datetime=episode_start_datetime,
                                               end_datetime=episode_end_datetime,
                                               freq_str=self.freq_str,
                                               is_end_include=True,  # 最後の値も含める
                                               to_tokyo=True,  #必ずTrueに
                                              )
        
        if episode_df is None:
            raise CannotGetAllDataError("This ticker names not in stockdb")
        
        if len(self.ticker_names)*5!=len(episode_df.columns):
            err_str = "Cannot get dataframe from stockdb."
            raise CannotGetAllDataError(err_str)
        
        self.episode_df = py_workdays.extract_workdays_intraday_jp(episode_df)
        
        # 各OHLCVに対応するboolを求めておく
        column_names_array = self.episode_df.columns.values.astype(str)
        self.open_bool_array = np.char.startswith(column_names_array, "Open")
        self.high_bool_array = np.char.startswith(column_names_array, "High")
        self.low_bool_array = np.char.startswith(column_names_array, "Low")
        self.close_bool_array = np.char.startswith(column_names_array, "Close")
        self.volume_bool_array = np.char.startswith(column_names_array, "Volume")
        
        
        all_datetime_index = pd.date_range(start=episode_start_datetime,
                                           end=episode_end_datetime,
                                           freq=self.freq_str,
                                           closed="left"
                                          )
        
        self.all_datetime_index = py_workdays.extract_workdays_intraday_jp_index(all_datetime_index)
        
        # episode_dfの補間
        if self.interpolate:
            add_datetime_bool = ~self.all_datetime_index.isin(self.episode_df.index)
            
            # 補間を行う数が20%を越えた場合
            if (add_datetime_bool.sum()/len(self.all_datetime_index)) > 0.1:
                err_str = "Interpolate exceeds 10 % about tickers={}, datetimes[{},{}]".format(self.ticker_names,
                                                                                               episode_start_datetime,
                                                                                               episode_end_datetime)
                raise CannotGetAllDataError(err_str)
            
            add_datetime_index = self.all_datetime_index[add_datetime_bool]
            # Noneのdfを作成
            nan_df = pd.DataFrame(index=add_datetime_index, columns=self.episode_df.columns)

            # Noneのdfを追加
            self.episode_df = self.episode_df.append(nan_df)
            self.episode_df.sort_index(inplace=True)
            
            # np.nanの補間
            self.episode_df.interpolate(limit_direction="both",inplace=True)
        else:
            share_index_bool = self.all_datetime_index.isin(self.episode_df.index)
            self.all_datetime_index = self.all_datetime_index[share_index_bool]
        
        # dfをndarrayに変更
        self.episode_df_values = self.episode_df.values.astype(float)
        del self.episode_df
        
        self.all_datetime_index_values = self.all_datetime_index.to_pydatetime()
        del self.all_datetime_index
        
        # データが正しく取得できたかどうか
        if np.isnan(self.episode_df_values).sum() > 0:
            err_str = "PriceSupplier cannot get data about tickers={}, datetimes[{},{}]".format(self.ticker_names,
                                                                                                episode_start_datetime,
                                                                                                episode_end_datetime)
            raise CannotGetAllDataError(err_str)
        
        
        # データの取得
        self.now_index = abs(min_window)
        now_datetime = self.all_datetime_index_values[self.now_index]
        
        add_window = self.now_index + self.window
        window_data_value = self.episode_df_values[add_window,:]
        
        open_array = window_data_value[:,self.open_bool_array].T
        high_array = window_data_value[:,self.high_bool_array].T
        low_array = window_data_value[:,self.low_bool_array].T
        close_array = window_data_value[:,self.close_bool_array].T
        volume_array = window_data_value[:,self.volume_bool_array].T
        
        open_array = np.concatenate([np.ones((1, open_array.shape[1])), open_array], axis=0)
        high_array = np.concatenate([np.ones((1, high_array.shape[1])), high_array], axis=0)
        low_array = np.concatenate([np.ones((1, low_array.shape[1])), low_array], axis=0)
        close_array = np.concatenate([np.ones((1, close_array.shape[1])), close_array], axis=0)
        volume_array = np.concatenate([np.ones((1, volume_array.shape[1])), volume_array], axis=0)
        
        
        out_ticker_names = ["yen"]
        out_ticker_names.extend(self.ticker_names)
        
        out_unit = DataSupplyUnit(names=out_ticker_names,
                                  key_currency_index=0,
                                  datetime=now_datetime,
                                  window=self.window,
                                  open_array=open_array,
                                  close_array=close_array,
                                  high_array=high_array,
                                  low_array=low_array,
                                  volume_array=volume_array
                                 )
        done = False
        return out_unit, done
    
    def step(self):
        """
        Returns
        -------
        DatasupplyUnit
            提供する価格データ
        bool
            エピソードが終了したかどうか
        """
        # indexの更新
        self.now_index += 1
        now_datetime = self.all_datetime_index_values[self.now_index]
        
        add_window = self.now_index + self.window
        window_data_value = self.episode_df_values[add_window,:]
        
        open_array = window_data_value[:,self.open_bool_array].T
        high_array = window_data_value[:,self.high_bool_array].T
        low_array = window_data_value[:,self.low_bool_array].T
        close_array = window_data_value[:,self.close_bool_array].T
        volume_array = window_data_value[:,self.volume_bool_array].T
         
        open_array = np.concatenate([np.ones((1, open_array.shape[1])), open_array], axis=0)
        high_array = np.concatenate([np.ones((1, high_array.shape[1])), high_array], axis=0)
        low_array = np.concatenate([np.ones((1, low_array.shape[1])), low_array], axis=0)
        close_array = np.concatenate([np.ones((1, close_array.shape[1])), close_array], axis=0)
        volume_array = np.concatenate([np.ones((1, volume_array.shape[1])), volume_array], axis=0)
        
        out_ticker_names = ["yen"]
        out_ticker_names.extend(self.ticker_names)
        
        out_unit = DataSupplyUnit(names=out_ticker_names,
                                  key_currency_index=0,
                                  datetime=now_datetime,
                                  window=self.window,
                                  open_array=open_array,
                                  close_array=close_array,
                                  high_array=high_array,
                                  low_array=low_array,
                                  volume_array=volume_array
                                 )
        done = self.now_index >= self.episode_length
        
        return out_unit, done 

In [33]:
start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
stock_list = ["4755","9984","6701","7203","7267"]

episode_length = 50

price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                     ticker_names=stock_list,
                                     episode_length=episode_length,
                                     freq_str="5T",
                                     interpolate=False
                                    )
#data_unit, _ = price_supplier.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
data_unit, _ = price_supplier.reset(start_datetime, window=[0,1,2,3])
print(data_unit.close_array.shape)
print(data_unit)

(6, 4)
DataSupplyUnit( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
datetime=2020-11-10 09:00:00+09:00
window=[0 1 2 3]
open_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00]
 [1.1250e+03 1.1270e+03 1.1080e+03 1.1050e+03]
 [7.0040e+03 7.0720e+03 7.0190e+03 7.0040e+03]
 [5.7300e+03 5.7000e+03 5.7300e+03 5.7100e+03]
 [7.3200e+03 7.3430e+03 7.3380e+03 7.3550e+03]
 [2.9200e+03 2.9595e+03 2.9420e+03 2.9710e+03]]
close_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00]
 [1.1280e+03 1.1070e+03 1.1050e+03 1.0620e+03]
 [7.0720e+03 7.0190e+03 7.0040e+03 6.9880e+03]
 [5.6800e+03 5.7200e+03 5.7100e+03 5.6900e+03]
 [7.3420e+03 7.3360e+03 7.3530e+03 7.3280e+03]
 [2.9595e+03 2.9415e+03 2.9705e+03 2.9310e+03]]
high_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00]
 [1.1320e+03 1.1290e+03 1.1100e+03 1.1050e+03]
 [7.0770e+03 7.0770e+03 7.0270e+03 7.0070e+03]
 [5.7300e+03 5.7300e+03 5.7400e+03 5.7200e+03]
 [7.3440e+03 7.3550e+03 7.3600e+03 7.3560e+03]
 [2.9635e+03 2

## ポートフォリオの制限 

学習時にはポートフォリオベクトルに制限はもうけない．バックテストや実際の運用時に単元数・基軸通貨換算資産によって制限を設ける．これは強化学習を資産の保有率を求める問題にするためである．

In [34]:
class PortfilioRestrictor(metaclass=ABCMeta):
    """
    ポートフォリオの制限を行う抽象基底クラス
    restrictメソッドをオーバーライドする必要がある
    """
    @abstractmethod
    def restrict(self, portfolio_state, supplied_data_unit, portfolio_vector):
        pass

In [35]:
class PortfolioRestrictorSingleKey(PortfilioRestrictor):
    """
    """
    def __init__(self, unit_number, key_name):
        pass
        
    def restrict(self, portfolio_state, supplied_data_unit, portfolio_vector):
        pass

In [36]:
class PortfolioRestrictorIdentity(PortfilioRestrictor):
    """
    portfolioの恒等写像を行うPortfolioRestrictor
    """
    def restrict(self, portfolio_state, supplied_data_unit, portfolio_vector):
        return portfolio_vector

## 取引手数料の計算クラス 

In [37]:
class FeeCalculator(metaclass=ABCMeta):
    """
    手数料を計算する抽象基底クラス
    """
    @abstractmethod
    def calculate(self, pre_portfolio_state, new_portfolio_state):
        pass

In [38]:
class FeeCalculatorPerNumber(FeeCalculator):
    """
    取引個数に応じて手数料を計算するFeeCalculator
    """
    def __init__(self, fee_per_number):
        self.fee_per_number = fee_per_number
    def calculate(self, pre_portfolio_state, new_portfolio_state):
        not_key_currency_indices_list = list(range(len(pre_portfolio_state.names)))
        not_key_currency_indices_list.remove(pre_portfolio_state.key_currency_index)
        
        not_key_currency_indices = np.array(not_key_currency_indices_list) 
        commition_fee = self.fee_per_number*np.abs((new_portfolio_state.numbers[not_key_currency_indices] - pre_portfolio_state.numbers[not_key_currency_indices])).sum()
        return commition_fee

##  ポートフォリオの遷移クラス  

強化学習の環境だけでなく，バックテスト等でも利用できるように汎用的なもの

In [39]:
class PortfolioTransformer:
    """
    price_supplierの提供するデータに応じてPortfolioStateを遷移させるクラス
    バックテスト・強化学習のどちらでも使えるようにする．
    """
    def __init__(self, 
                 price_supplier, 
                 portfolio_restrictor=PortfolioRestrictorIdentity(), 
                 use_ohlc="Close", 
                 initial_portfolio_vector=None,
                 initial_mean_cost_price_array=None,
                 initial_all_assets=None, 
                 fee_calculator=FeeCalculatorPerNumber(fee_per_number=1e-3)):
        """
        price_supplier: PriceSupplier
            価格データを供給するクラス
        portfolio_restrictor: PortfolioRestrictor
            エージェントが渡すportfolio_vectorを制限するクラス
        use_ohlc: str, defalt:'Close'
            利用する価格データの指定
        initial_portfolio_vector: any, defalt:None
            初期ポートフォリオベクトル
        fee_calculator: FeeCalculator
            手数料を計算するクラス
        """
        self.price_supplier = price_supplier
        self.portfolio_restrictor = portfolio_restrictor
        self.initial_portfolio_vector = initial_portfolio_vector
        self.initial_mean_cost_price_array = initial_mean_cost_price_array
        self.initial_all_assets = initial_all_assets
        self.fee_calculator = fee_calculator
    
        # 利用するohlcのいずれか
        if use_ohlc not in {"Open","High","Low","Close"}:
            raise Exception("use_ohlc must be in {'Open','High','Low','Close'}")
        
        field_name_dict = {"Open":"open_array",
                           "Close":"close_array",
                           "Low":"low_array",
                           "High":"high_array"
                          }
            
        self.use_ohlc_filed = field_name_dict[use_ohlc]
        
    def reset(self, start_datetime, window=[0]):
        """
        Parameters
        ----------
        start_datetime: datetime.datetime 
            データ供給の開始時刻
        window: ndarray
            データ供給のウィンドウ
            
        Returns
        -------
        PortfolioStat
             ポートフォリオ状態
        bool
            エピソードが終了したかどうか
        """
        #from IPython.core.debugger import Pdb; Pdb().set_trace()
        
        
        initial_data_unit, done = self.price_supplier.reset(start_datetime, window)
    
        now_price_bool = initial_data_unit.window==0 
        now_price_array = getattr(initial_data_unit, self.use_ohlc_filed)[:,now_price_bool].squeeze()
    
        # 初期パラメータ―のデフォルト値
        if self.initial_portfolio_vector is None:
            self.initial_portfolio_vector = np.zeros(len(initial_data_unit.names))
            self.initial_portfolio_vector[initial_data_unit.key_currency_index] = 1.0
            
        else:
            assert len(initial_data_unit.names) == len(self.initial_portfolio_vector)
            if abs(self.initial_portfolio_vector.sum() - 1.0) > 1.e-5:  # 大体1ならOK
                raise PortfolioVectorInvalidError("initial portfolio vector sum must be 1. This portfolio vector is {}.\n This sum is {}".format(self.initial_portfolio_vector,
                                                                                                                               self.initial_portfolio_vector.sum()))
            
        if self.initial_mean_cost_price_array is None:
            self.initial_mean_cost_price_array = now_price_array
        else:
            assert self.initial_mean_cost_price_array.shape[0] == now_price_array.shape[0]
            
        if self.initial_all_assets is None:
            self.initial_all_assets = 1.e6            
        
        # PortfoliioStateの作成
        self.portfolio_state = PortfolioState(names=initial_data_unit.names,
                                              key_currency_index=initial_data_unit.key_currency_index,
                                              window=initial_data_unit.window,
                                              datetime=initial_data_unit.datetime,
                                              price_array=getattr(initial_data_unit, self.use_ohlc_filed),
                                              volume_array=initial_data_unit.volume_array,
                                              now_price_array=now_price_array,
                                              portfolio_vector=self.initial_portfolio_vector,
                                              mean_cost_price_array=now_price_array,
                                              all_assets=self.initial_all_assets
                                             )
        
        
        return self.portfolio_state.copy(), done
    
    def step(self, action):
        """
        Parameters
        ----------
        action: ndarray
            エージェントが渡すポートフォリオベクトル
            
        Returns
        -------
        PortfolioStat
             ポートフォリオ状態
        bool
            エピソードが終了したかどうか
        """
        
        if not isinstance(action, np.ndarray):
            action = np.array(action)
        assert (action<0).sum() == 0 and (action>1).sum() == 0
        if abs(action.sum() - 1.0) > 1.e-5:  # 大体1ならOK
            raise PortfolioVectorInvalidError("action sum must be 1. This action is {}.\n This sum is {}".format(action, action.sum()))
            
        #from IPython.core.debugger import Pdb; Pdb().set_trace()
        
        previous_portfolio_state = self.portfolio_state
        supplied_data_unit, done = self.price_supplier.step()
        
        if len(action)!=len(supplied_data_unit.names) or len(action.shape)!=1:
            raise PortfolioVectorInvalidError("Action dimmention must be names({})".format(len(supplied_data_unit.names)))
        
        restricted_portfolio_vector = self.portfolio_restrictor.restrict(previous_portfolio_state, supplied_data_unit, action)
        
        # 全資産の変化率を求める
        now_price_bool = supplied_data_unit.window==0
        now_price_array = getattr(supplied_data_unit, self.use_ohlc_filed)[:,now_price_bool].squeeze()
        
        price_change_ratio = now_price_array / previous_portfolio_state.now_price_array
        
        all_assets_change_ratio = np.dot(restricted_portfolio_vector, price_change_ratio)
        all_assets = previous_portfolio_state.all_assets * all_assets_change_ratio
        
        # 平均取得価格を設ける
        new_numbers = all_assets*restricted_portfolio_vector/now_price_array
        pre_numbers = previous_portfolio_state.numbers
        mean_num = pre_numbers*previous_portfolio_state.now_price_array + (new_numbers - pre_numbers) * now_price_array
        mean_den = new_numbers
        
        new_numbers_near_zero_bool = new_numbers < 1  # 取り合えず1以下の場合
        mean_num[new_numbers_near_zero_bool] = 1  # 適当に1にしておく
        mean_den[new_numbers_near_zero_bool] = 1  # 適当に1にしておく
        
        mean_cost_price_array = mean_num / mean_den
        mean_cost_price_array[new_numbers_near_zero_bool] = now_price_array[new_numbers_near_zero_bool]
        
        self.portfolio_state = PortfolioState(names=supplied_data_unit.names,
                                              key_currency_index=supplied_data_unit.key_currency_index,
                                              window=supplied_data_unit.window,
                                              datetime=supplied_data_unit.datetime,
                                              price_array=getattr(supplied_data_unit, self.use_ohlc_filed),
                                              volume_array=supplied_data_unit.volume_array,
                                              now_price_array=now_price_array,
                                              portfolio_vector=restricted_portfolio_vector,
                                              mean_cost_price_array=mean_cost_price_array,
                                              all_assets=all_assets
                                             )
                 
        # 手数料の計算と更新
        all_fee = self.fee_calculator.calculate(previous_portfolio_state, self.portfolio_state)
        self.portfolio_state = self.portfolio_state._replace(all_assets=all_assets-all_fee)   
        
        return self.portfolio_state.copy(), done
        

In [40]:
transformer = PortfolioTransformer(price_supplier=price_supplier,
                                   portfolio_restrictor=PortfolioRestrictorIdentity(),
                                   use_ohlc="Close",
                                   initial_portfolio_vector=None,
                                   initial_all_assets=1e6,
                                   fee_calculator=FeeCalculatorPerNumber(0.01)
                                  )

portfolio_state, _ = transformer.reset(start_datetime, window=[-3,-2,-1,0,1,2,3])
print(portfolio_state)

PortfolioState( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
window=[-3 -2 -1  0  1  2  3]
datetime=2020-11-10 09:00:00+09:00
price_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1200e+03 1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03
  1.0620e+03]
 [7.0690e+03 7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03
  6.9880e+03]
 [5.7600e+03 5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03
  5.6900e+03]
 [7.1750e+03 7.1770e+03 7.1770e+03 7.3420e+03 7.3360e+03 7.3530e+03
  7.3280e+03]
 [2.8400e+03 2.8395e+03 2.8340e+03 2.9595e+03 2.9415e+03 2.9705e+03
  2.9310e+03]]
volume_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [7.1600e+04 9.3300e+04 4.1190e+05 8.7840e+05 5.8380e+05 2.5710e+05
  7.9410e+05]
 [2.3750e+05 3.0980e+05 3.7840e+05 2.5031e+06 8.4220e+05 7.3700e+05
  4.2530e+05]
 [4.8200e+04 1.5600e+04 4.2800e+04 2.6540e+05 7.0900e+04 5.5200e+04
  7.5200e+0

In [41]:
new_portfolio_vector = [0,1,0,0,0,0]
portfolio_state, _ = transformer.step(new_portfolio_vector)
print(portfolio_state)

PortfolioState( 
names=['yen', '4755', '9984', '6701', '7203', '7267']
key_currency_index=0
window=[-3 -2 -1  0  1  2  3]
datetime=2020-11-10 09:05:00+09:00
price_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [1.1190e+03 1.1170e+03 1.1280e+03 1.1070e+03 1.1050e+03 1.0620e+03
  1.0760e+03]
 [7.0740e+03 7.0960e+03 7.0720e+03 7.0190e+03 7.0040e+03 6.9880e+03
  6.9570e+03]
 [5.7600e+03 5.7600e+03 5.6800e+03 5.7200e+03 5.7100e+03 5.6900e+03
  5.6700e+03]
 [7.1770e+03 7.1770e+03 7.3420e+03 7.3360e+03 7.3530e+03 7.3280e+03
  7.3170e+03]
 [2.8395e+03 2.8340e+03 2.9595e+03 2.9415e+03 2.9705e+03 2.9310e+03
  2.9170e+03]]
volume_array=[[1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00 1.0000e+00
  1.0000e+00]
 [9.3300e+04 4.1190e+05 8.7840e+05 5.8380e+05 2.5710e+05 7.9410e+05
  4.6700e+05]
 [3.0980e+05 3.7840e+05 2.5031e+06 8.4220e+05 7.3700e+05 4.2530e+05
  4.1350e+05]
 [1.5600e+04 4.2800e+04 2.6540e+05 7.0900e+04 5.5200e+04 7.5200e+04
  7.9000e+0

高速化できる部分は無い

## ポートフォリオの遷移を可視化 

In [42]:
start_datetime = jst_timezone.localize(datetime.datetime(2020,11,10,9,0,0))
stock_list = ["4755","9984","6701","7203","7267"]
episode_length = 100

price_supplier = StockDBPriceSupplier(stock_db=stock_db,
                                     ticker_names=stock_list,
                                     episode_length=episode_length,
                                     freq_str="5T",
                                     interpolate=False
                                    )

transformer = PortfolioTransformer(price_supplier=price_supplier,
                                   portfolio_restrictor=PortfolioRestrictorIdentity(),
                                   use_ohlc="Close",
                                   initial_portfolio_vector=None,
                                   initial_all_assets=1e6,
                                   fee_calculator=FeeCalculatorPerNumber(0)
                                  )


portfolio_state_list = []
initial_state, _ = transformer.reset(start_datetime, window=[-1,0,1])
portfolio_state_list.append(initial_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))

while True:
    action = softmax(np.abs(np.random.randn(1+len(stock_list))))
    portfolio_state, done = transformer.step(action)
    portfolio_state_list.append(portfolio_state.partial("names", "now_price_array", "mean_cost_price_array", "portfolio_vector", "all_assets", "datetime"))
    if done:
        break

In [43]:
def make_y_limit(y_array, upper_ratio=0.1, lowwer_ratio=0.1):
    min_value = np.amin(y_array)
    max_value = np.amax(y_array)
    diff = max_value - min_value
    return min_value-lowwer_ratio*diff, max_value+upper_ratio*diff

In [44]:
def make_y_limit_multi(y_arrays, upper_ratio=0.1, lowwer_ratio=0.1):
    min_values = []
    max_values = []
    for y_array in y_arrays:
        min_values.append(np.amin(y_array))
        max_values.append(np.amax(y_array))
        
    min_value = min(min_values)
    max_value = max(max_values)
    diff = max_value - min_value
    
    return min_value-lowwer_ratio*diff, max_value+upper_ratio*diff

In [45]:
def make_ticker_text(ticker_value_array, ticker_names):
    div_text = ""
    text_sum_line = 150
    text_sum_count = 0

    for i, ticker_name in enumerate(ticker_names):
        div_text += ticker_name + "="
        text_sum_count += len(ticker_name)
        ticke_value_str = str(ticker_value_array[i])
        div_text += ticke_value_str
        text_sum_count += len(ticke_value_str)

        div_text += ", "
        text_sum_count += 2

        if text_sum_count > text_sum_line:
            div_text += "\n"
            text_sum_count = 0
            
    return div_text

ここはメインの開発場所ではない，プロトタイプ版

In [46]:
def visualize_portfolio_transform_bokeh(portfolio_state_list, save_path=None, is_save=False, is_show=True, is_jupyter=True):
    # テータの取り出し
    ticker_names = portfolio_state_list[0].names
    colors = d3["Category20"][len(ticker_names)]

    all_price_array = np.stack([one_state.now_price_array for one_state in portfolio_state_list], axis=1)
    all_portfolio_vector = np.stack([one_state.portfolio_vector for one_state in portfolio_state_list], axis=1)
    all_mean_cost_price_array = np.stack([one_state.mean_cost_price_array for one_state in portfolio_state_list], axis=1)
    all_assets_array = np.array([one_state.all_assets for one_state in portfolio_state_list])
    all_datetime_array = np.array([get_naive_datetime_from_datetime(one_state.datetime) for one_state in portfolio_state_list])
    x = np.arange(0, len(portfolio_state_list))


    # sorceの作成
    portfolio_vector_source = {"x":x, "datetime":all_datetime_array}
    price_source_x = []
    price_source_y = []

    mean_cost_price_source_x = []
    mean_cost_price_source_y = []

    for i, ticker_name in enumerate(ticker_names):
        portfolio_vector_source[ticker_name] = all_portfolio_vector[i,:]

        price_source_x.append(x)
        price_source_y.append(all_price_array[i,:]/all_price_array[i,0])

        mean_cost_price_source_x.append(x)
        mean_cost_price_source_y.append(all_mean_cost_price_array[i,:]/all_mean_cost_price_array[i,0])

    # ホバーツールの設定
    #tool_tips = [("x", "@x")]
    tool_tips = [("datetime", "@datetime{%F %H:%M:%S}")]
    tool_tips.extend([(ticker_name, "@"+ticker_name+"{0.000}") for ticker_name in ticker_names])

    hover_tool = HoverTool(
        tooltips=tool_tips,
        formatters={'@datetime' : 'datetime'}
    )

    # 描画

    p1_text = Div(text=make_ticker_text(all_price_array[:,0], ticker_names))

    p1 = bokeh.plotting.figure(plot_width=1200,plot_height=500,title="正規化価格・ポートフォリオ")
    p1.add_tools(hover_tool)

    p1.extra_y_ranges = {"portfolio_vector": Range1d(start=0, end=3)}
    p1.add_layout(LinearAxis(y_range_name="portfolio_vector"), 'right')
    p1.vbar_stack(ticker_names, x='x', width=1, color=colors,y_range_name="portfolio_vector", source=portfolio_vector_source, legend_label=ticker_names, alpha=0.8)

    p1.multi_line(xs=price_source_x, ys=price_source_y, line_color=colors, line_width=2)
    y_min, y_max = make_y_limit_multi(price_source_y, lowwer_ratio=0.1, upper_ratio=0.1)
    y_min -= (y_max - y_min) * 0.66  #  ポートフォリオ割合のためのオフセット
    p1.y_range = Range1d(start=y_min, end=y_max)

    p1.yaxis[0].axis_label = "正規化価格"
    p1.yaxis[1].axis_label = "保有割合"

    p1.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    p2_text = Div(text=make_ticker_text(all_mean_cost_price_array[:,0], ticker_names))

    p2 = bokeh.plotting.figure(plot_width=1200,plot_height=300,title="正規化平均取得価格・全資産")
    p2.multi_line(xs=mean_cost_price_source_x, ys=mean_cost_price_source_y, line_color=colors, line_width=2)
    y_min, y_max = make_y_limit_multi(mean_cost_price_source_y, lowwer_ratio=0.1, upper_ratio=0.1)
    p2.y_range = Range1d(start=y_min, end=y_max)

    y_max, y_min = make_y_limit(all_assets_array, upper_ratio=0.1, lowwer_ratio=0.1)
    p2.extra_y_ranges = {"all_assets": Range1d(start=y_max, end=y_min)}
    p2.add_layout(LinearAxis(y_range_name="all_assets"), 'right')
    p2.line(x, all_assets_array, color="red", legend_label="all_assets", line_width=4, y_range_name="all_assets")

    # 疑似的なレジェンドをつける
    for ticker_name, color in zip(ticker_names, colors):
        p2.line([], [], legend_label=ticker_name, color=color, line_width=2)

    p2.yaxis[0].axis_label = "正規化平均取得価格"
    p2.yaxis[1].axis_label = "全資産 [円]"

    p2.xaxis.major_label_overrides = {str(one_x) : str(all_datetime_array[i]) for i, one_x in enumerate(x)}

    created_figure = bokeh.layouts.column(p1_text, p1, p2_text, p2)

    if is_save:
            if save_path.suffix == ".png":
                bokeh.io.export_png(created_figure, filename=save_path)
            elif save_path.suffix == ".html":
                output_file(save_path)
                bokeh.io.save(created_figure, filename=save_path, title="trading process")    
            else:
                raise Exception("The suffix of save_path is must be '.png' or '.html'.")
    if is_show:
        try:
            reset_output()
            if is_jupyter:
                output_notebook()
            show(created_figure)
        except:
            if is_jupyter:
                output_notebook()
            show(created_figure)
        
    if not is_save and not is_show:
        raise Exception("is_save and is_show is False. This function do nothing")
        

In [47]:
visualize_portfolio_transform_bokeh(portfolio_state_list, save_path=Path("visualization/trade_transform.png"), is_save=False, is_show=True)