In [54]:
import uuid
import numpy as np
import pandas as pd
from collections import OrderedDict
from functools import lru_cache 
from datetime import date, datetime, timedelta
from zenquant.trader.database import get_database
from zenquant.trader.constant import Interval
from zenquant.trader.object import  OrderData, TradeData, BarData, TickData
from zenquant.ctastrategy.base import (
    STOPORDER_PREFIX,
    StopOrder,
    StopOrderStatus,
    INTERVAL_DELTA_MAP
)
from zenquant.trader.constant import (
    Status,
    Direction,
    Offset,
    Exchange
)
import gym
from zenquant.feed.data import BarDataFeed,TickDataFeed 
from zenquant.feed.portfolio import PortfolioDataStream,NetPortfolioDataStream
from zenquant.env.action import ContinueAction 
from zenquant.env.observer import Observer
from zenquant.env.reward import Reward 
from zenquant.env.stopper import Stopper
from zenquant.env.informer import Informer
from zenquant.env.renender import BarRenderer
from zenquant.utils.get_indicators_info import (
    get_bar_level_indicator_info,
    get_tick_level_indicator_info
)

class ContinueEnv(gym.Env):
    """
    A trading environment made for use with Gym-compatible reinforcement
    learning algorithms with continue actions.
    Parameters
    ----------
    """
    def __init__(self):
        self.clock_step = 0 
        self.if_discrete = False
        self.agent_id = 0    ## updated by agent 
        self.env_num  = 0    ## updated by agent 
        self.target_return = 10 
        self.env_name = "ContinueEnv"
        self.episode_id = 0
        self.tick: TickData
        self.bar: BarData
        self.datetime = None
        self.last_price = 6000  ##division by zero
        self.interval = None
        self.min_step = 250

        self.history_data = []
        #history 
        self.history_action = 0
        self.history_pnl = 0

        self.stop_order_count = 0
        self.stop_orders = {}
        self.active_stop_orders = {}

        self.limit_order_count = 0
        self.limit_orders = {}
        self.active_limit_orders = {}

        self.trade_count = 0
        self.trades = OrderedDict()
    def on_init(self,**kwargs):
        '''
        init basic components  of environments 
        read data and load data 
        '''
        ##parameters for environments

        self.gateway_name = kwargs.get("gateway_name","CryptoBacktest")
        self.mode = kwargs.get('mode',"bar") 
        self.vt_symbol = kwargs.get("vt_symbol", "BTCUSDT.BINANCE")
        self.interval = Interval( kwargs.get("interval","1m"))
        self.min_step = kwargs.get("min_step",250)
        self.symbol = kwargs.get("symbol", "BTC/USDT")
        self.exchange= kwargs.get("exchange",Exchange.BINANCE) 

        self.start = kwargs.get("start", datetime(2021, 9, 1))
        self.end = kwargs.get("end", datetime.now())
        ##parmaeters for environments' components
        ##portfolio
        self.MarginLevel= kwargs.get("MarginLevel", 1)
        self.risk_free = kwargs.get("risk_free", 0)
        self.capital =kwargs.get("capital", 100000)
        self.commission_rate = kwargs.get("commission_rate",0.0)
        self.slippage_rate = kwargs.get("slippage_rate ",0.0)
        ##Action 
        self.action_dim = kwargs.get("action_dim",1) 
        self.pricetick = kwargs.get("pricetick", 0.01)
        self.min_volume = kwargs.get("min_volume", 0.001)
        self.min_trade_balance = kwargs.get("min_trade_balance", 5)
        self.limit_total_margin_rate = kwargs.get("limit_total_margin_rate", 0.5)
        self.available_change_percent= kwargs.get("available_change_percent", 0.5)
        self.skip_mode = kwargs.get("skip_mode", "sma") 
        self.sma_window = kwargs.get("sma_window", 10)
        self.atr_window = kwargs.get("atr_window", 14) 
        self.boll_window = kwargs.get("boll_window", 18)
        self.boll_dev = kwargs.get("boll_dev", 3.4)
        self.holding_pos_mode = kwargs.get("holding_pos_mode", "net")
        self.use_stop = kwargs.get("use_stop", False)
        ##Observer 
        self.pos_info_scale = kwargs.get("pos_info_scale", 2**-7)
        self.indicator_info_scale = kwargs.get("indicator_info_scale", 2**-8)
        self.history_action_scale = kwargs.get("history_action_scale", 2**-7)
        self.history_pnl_scale = kwargs.get("history_pnl_scale", 2**-8)
        self.state_dim= kwargs.get("state_dim", 3)
        self.windows_size = kwargs.get("windows_size", 5) 
        self.indicator_windows_list = kwargs.get("indicator_windows_list",[10,20,40,80])
        ##Rewarder 
        self.lag_window = kwargs.get("lag_window", 5)  
        self.extra_reward = kwargs.get("extra_reward", 0.001) 
        self.survive_reward_scale = kwargs.get("survive_reward_scale", 0.001) 
        self.reward_mode = kwargs.get("reward_mode", "differential_sharpe_ratio") 
        ##Stopper  and Informer
        self.max_allowed_loss = kwargs.get("max_allowed_loss", 0.05) 
        ##traning params
        self.profit_stop_rate = kwargs.get("profit_stop_rate", 1)
        self.loss_stop_rate = kwargs.get("loss_stop_rate", -0.5) 
        ##Renender(by **kwargs)
        
        ##load data  to history_data
        self.load_data() 
        self.max_step= len(self.history_data) -1 
        #update datafeed and update indicators' info 
        if self.mode == "bar":
            self.datafeed = BarDataFeed(len(self.history_data))
            for idx,bar in enumerate(self.history_data):
                self.datafeed.update_by_index(idx,bar) 
            self.indicator_array = get_bar_level_indicator_info(self.datafeed,self.indicator_windows_list)
            self.atr_array = self.datafeed.atr(self.atr_window,array = True)
            self.sma_array = self.datafeed.sma(self.sma_window,array = True)
            self.boll_up,self.boll_down = self.datafeed.boll(self.boll_window,self.boll_dev,array = True)
        elif self.mode == "tick":
            self.datafeed = TickDataFeed(len(self.history_data))
            for idx,tick  in enumerate(self.history_data):
                self.datafeed.update_by_index(idx,tick) 
            self.indicator_array = get_tick_level_indicator_info(self.datafeed,self.indicator_windows_list)
            self.atr_array = self.datafeed.atr(self.atr_window,array = True)
            self.sma_array = self.datafeed.sma(self.sma_window,array = True)
            self.boll_up,self.boll_down = self.datafeed.boll(self.boll_window,self.boll_dev,array = True)
        else:
            raise NotImplementedError  
        ##创建组件

        if self.holding_pos_mode == "net":
            self.state_dim= len(self.indicator_array) + 5
        else:
            self.state_dim= len(self.indicator_array) + 9
            ##component
        if self.holding_pos_mode == "net":
            self.portfolio = NetPortfolioDataStream(self)
        else:
            self.portfolio = PortfolioDataStream(self)
        self.action = ContinueAction(self)
        self.observer = Observer(self.state_dim,self.windows_size)
        self.rewarder = Reward(reward_mode=self.reward_mode)
        self.stopper = Stopper(self.max_allowed_loss)
        self.informer = Informer() 
        self.renderer = BarRenderer()

        ##check if min_step is ok
        self.indicator_info = np.array([item[self.min_step] for item in self.indicator_array])
        while np.isnan(self.indicator_info).any():
            self.min_step += 1
            self.indicator_info = np.array([item[self.min_step] for item in self.indicator_array])
        ##update  to min_step 

        self.clock_step = self.min_step
        self.portfolio.clock_step = self.min_step
        self.action.clock_step = self.min_step
        self.observer.clock_step = self.min_step
        self.rewarder.clock_step = self.min_step
        self.stopper.clock_step = self.min_step
        self.informer.clock_step = self.min_step
        if self.mode == "bar":
            self.last_price = self.datafeed.close_array[self.clock_step]
        elif self.mode == "tick":
            self.last_price = self.datafeed.last_price_array[self.clock_step]
        self.portfolio.occupy_rate = 0 
        if self.holding_pos_mode == "net":
            self.portfolio.pos_occupy_rate = 0
            self.portfolio.pos_avgprice = self.last_price 
            self.pos_info =  np.array([self.portfolio.occupy_rate,
            abs(self.portfolio.pos)>self.min_volume,
            1.0-self.portfolio.pos_avgprice/self.last_price])
        else:
            self.portfolio.long_pos_occupy_rate = 0
            self.portfolio.short_pos_occupy_rate = 0 
            self.portfolio.long_pos_avgprice = self.last_price 
            self.portfolio.short_pos_avgprice = self.last_price
            self.pos_info = np.array([self.portfolio.long_pos_occupy_rate ,
                                  self.portfolio.short_pos_occupy_rate,
                                  self.portfolio.occupy_rate,     #long+short+locked 
                                  self.portfolio.long_pos>self.min_volume,
                                  self.portfolio.short_pos>self.min_volume,
                                  1.0-self.portfolio.long_pos_avgprice/self.last_price,
                                  self.portfolio.short_pos_avgprice/self.last_price-1.0])   

        self.indicator_info = np.array([item[self.clock_step] for item in self.indicator_array])
        ## update info  for agent 
        self.pos_info = self.pos_info * self.pos_info_scale 
        self.pos_info = np.hstack([self.pos_info,self.history_action,self.history_pnl])
        self.indicator_info = self.indicator_info * self.indicator_info_scale
        self.init_observation = self.observer.observe(self.indicator_info,self.pos_info).reshape((-1,)) 
        
        ## update info  for agent 
        self.observation_space = self.observer.observation_space
        self.action_space = self.action.action_space 
    def load_data(self):
        """"""
        self.output("开始加载历史数据")

        if not self.end:
            self.end = datetime.now()

        if self.start >= self.end:
            self.output("起始日期必须小于结束日期")
            return
        self.history_data.clear()       # Clear previously loaded history data

        # Load 30 days of data each time and allow for progress update
        total_days = (self.end - self.start).days
        progress_days = max(int(total_days / 10), 1)
        progress_delta = timedelta(days=progress_days)
        interval_delta = INTERVAL_DELTA_MAP[self.interval]

        start = self.start
        end = self.start + progress_delta
        progress = 0

        while start < self.end:
            progress_bar = "#" * int(progress * 10 + 1)
            self.output(f"加载进度：{progress_bar} [{progress:.0%}]")

            end = min(end, self.end)  # Make sure end time stays within set range

            if self.mode == "bar":
                data = load_bar_data(
                    self.symbol,
                    self.exchange,
                    self.interval,
                    start,
                    end
                )
            else:
                data = load_tick_data(
                    self.symbol,
                    self.exchange,
                    start,
                    end
                )

            self.history_data.extend(data)

            progress += progress_days / total_days
            progress = min(progress, 1)

            start = end + interval_delta
            end += progress_delta

        self.output(f"历史数据加载完成，数据量：{len(self.history_data)}")

    def step(self):
         ##create pd
        df = pd.DataFrame([])
        df["time"]=self.datafeed.datetime_array
        df["high"]=self.datafeed.high
        df["low"]=self.datafeed.low
        df["open"]=self.datafeed.open
        df["close"]=self.datafeed.close
        df["volume"]=self.datafeed.volume
        df["atr"] = self.atr_array 
        df=self.create_label(df)
        for i in range(len(self.indicator_array)):
            df[str(i)]=self.indicator_array[i]
        return  df
        
        
        
    def create_label(self,df):
        zigzags = []
        ATR_MULTIPILIER=2
        def calc_change_since_pivot(row, key):
            current = row[key]
            last_pivot = zigzags[-1]["Value"]
            if(last_pivot == 0): last_pivot = 1 ** (-100) # avoid division by 0
            perc_change_since_pivot = (current - last_pivot) / abs(last_pivot)
            return perc_change_since_pivot

        def get_zigzag(row, taip=None):
            if(taip == "Peak"): key = "high"
            elif(taip == "Trough"): key = "low"
            else: key = "close"

            return {
                "Time": row["time"],
                "Value": row[key],
                "Type": taip 
            }
        for ix, row in df.iterrows():
            threshold = row['atr'] / row["open"] * ATR_MULTIPILIER
            # handle first point
            is_starting = ix == 0
            if(is_starting):
                zigzags.append(get_zigzag(row))
                continue

            # handle first line
            is_first_line = len(zigzags) == 1
            if(is_first_line):
                perc_change_since_pivot = calc_change_since_pivot(row, "close")

                if(abs(perc_change_since_pivot) >= threshold):
                    if(perc_change_since_pivot > 0):
                        zigzags.append(get_zigzag(row, "Peak"))
                        zigzags[0]["Type"] = "Trough"
                    else: 
                        zigzags.append(get_zigzag(row, "Trough"))
                        zigzags[0]["Type"] = "Peak"
                continue
    
            # handle other lines
            is_trough = zigzags[-2]["Value"] > zigzags[-1]["Value"]
            is_ending = ix == len(df.index) - 1
            last_pivot = float(zigzags[-1]["Value"])
            # based on last pivot type, look for reversal or continuation
            if(is_trough):
                perc_change_since_pivot = calc_change_since_pivot(row, "high")
                is_reversing = (perc_change_since_pivot >= threshold) or is_ending
                is_continuing = row["low"] <= last_pivot
                if (is_continuing): 
                    zigzags[-1] = get_zigzag(row, "Trough")
                elif (is_reversing): 
                    zigzags.append(get_zigzag(row, "Peak"))
            else:
                perc_change_since_pivot = calc_change_since_pivot(row, "low")
                is_reversing = (perc_change_since_pivot <= -threshold) or is_ending
                is_continuing = row["high"] >= last_pivot
                if(is_continuing): 
                    zigzags[-1] = get_zigzag(row, "Peak")
                elif (is_reversing): 
                    zigzags.append(get_zigzag(row, "Trough"))
        zigzags = pd.DataFrame(zigzags)
        zigzags["PrevExt"] = zigzags.Value.shift(2)
        df=zigzags.merge(df,left_on="Time",right_on="time",how="right")
        df.Type = df.Type.map({"Trough":1,"Peak":1})
        df.Type=df.Type.replace(np.nan,0)
        df["PrevExt"] = df["PrevExt"].fillna(method='ffill')
        df["target"] = df["PrevExt"]/df["close"]
        return df
    def new_bar(self, bar: BarData):
        """
        撮合订单，并更新portfolio
        """
        self.bar = bar
        self.datetime = bar.datetime

        self.cross_limit_order()
        self.cross_stop_order()
       

    def new_tick(self, tick: TickData):
        """
        撮合订单，并更新portfolio
        """
        self.tick = tick
        self.datetime = tick.datetime

        self.cross_limit_order()
        self.cross_stop_order()
    def cross_limit_order(self):
        """
        Cross limit order with last bar/tick data.
        """
        if self.mode == "bar":
            long_cross_price = self.bar.low_price
            short_cross_price = self.bar.high_price
            long_best_price = self.bar.open_price
            short_best_price = self.bar.open_price
        else:
            long_cross_price = self.tick.ask_price_1
            short_cross_price = self.tick.bid_price_1
            long_best_price = long_cross_price
            short_best_price = short_cross_price

        for order in list(self.active_limit_orders.values()):
            # Push order update with status "not traded" (pending).
            if order.status == Status.SUBMITTING:
                order.status = Status.NOTTRADED

            # Check whether limit orders can be filled.
            long_cross = (
                order.direction == Direction.LONG
                and order.price >= long_cross_price
                and long_cross_price > 0
            )

            short_cross = (
                order.direction == Direction.SHORT
                and order.price <= short_cross_price
                and short_cross_price > 0
            )

            if not long_cross and not short_cross:
                continue

            # Push order udpate with status "all traded" (filled).
            order.traded = order.volume
            order.status = Status.ALLTRADED


            if order.vt_orderid in self.active_limit_orders:
                self.active_limit_orders.pop(order.vt_orderid)


            if long_cross:
                trade_price = min(order.price, long_best_price)
                pos_change = order.volume
            else:
                trade_price = max(order.price, short_best_price)
                pos_change = -order.volume

            trade = TradeData(
                symbol=order.symbol,
                exchange=order.exchange,
                orderid=order.orderid,
                tradeid=str(self.trade_count),
                direction=order.direction,
                offset=order.offset,
                price=trade_price,
                volume=order.volume,
                datetime=self.datetime,
                gateway_name=self.gateway_name,
            )

            if self.update_portfolio(trade): ##检查trade有效性，如果有效则更新portfolio
                self.trade_count += 1
                self.trades[trade.vt_tradeid] = trade

    def cross_stop_order(self):
        """
        Cross stop order with last bar/tick data.
        """
        if self.mode == "bar":
            long_cross_price = self.bar.high_price
            short_cross_price = self.bar.low_price
            long_best_price = self.bar.open_price
            short_best_price = self.bar.open_price
        else:
            long_cross_price = self.tick.last_price
            short_cross_price = self.tick.last_price
            long_best_price = long_cross_price
            short_best_price = short_cross_price

        for stop_order in list(self.active_stop_orders.values()):
            # Check whether stop order can be triggered.
            long_cross = (
                stop_order.direction == Direction.LONG
                and stop_order.price <= long_cross_price
            )

            short_cross = (
                stop_order.direction == Direction.SHORT
                and stop_order.price >= short_cross_price
            )

            if not long_cross and not short_cross:
                continue

            # Create order data.
            self.limit_order_count += 1

            order = OrderData(
                symbol=self.symbol,
                exchange=self.exchange,
                orderid=str(self.limit_order_count),
                direction=stop_order.direction,
                offset=stop_order.offset,
                price=stop_order.price,
                volume=stop_order.volume,
                traded=stop_order.volume,
                status=Status.ALLTRADED,
                gateway_name=self.gateway_name,
                datetime=self.datetime
            )

            self.limit_orders[order.vt_orderid] = order

            # Create trade data.
            if long_cross:
                trade_price = max(stop_order.price, long_best_price)
                pos_change = order.volume
            else:
                trade_price = min(stop_order.price, short_best_price)
                pos_change = -order.volume


            trade = TradeData(
                symbol=order.symbol,
                exchange=order.exchange,
                orderid=order.orderid,
                tradeid=str(self.trade_count),
                direction=order.direction,
                offset=order.offset,
                price=trade_price,
                volume=order.volume,
                datetime=self.datetime,
                gateway_name=self.gateway_name,
            )


            # Update stop order.
            stop_order.vt_orderids.append(order.vt_orderid)
            stop_order.status = StopOrderStatus.TRIGGERED

            if stop_order.stop_orderid in self.active_stop_orders:
                self.active_stop_orders.pop(stop_order.stop_orderid)


            if self.update_portfolio(trade): ##检查trade有效性，如果有效则更新portfolio
                self.trade_count += 1
                self.trades[trade.vt_tradeid] = trade
    def update_portfolio(self,trade):
        '''
        检查trade有效性，并更新portfolio
        '''
        ##限制一下trade_volume 小数点计算问题
        if trade.offset == Offset.CLOSE:
            if self.holding_pos_mode == "net":
                trade.volume = min(trade.volume,abs(self.portfolio.pos))
            else:
                if trade.direction == Direction.LONG: 
                    trade.volume = min(trade.volume,self.portfolio.short_pos)
                ##平多
                elif trade.direction == Direction.SHORT:
                    trade.volume = min(trade.volume,self.portfolio.long_pos)
        trade_effect = True 
        self.portfolio.update_by_trade(trade)
        return trade_effect

    def reset(self):
        """
        reset variables and start another backtesting
        """
        self.episode_id = str(uuid.uuid4())
        self.clock_step = 0 
        self.tick: TickData
        self.bar: BarData
        self.datetime = None
        self.last_price = 60000   
        #history 
        self.history_action = [0]
        self.history_pnl = [0] 

        self.stop_order_count = 0
        self.stop_orders = {}
        self.active_stop_orders = {}

        self.limit_order_count = 0
        self.limit_orders = {}
        self.active_limit_orders = {}

        self.trade_count = 0
        self.trades = OrderedDict()

        self.portfolio.reset()
        self.action.reset()
        self.observer.reset()
        self.rewarder.reset()
        self.stopper.reset()
        self.informer.reset()
        self.renderer.reset()

        
        self.clock_step = self.min_step
        self.portfolio.clock_step = self.min_step
        self.action.clock_step = self.min_step
        self.observer.clock_step = self.min_step
        self.rewarder.clock_step = self.min_step
        self.stopper.clock_step = self.min_step
        self.informer.clock_step = self.min_step
        if self.mode == "bar":
            self.last_price = self.datafeed.close_array[self.clock_step]
        elif self.mode == "tick":
            self.last_price = self.datafeed.last_price_array[self.clock_step]
        self.portfolio.occupy_rate = 0 
        if self.holding_pos_mode == "net":
            self.portfolio.pos_occupy_rate = 0
            self.portfolio.pos_avgprice = self.last_price 
            self.pos_info =  np.array([self.portfolio.occupy_rate,
            abs(self.portfolio.pos)>self.min_volume,
            1.0-self.portfolio.pos_avgprice/self.last_price])
        else:
            self.portfolio.long_pos_occupy_rate = 0
            self.portfolio.short_pos_occupy_rate = 0 
            self.portfolio.long_pos_avgprice = self.last_price 
            self.portfolio.short_pos_avgprice = self.last_price
            self.pos_info = np.array([self.portfolio.long_pos_occupy_rate ,
                                  self.portfolio.short_pos_occupy_rate,
                                  self.portfolio.occupy_rate,     #long+short+locked 
                                  self.portfolio.long_pos>self.min_volume,
                                  self.portfolio.short_pos>self.min_volume,
                                  1.0-self.portfolio.long_pos_avgprice/self.last_price,
                                  self.portfolio.short_pos_avgprice/self.last_price-1.0])  
        self.indicator_info = np.array([item[self.clock_step] for item in self.indicator_array])
        ## update info  for agent 
        self.pos_info = self.pos_info * self.pos_info_scale 
        self.pos_info = np.hstack([self.pos_info,self.history_action,self.history_pnl])
        self.indicator_info = self.indicator_info * self.indicator_info_scale
        self.init_observation = self.observer.observe(self.indicator_info,self.pos_info).reshape((-1,)) 
        ## update info  for agent 
        self.action_space = self.action.action_space 
        ## 返回开始计算的min_step时刻Observe的状态
        return self.init_observation

    def render(self, **kwargs) -> None:
        """Renders the environment."""
        self.renderer.render(self, **kwargs)

    def save(self) -> None:
        """Saves the rendered view of the environment."""
        self.renderer.save()

    def close(self) -> None:
        """Closes the environment."""
        self.renderer.close()
    def output(self, msg) -> None:
        """
        Output message of backtesting engine.
        """
        print(f"{datetime.now()}\t{msg}") 
@lru_cache(maxsize=999)
def load_bar_data(
    symbol: str,
    exchange: Exchange,
    interval: Interval,
    start: datetime,
    end: datetime
):
    """"""
    database = get_database()

    return database.load_bar_data(
        symbol, exchange, interval, start, end
    )


@lru_cache(maxsize=999)
def load_tick_data(
    symbol: str,
    exchange: Exchange,
    start: datetime,
    end: datetime
):
    """"""
    database = get_database()

    return database.load_tick_data(
        symbol, exchange, start, end
    )



In [55]:


test_env= ContinueEnv()
config={
    "gateway_name":"CryptoContinue",
    'mode':"bar",
    "vt_symbol":"BTCUSDT.BINANCE",
    "interval":"1m",
    "symbol": "BTCUSDT",
    "exchange":Exchange.BINANCE,
    "min_step":100,
    "start":datetime(2021, 12, 1),
    "end":datetime(2021, 12, 19),
    "MarginLevel":10,
    "risk_free":0,
    "capital":10000,
    "commission_rate":0.0004,
    "slippage_rate":0,
    "pricetick": 0.01,
    "min_volume":0.001,
    "min_trade_balance":5,
    "limit_total_margin_rate":0.8,
    "available_change_percent":0.2,
    "skip_mode":"",
    "sma_window":10,
    "atr_window":14,
    "boll_window":20,
    "boll_dev":1.8,
    "holding_pos_mode":"net",
    "use_stop":False,
    "pos_info_scale ":1,
    "indicator_info_scale":10**-2,
    "history_action_scale ":1,
    "history_pnl_scale":1,
    "windows_size": 1,
    "indicator_windows_list":[12,48,168],
    "lag_window":20,
    "extra_reward":0,
        "reward_mode":'differential_sharpe_ratio',
    "max_allowed_loss":0.5,
    "loss_stop_rate":-0.3,
    ##DQN params 
    "learning_rate":2**-15,
    "batch_size": 2**11,
    "gamma":  0.97,
    "seed":312,
    "net_dim": 2**9,
    "worker_num":4,
    "reward_scale":1,
    "target_step": 10000, #collect target_step, then update network
    "eval_gap": 30  #used for evaluate, evaluate the agent per eval_gap seconds
}


In [56]:
test_env.on_init(**config)

2021-12-26 14:57:14.284902	开始加载历史数据
2021-12-26 14:57:14.285047	加载进度：# [0%]
2021-12-26 14:57:14.392712	加载进度：# [6%]
2021-12-26 14:57:14.488931	加载进度：## [11%]
2021-12-26 14:57:14.596787	加载进度：## [17%]
2021-12-26 14:57:14.698225	加载进度：### [22%]
2021-12-26 14:57:14.798983	加载进度：### [28%]
2021-12-26 14:57:14.900235	加载进度：#### [33%]
2021-12-26 14:57:14.997087	加载进度：#### [39%]
2021-12-26 14:57:15.098618	加载进度：##### [44%]
2021-12-26 14:57:15.197935	加载进度：###### [50%]
2021-12-26 14:57:15.289647	加载进度：###### [56%]
2021-12-26 14:57:15.388858	加载进度：####### [61%]
2021-12-26 14:57:15.675413	加载进度：####### [67%]
2021-12-26 14:57:15.772977	加载进度：######## [72%]
2021-12-26 14:57:15.868338	加载进度：######## [78%]
2021-12-26 14:57:15.968319	加载进度：######### [83%]
2021-12-26 14:57:16.066709	加载进度：######### [89%]
2021-12-26 14:57:16.160630	加载进度：########## [94%]
2021-12-26 14:57:16.242568	历史数据加载完成，数据量：25647


In [57]:
df= test_env.step()

In [58]:
df.head()

Unnamed: 0,Time,Value,Type,PrevExt,time,high,low,open,close,volume,...,64,65,66,67,68,69,70,71,72,73
0,2021-12-01,57232.39,1.0,,2021-12-01 00:00:00,57235.01,57198.76,57205.42,57232.39,73.878,...,,,,,,,-0.000471,4.6e-05,-0.000588,73.878
1,NaT,,0.0,,2021-12-01 00:01:00,57247.2,57140.96,57232.39,57148.04,156.242,...,,,,,,,0.001476,0.001735,-0.000124,156.242
2,NaT,,0.0,,2021-12-01 00:02:00,57195.47,57116.19,57149.41,57194.86,90.869,...,,,,,,,-0.000795,1.1e-05,-0.001375,90.869
3,NaT,,0.0,,2021-12-01 00:03:00,57227.72,57128.09,57194.86,57160.84,88.416,...,,,,,,,0.000595,0.00117,-0.000573,88.416
4,NaT,,0.0,,2021-12-01 00:04:00,57170.74,57111.0,57160.84,57135.0,59.97,...,,,,,,,0.000452,0.000626,-0.00042,59.97


In [59]:
df=df.dropna()

In [60]:
df.columns.to_list()

['Time',
 'Value',
 'Type',
 'PrevExt',
 'time',
 'high',
 'low',
 'open',
 'close',
 'volume',
 'atr',
 'target',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30',
 '31',
 '32',
 '33',
 '34',
 '35',
 '36',
 '37',
 '38',
 '39',
 '40',
 '41',
 '42',
 '43',
 '44',
 '45',
 '46',
 '47',
 '48',
 '49',
 '50',
 '51',
 '52',
 '53',
 '54',
 '55',
 '56',
 '57',
 '58',
 '59',
 '60',
 '61',
 '62',
 '63',
 '64',
 '65',
 '66',
 '67',
 '68',
 '69',
 '70',
 '71',
 '72',
 '73']

In [61]:
features_columns = [
 'high',
 'low',
 'open',
 'volume',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30',
 '31',
 '32',
 '33',
 '34',
 '35',
 '36',
 '37',
 '38',
 '39',
 '40',
 '41',
 '42',
 '43',
 '44',
 '45',
 '46',
 '47',
 '48',
 '49',
 '50',
 '51',
 '52',
 '53',
 '54',
 '55',
 '56',
 '57',
 '58',
 '59',
 '60',
 '61',
 '62',
 '63',
 '64',
 '65',
 '66',
 '67',
 '68',
 '69',
 '70',
 '71',
 '72',
 '73'
]

In [66]:
import copy 
all_feature= copy.deepcopy(features_columns)
all_feature.append("Type")

In [67]:
test_df=df[all_feature]
test_df=test_df.dropna()

In [68]:
test_df["Type"]=test_df["Type"].astype(int)

In [69]:
import numpy as np
from scipy import optimize
from scipy import special

class FocalLoss:

    def __init__(self, gamma, alpha=None):
        self.alpha = alpha
        self.gamma = gamma

    def at(self, y):
        if self.alpha is None:
            return np.ones_like(y)
        return np.where(y, self.alpha, 1 - self.alpha)

    def pt(self, y, p):
        p = np.clip(p, 1e-15, 1 - 1e-15)
        return np.where(y, p, 1 - p)

    def __call__(self, y_true, y_pred):
        at = self.at(y_true)
        pt = self.pt(y_true, y_pred)
        return -at * (1 - pt) ** self.gamma * np.log(pt)

    def grad(self, y_true, y_pred):
        y = 2 * y_true - 1  # {0, 1} -> {-1, 1}
        at = self.at(y_true)
        pt = self.pt(y_true, y_pred)
        g = self.gamma
        return at * y * (1 - pt) ** g * (g * pt * np.log(pt) + pt - 1)

    def hess(self, y_true, y_pred):
        y = 2 * y_true - 1  # {0, 1} -> {-1, 1}
        at = self.at(y_true)
        pt = self.pt(y_true, y_pred)
        g = self.gamma

        u = at * y * (1 - pt) ** g
        du = -at * y * g * (1 - pt) ** (g - 1)
        v = g * pt * np.log(pt) + pt - 1
        dv = g * np.log(pt) + g + 1

        return (du * v + u * dv) * y * (pt * (1 - pt))

    def init_score(self, y_true):
        res = optimize.minimize_scalar(
            lambda p: self(y_true, p).sum(),
            bounds=(0, 1),
            method='bounded'
        )
        p = res.x
        log_odds = np.log(p / (1 - p))
        return log_odds

    def lgb_obj(self, preds, train_data):
        y = train_data.get_label()
        p = special.expit(preds)
        return self.grad(y, p), self.hess(y, p)

    def lgb_eval(self, preds, train_data):
        y = train_data.get_label()
        p = special.expit(preds)
        is_higher_better = False
        return 'focal_loss', self(y, p).mean(), is_higher_better