In [1]:
from sqlite3 import paramstyle
# %matplotlib inline
# %matplotlib widget
# %matplotlib notebook
from matplotlib import style
import backtrader as bt
import pandas as pd
import datetime
import os

class FeedData(bt.feeds.GenericCSVData):
    lines=("chg","turnover",)
    params = (
        ("fromdate",datetime.datetime(2015,1,1)),
        ("todate",datetime.datetime(2020,1,1)),
        ('nullvalue', float('NaN')),
        ('dtformat', '%Y-%m-%d'),
        # ('tmformat', '%H:%M:%S'),
        ('datetime', 1),
        ('name', 2),
        ('code', 3),
        ('time', -1),
        ('open', 4),
        ("close",5),
        ('high', 6),
        ('low', 7),
        ('volume', 8),
        ('turnover', 9),
        ('amplitude', 10),
        ('chg', 11),
        ('turnover', 13),
        ('openinterest', -1),
    )


class TestStrategy(bt.Strategy):
 
    def log(self, txt, dt=None):
        ''' 提供记录功能'''
        dt = dt or self.datas[0].datetime.date(0)
        print('%s, %s' % (dt.isoformat(), txt))
 
    def __init__(self):
        self.res = pd.read_csv("./res.csv")
        # self.res["date"]= pd.to_datetime(self.res["date"])
        self.ind = {}
        self.order = None
        for i,d in enumerate(self.datas):
            self.ind[d]=bt.indicators.MACD(d)
        
    def next(self):
        print("\n\n","*"*20,len(self),"*"*20)
        date = str(self.datas[0].datetime.date(0))
        info = self.res[self.res["date"]==date]
        buy_list = [str(int(i)) for i in info["code"]]
        print(info,buy_list)

        if len(self)<26:
            return
        # print(self._trades(d).barlen())

        for d in self.datas:
            
            if d._name in buy_list:
                self.log(f'BUY {d._name}, {d.close[0]}' )
                self.order = self.buy(data=d)
            elif self.getposition(d).size:
                self.order = self.sell(data=d)
                self.log(f'SEll {d._name}, %.2f' % d.close[0])
        
           
    
    def notify_trade(self, trade):
        if not trade.isclosed:
            return
        self.log(f'执行标的：{trade.getdataname()}，策略收益：毛收益 {trade.pnl:.2f}, 净收益 {trade.pnlcomm:.2f},')


    def notify_order(self, order):
        if order.status in [order.Submitted, order.Accepted]:
            # Buy/Sell order submitted/accepted to/by broker - Nothing to do
            return
        
 
        # Check if an order has been completed
        # Attention: broker could reject order if not enough cash
        if order.status in [order.Completed]:
            if order.isbuy():
                self.log(
                    'BUY EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                    (order.executed.price,
                     order.executed.value,
                     order.executed.comm))
 
                self.buyprice = order.executed.price
                self.buycomm = order.executed.comm
            else:  # Sell
                self.log('SELL EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                         (order.executed.price,
                          order.executed.value,
                          order.executed.comm))
 
            self.bar_executed = len(self)
 
 
        elif order.status in [order.Canceled, order.Margin, order.Rejected]:
            self.log('Order Canceled/Margin/Rejected')
 
        self.order = None


if __name__ == '__main__':
    root = "./data"
    cerebro = bt.Cerebro()
    cerebro.addstrategy(TestStrategy)
    #获取数据
    files = os.listdir(root)
    from efinance.stock import get_members
    data = get_members("上证50")["股票代码"]
    
    for file in data:
#         if not file.endswith("csv"):
#             continue
        file = file+".csv"
#         print(file)
        path = os.path.join(root,file)
        if not os.path.exists(path):
            continue
        data = FeedData(dataname=path,encoding="utf-8") # 加载数据
        cerebro.adddata(data,name=file[:-4])  # 将数据传入回测系统
    cerebro.broker.set_coc(True)
    cerebro.broker.setcash(100000.0)
    cerebro.broker.setcommission(commission=0.00015)
    cerebro.addsizer(bt.sizers.PercentSizer, percents=10)
    print('Starting Portfolio Value: %.2f' % cerebro.broker.getvalue())
    
    cerebro.addobserver(bt.observers.DrawDown)
    # cerebro.addobserver(bt.observers.Benchmark, data=data)
    # cerebro.addobserver(bt.observers.Broker)

    cerebro.run()
    cerebro.plot(
        style="candel",
        plotdist=0.1,
        barup = '#ff9896', bardown='#98df8a',
        volup='#ff9896', voldown='#98df8a',
        grid=False
    )
    print('Final Portfolio Value: %.2f' % cerebro.broker.getvalue())

Starting Portfolio Value: 100000.00


 ******************** 34 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 35 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 36 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 37 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 38 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 39 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 40 ********************
Empty DataFrame
Columns: [date, code, res]
Index: [] []


 ******************** 41 ********************
         date      code       res
0  2015-03-09  600346.0  0.688283
1  2015-03-09  600588.0  0.668870
2  2015-03-09  600690.0  0.666879
3  2015-03-09  601888.0  0.658530
4  2015-03-09  603

ValueError: Axis limits cannot be NaN or Inf