In [None]:
# 在本机未更新至2.0.8的情况下，使用vnpy2.0.8的回测逻辑
import sys
from pathlib import Path
new_version_path = Path(r'D:\vnpy-2.0.8')
# new_version_path = Path(r'E:\vnpy\vnpy-2.0.8')
sys.path.insert(0, str(new_version_path))
# sys.path

import vnpy
print(vnpy.__version__)

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
plt.style.use('ggplot')

from datetime import datetime
from vnpy.app.cta_strategy.backtesting import BacktestingEngine, OptimizationSetting
from boll_channel_strategy import BollChannelStrategy
from utility import (vt_trade_to_df, load_data)

### 设置回测参数和策略

In [None]:
test_symbol = 'RB888.SHFE'
test_interval = '1h'
test_start = datetime(2018, 1, 1)
test_end = datetime(2019, 12, 1)

engine = BacktestingEngine()
engine.set_parameters(
    vt_symbol=test_symbol,
    interval=test_interval,
    start=test_start,
    end=test_end,
    rate=0,
    slippage=0,
    size=10,
    pricetick=1,
    capital=100000,
)
engine.add_strategy(BollChannelStrategy, {})

### 运行回测输出统计结果

In [None]:
# 运行回测
engine.load_data()
engine.run_backtesting()

# 保存成交记录
trades = engine.get_all_trades()
trade_df = vt_trade_to_df(trades)
trade_df.to_csv(f'{engine.vt_symbol}_trade_continuous.csv')

# 统计成交结果并保存
pnl_df = engine.calculate_result()
pnl_df.to_csv(f'{engine.vt_symbol}_pnl_continuous.csv')
engine.calculate_statistics()

### 绘制回测结果

In [None]:
stats_df = engine.daily_df

%matplotlib notebook
plt.figure(figsize=(10, 16))


balance_plot = plt.subplot(4, 1, 1)
balance_plot.set_title("Balance")
balance_plot.plot(stats_df["balance"].index, stats_df["balance"].values)

drawdown_plot = plt.subplot(4, 1, 2)
drawdown_plot.set_title("DrawdownPercent")
drawdown_plot.fill_between(stats_df["ddpercent"].index, stats_df['ddpercent'].values, facecolor='green', alpha=0.5)

pnl_plot = plt.subplot(4, 1, 3)
pnl_plot.set_title("Daily Pnl")
stats_df["net_pnl"].plot(legend=False)

distribution_plot = plt.subplot(4, 1, 4)
distribution_plot.set_title("Daily Pnl Distribution")
stats_df["net_pnl"].hist(bins=50)

plt.tight_layout() 

### 绘制成交点

In [None]:
long_df = trade_df[trade_df.direction == '多']
short_df = trade_df[trade_df.direction == '空']

# 计算成交点x,y序列
price = load_data(test_symbol, test_interval, test_start, test_end)
price['id'] = range(len(price))

long_x = price.loc[long_df.index].id.values
long_y = long_df.price.values
short_x = price.loc[short_df.index].id.values
short_y = short_df.price.values

# 绘制成交点位
%matplotlib notebook
price2 = price.reset_index()
axe = price2.close.plot(figsize=(12, 6), color='#a8a8a8', zorder=10)
axe.scatter(long_x, long_y, color='r', marker=6, zorder=20)
axe.scatter(short_x, short_y, color='g', marker=7, zorder=30)
# plt.show()