In [11]:
from datetime import datetime
from pathlib import Path
import sqlite3
import pandas as pd
import numpy as np
import random as rd
import backtrader as bt # 导入 Backtrader
import backtrader.indicators as btind # 导入策略分析模块
import backtrader.feeds as btfeeds # 导入数据模块

# 用来屏蔽方法内print
class hiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

# 快速连接数据库
class fast_db_conn:
    def __init__(self,db_path):
        self._db = db_path
    def __enter__(self):
        self._conn = sqlite3.connect(self._db)
        print("打开连接 {}".format(self._db))
        return self._conn
    def __exit__(self, exc_type, exc_val, exc_tb):
        self._conn.commit()
        self._conn.close()
        print("关闭连接 {}".format(self._db))


In [10]:
# 声明路径
cur_path = Path(r'.')
print("当前路径为 {}".format(cur_path.resolve()))
gitee_path = cur_path / Path(r'./../../..')    # D:/gitee
db_file_path = gitee_path / Path(r'./finance/db')     # D:/gitee/db
print("数据路径为 {}".format(db_file_path.resolve()))

当前路径为 D:\gitee\finance\jupyterFile\bt
数据路径为 D:\gitee\finance\db


In [12]:
# sql 语句
sql = {}
sql['simple-data'] = """
    select date,code,name,open,high,low,close from stock_all_221102 where date < '{end_date}' and date >= '{start_date}'
""".strip()

sql['data_spot'] = """
    select * from stock_spot 
""".strip()

sql['data_hfq'] = """
    select * from stock_hfq
""".strip()

# 取22年数据
dbname = [
    'em_stock_2022_1.db',
    'em_stock_2022_2.db',
    'em_stock_2022_3.db',
    'em_stock_2022_4.db',
]

dbpath = db_file_path / Path(dbname[0])
dbpath.resolve()

WindowsPath('D:/gitee/finance/db/em_stock_2022_1.db')

In [13]:
stock_hfq = []
# 读取数据
for i,db in enumerate(dbname):
    with fast_db_conn((db_file_path / Path(db)).resolve()) as c:
        data = pd.read_sql(sql=sql['data_hfq'], con = c, parse_dates=['d'])
        print(data.shape)
        stock_hfq.append(data)

data = pd.concat(stock_hfq, ignore_index=True)
data.head()

打开连接 D:\gitee\finance\db\em_stock_2022_1.db
(274032, 13)
关闭连接 D:\gitee\finance\db\em_stock_2022_1.db
打开连接 D:\gitee\finance\db\em_stock_2022_2.db
(282131, 13)
关闭连接 D:\gitee\finance\db\em_stock_2022_2.db
打开连接 D:\gitee\finance\db\em_stock_2022_3.db
(316312, 13)
关闭连接 D:\gitee\finance\db\em_stock_2022_3.db
打开连接 D:\gitee\finance\db\em_stock_2022_4.db
(299536, 13)
关闭连接 D:\gitee\finance\db\em_stock_2022_4.db


Unnamed: 0,d,open,close,high,low,volume,price_volume,amplitude,p_change,price_change,turnover,code,record_date
0,2022-01-04,45.43,45.7,46.58,45.19,206882,397126992.0,3.08,1.2,0.54,4.86,2222,2022-11-19
1,2022-01-05,45.23,43.68,45.55,42.73,276311,506863760.0,6.17,-4.42,-2.02,6.49,2222,2022-11-19
2,2022-01-06,43.5,43.54,43.84,42.62,128472,232326894.0,2.79,-0.32,-0.14,3.02,2222,2022-11-19
3,2022-01-07,43.72,42.51,44.24,42.4,151622,272725584.0,4.23,-2.37,-1.03,3.56,2222,2022-11-19
4,2022-01-10,42.62,43.36,44.53,42.13,175901,319923328.0,5.65,2.0,0.85,4.13,2222,2022-11-19


In [14]:
df = data.loc[:,['code','d','open','high','low','close','volume']]

col_map={"d":"datetime"}
df = df.rename(columns=col_map)
df['openinterest'] = 0

code = df.loc[:,'code'].unique()
st_date = df.datetime.min()
ed_date = df.datetime.max()
df.head()

Unnamed: 0,code,datetime,open,high,low,close,volume,openinterest
0,2222,2022-01-04,45.43,46.58,45.19,45.7,206882,0
1,2222,2022-01-05,45.23,45.55,42.73,43.68,276311,0
2,2222,2022-01-06,43.5,43.84,42.62,43.54,128472,0
3,2222,2022-01-07,43.72,44.24,42.4,42.51,151622,0
4,2222,2022-01-10,42.62,44.53,42.13,43.36,175901,0


In [15]:
%%time
cerebro = bt.Cerebro()
for c in code[:200]:
    temp = df.query(f"code=='{c}'").set_index('datetime').drop(columns=['code'])
    datafeed1 = bt.feeds.PandasData(dataname=temp, fromdate=st_date, todate=ed_date)
    cerebro.adddata(datafeed1, name=c)


CPU times: total: 12 s
Wall time: 11.9 s


In [16]:
class maCross(bt.Strategy):
    # list of parameters which are configurable for the strategy
    params = dict(
        pfast=5,  # period for the fast moving average
        pslow=15   # period for the slow moving average
    )

    def __init__(self):
        sma1 = bt.ind.SMA(period=self.p.pfast)  # fast moving average
        sma2 = bt.ind.SMA(period=self.p.pslow)  # slow moving average
        self.crossover = bt.ind.CrossOver(sma1, sma2)  # crossover signal

    def next(self):
        if not self.position:  # not in the market
            if self.crossover > 0:  # if fast crosses slow to the upside
                self.buy()  # enter long

        elif self.crossover < 0:  # in the market & cross to the downside
            self.close()  # close long position

In [22]:
bt.ind.SMA??

[1;31mInit signature:[0m [0mbt[0m[1;33m.[0m[0mind[0m[1;33m.[0m[0mSMA[0m[1;33m([0m[1;33m*[0m[0margs[0m[1;33m,[0m [1;33m**[0m[0mkwargs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m     
Non-weighted average of the last n periods

Formula:
  - movav = Sum(data, period) / period

See also:
  - http://en.wikipedia.org/wiki/Moving_average#Simple_moving_average
[1;31mFile:[0m           d:\anaconda3\envs\finance\lib\site-packages\backtrader\indicators\sma.py
[1;31mType:[0m           MetaMovAvBase
[1;31mSubclasses:[0m     


In [17]:
%%time
cerebro.addstrategy(maCross)
rasult = cerebro.run()

CPU times: total: 10.2 s
Wall time: 10.3 s


In [18]:
%%time
cerebro.plot(iplot=False)

CPU times: total: 1min 47s
Wall time: 2min 29s


[[<Figure size 960x720 with 403 Axes>]]

In [118]:
cerebro = bt.Cerebro()
# 添加数据到datafeed
for c in code[:5]:
    temp = df.query(f"code=='{c}'").set_index('datetime').drop(columns=['code'])
    datafeed1 = bt.feeds.PandasData(dataname=temp, fromdate=st_date, todate=ed_date)
    cerebro.adddata(datafeed1, name=c)
    
class TestStrategy1(bt.Strategy):
    def __init__(self):
        self.count = 0 # 用于计算 next 的循环次数
        # 打印数据集和数据集对应的名称
        print("------------- init 中的索引位置-------------")
        print("股票code为 ", self.data1._name)
        print("0 索引：",'datetime',self.data1.lines.datetime.date(0), 'close',self.data1.lines.close[0])
        print("-1 索引：",'datetime',self.data1.lines.datetime.date(-1),'close', self.data1.lines.close[-1])
        print("-2 索引",'datetime', self.data1.lines.datetime.date(-2),'close', self.data1.lines.close[-2])
        print("1 索引：",'datetime',self.data1.lines.datetime.date(1),'close', self.data1.lines.close[1])
        print("2 索引",'datetime', self.data1.lines.datetime.date(2),'close', self.data1.lines.close[2])
        print("从 0 开始往前取3天的收盘价：", self.data1.lines.close.get(ago=0, size=3))
        print("从-1开始往前取3天的收盘价：", self.data1.lines.close.get(ago=-1, size=3))
        print("从-2开始往前取3天的收盘价：", self.data1.lines.close.get(ago=-2, size=3))
        print("line的总长度：", self.data1.buflen())
        
    def next(self):
        print(f"------------- next 的第{self.count+1}次循环 --------------")
        print("当前时点（今日）：",'datetime',self.data1.lines.datetime.date(0),'close', self.data1.lines.close[0])
        print("往前推1天（昨日）：",'datetime',self.data1.lines.datetime.date(-1),'close', self.data1.lines.close[-1])
        print("往前推2天（前日）", 'datetime',self.data1.lines.datetime.date(-2),'close', self.data1.lines.close[-2])
        print("前日、昨日、今日的收盘价：", self.data1.lines.close.get(ago=0, size=3))
        print("往后推1天（明日）：",'datetime',self.data1.lines.datetime.date(1),'close', self.data1.lines.close[1])
        print("往后推2天（明后日）", 'datetime',self.data1.lines.datetime.date(2),'close', self.data1.lines.close[2])
        print("已处理的数据点：", len(self.data1))
        print("line的总长度：", self.data0.buflen())
        self.count += 1

cerebro.addstrategy(TestStrategy1)
rasult = cerebro.run()