In [1]:
import baostock as bs
import pandas as pd
import talib as ta
import matplotlib.pyplot as plt
import BaoStockUtil
import math

import datetime

from RSI import DayRSI,WeekRSI,MonthRSI
from Stock import Stock
import dbutil
import KlineService

from IPython.core.debugger import set_trace

RSI_OVER_BUY = 80
RSI_OVER_SELL = 20
RSI_OVER_BUY_12 = 75
RSI_OVER_SELL_12 = 25
RSI_OVER_BUY_24 = 70
RSI_OVER_SELL_24 = 30
RSI_MIDDLE = 50

In [2]:
def findLatestRSIDate(period):
    mydb = dbutil.connectDB()
    collection = mydb[chooseRSICollection(period)]
    cursor = collection.find().sort("date",-1).limit(1)
    df =  pd.DataFrame(list(cursor))
    if df.empty:
        return "1970-01-01"
    return df["date"][0]

def clearRSI(period):
    mydb = dbutil.connectDB()
    collection = mydb[chooseRSICollection(period)]
    collection.delete_many()

##
#  从数据库读指定日期RSI数据
#
#
def readRSI(period, stockCode, specifiedDate):
    mydb = dbutil.connectDB()
    collection = mydb[chooseRSICollection(period)]
    if type(specifiedDate) == str:
        specifiedDate = datetime.datetime.strptime(specifiedDate, "%Y-%m-%d")
    cursor = collection.find({"code":stockCode,"date":specifiedDate})
    df =  pd.DataFrame(list(cursor))
    return df
    
##
#  写RSI数据库
#
#
def writeRSIToDb(period, stockCode, stockName, rsi_df):
    dataList = []
    for index,rsi in rsi_df.iterrows():
        rsiDate = rsi['date']
        rsiDF = readRSI(period, stockCode, rsiDate)
        #非空说明库里有数据，没法重复插入，调过
        if not rsiDF.empty:
            continue
        if period == "day":
            rsiObj = DayRSI(stockCode, stockName)
        elif period == "week":
            rsiObj = WeekRSI(stockCode, stockName)
        elif period == "month":
            rsiObj = MonthRSI(stockCode, stockName)
        elif period == "5m":
            rsiObj = FiveMinRSI(stockCode, stockName)
        elif period == "15m":
            rsiObj = FiftyMinRSI(stockCode, stockName)
        elif period == "30m":
            rsiObj = ThirtyMinRSI(stockCode, stockName)
        elif period == "60m":
            rsiObj = SixtyMinRSI(stockCode, stockName)

        rsiObj.date = rsiDate
        rsiObj.rsi_6 = rsi['rsi_6']
        rsiObj.rsi_12 = rsi['rsi_12']
        rsiObj.rsi_24 = rsi['rsi_24']
        rsiObj.overBuy = rsi['overBuyFlag']
        rsiObj.overSell = rsi['overSellFlag']
        
        dataList.append(rsiObj.__dict__)
        
    mydb = dbutil.connectDB()
    collection = mydb[chooseRSICollection(period)]
    if len(dataList) > 0:
        collection.insert_many(dataList)
    else:
        raise RuntimeError("RSI数据为空")

##
#  选择不同的Kline Collection
#
def chooseRSICollection(period):
    periodRSICollection = {
        "day" : "RSI_Day",
        "week" : "RSI_Week",
        "month" : "RSI_Month",
        "5m" : "RSI_5m",
        "15m" : "RSI_15m",
        "30m" : "RSI_30m",
        "60m" : "RSI_60m"
    }
    return periodRSICollection.get(period)


def computeRSI(klineDataFrame):
    # 剔除停盘数据
    klineDataFrame = klineDataFrame[klineDataFrame['tradeStatus'] == '1']
    rsi_12days = ta.RSI(klineDataFrame['closePrice'],timeperiod=12)
    rsi_6days = ta.RSI(klineDataFrame['closePrice'],timeperiod=6)
    rsi_24days = ta.RSI(klineDataFrame['closePrice'],timeperiod=24)
    
    rsiFrame = pd.DataFrame(klineDataFrame, columns=["date"])
    rsiFrame['rsi_6'] = rsi_6days
    rsiFrame['rsi_12'] = rsi_12days
    rsiFrame['rsi_24'] = rsi_24days
    ##添加参考线位置
    rsiFrame['overBuy'] = RSI_OVER_BUY
    rsiFrame['overSell'] = RSI_OVER_SELL
    rsiFrame['middle'] = RSI_MIDDLE

    # RSI超卖和超买
    rsi_buy_position = rsiFrame['rsi_12'] > RSI_OVER_BUY_12
    rsi_sell_position = rsiFrame['rsi_12'] < RSI_OVER_SELL_12
    rsiFrame.loc[rsi_buy_position[(rsi_buy_position == True) & (rsi_buy_position.shift() == False)].index, 'overBuyFlag'] = 'Yes'
    rsiFrame.loc[rsi_sell_position[(rsi_sell_position == True) & (rsi_sell_position.shift() == False)].index, 'overSellFlag'] = 'Yes'
    return rsiFrame

##
#  计算自起始日期起的RSI
#
#
def computeAllRSIDataOfPeriod(period, startDate):
    clearRSI(period)
    stockDict = KlineService.allStocks()
    endDate = str(datetime.date.today())
    processCount = 0
    failCount = 0
    for key,stock in stockDict.items():
        processCount = processCount + 1
        try:
            df = KlineService.readStockKline(key, period, startDate, endDate)
            rsi_df = computeRSI(df)
            writeRSIToDb(period, key, stock["name"], rsi_df)
        except BaseException as e:
            failCount = failCount + 1
            print ("download " + key + " error:" + str(e))
        
        if processCount % 100 == 0 and processCount > 0:
            print ("download process:", processCount, " of ", len(stockDict) ," failed:", failCount)
    return True

##
#  计算指定日期的RSI
#
#
def computeAllRSIData(period, specifiedDateStr):
    BaoStockUtil.customLogin()
    specifiedDate = datetime.datetime.strptime(specifiedDateStr, "%Y-%m-%d")
    today = datetime.date.today()
    #如果把时间设成未来，自动调成今天
    if specifiedDate > datetime.datetime.today():
        specifiedDate = datetime.date.today()
    #避免跨年问题，直接从去年开始取
    startDate = specifiedDate - datetime.timedelta(days = 365)
    #取交易日列表，用作倒推周期使用
    rs = bs.query_trade_dates(start_date=datetime.datetime.strftime(startDate, "%Y-%m-%d"), end_date = specifiedDate)
    BaoStockUtil.customLogout()
    if rs.error_code != '0':
        raise RuntimeError("交易日api调用失败了:" + rs.error_code)
    tradeDates = []
    while (rs.error_code == '0') & rs.next():
        row = rs.get_row_data()
        if row[1] == "1":
            tradeDates.append(row[0])
    if len(tradeDates) == 0:
        raise RuntimeError("取不到最新的交易日")
    
    #检查起始日期的RSI算好没,如果起始日期比库里的最近RSI时间还要大
    #说明有数据缺漏，要从库里的最近RSI时间起算，补齐缺漏数据
    rsiLatestDate = findLatestRSIDate(period)
    rsiLatestDateStr = datetime.datetime.strftime(rsiLatestDate, "%Y-%m-%d")
#     rsiLatestDate = datetime.datetime.strptime(rsiLatestDateStr, "%Y-%m-%d")
    
    #若期望计算的日期比库里最新日期还早，无需更新
    if rsiLatestDate > specifiedDate:
        raise RuntimeError(specifiedDateStr + " 的 " + period + " RSI的计算已经完成，无需重新计算")
    
    #找到指定日期以及rsi存量数据最近日期在交易日周期的序号
    latestDateIndex = -1
    specifiedDateIndex = -1
    for i in range(0, len(tradeDates)):
        currDate = tradeDates[i]
        if latestDateIndex > 0 and specifiedDateIndex > 0:
            break
        if currDate == rsiLatestDateStr:
            latestDateIndex = i
        elif currDate == specifiedDateStr:
            specifiedDateIndex = i
            
    
    #rsi最新日期在取值周期内存在，说明库里有数据
    #否则该周期的数据为空，需要全部重算
    dateDiff = -1
    if latestDateIndex > 0:
        #指定日期也在取值周期内存在，计算两者差的日期
        #日期差1天以上，说明中间有缺漏
        if specifiedDateIndex > 0:
            dateDiff = specifiedDateIndex - latestDateIndex
    else:
        #重算就从数据库里的K线数据最早日期开始
        return computeAllRSIDataOfPeriod(period, "2017-01-01")
    
    processCount = 0
    failCount = 0
    startDateIndex = -1
    
    set_trace()
    if dateDiff > 1:
        daysBefore = computeRSIDataStartTradeDateRange(period, rsiLatestDateStr)
        startDateIndex = latestDateIndex - daysBefore
    else:
        daysBefore = computeRSIDataStartTradeDateRange(period, specifiedDateStr)
        startDateIndex = specifiedDateIndex - daysBefore
    
    #起始日期index负数，说明rsi数据不够，直接从K线起始日重算
    if startDateIndex < 0:
        #重算就从数据库里的K线数据最早日期开始
        return computeAllRSIDataOfPeriod(period, "2017-01-01")
    
    startDateStr = tradeDates[startDateIndex]
    print("compute rsi tradeDates from ", startDateStr, "to", specifiedDateStr)
    return computeAllRSIDataOfPeriod(period, startDateStr)
#     stockDict = KlineService.allStocks()
    

#算出计算本周期下指定数据需要的起始交易日
#每个交易日一共4小时，所以取4小时为一天，而不是24小时
#每个计算周期一共至少需要24个节点，分钟线RSI统一除以4*60=240分钟算出所需计算数据天数，最少为一天
#日线不用除分钟
## TODO 周线没想好怎么算，更别说月线了。
def computeRSIDataStartTradeDateRange(period, specifiedDate):
    daysBefore = 0
    if period.endswith("m"):
        daysBefore = math.ceil(24 * (int(period.replace("m", "")) + 1) / (60 * 4))
    elif period == "day":
        daysBefore = 24
    else:
        raise RuntimeError("周期有误")
    return daysBefore

In [3]:
# downloadAllKlineDataOfSingleDay("2019-09-24")

# downloadAllKlineDataOfPeriod("day", "2017-01-01")
# downloadAllStocks("2019-09-23")
# df = allStocks()

computeAllRSIDataOfPeriod("day", "2017-01-01")

# computeAllRSIData("day", "2019-09-27")

login success!
logout success!
> [0;32m<ipython-input-2-5ec54f1913f5>[0m(194)[0;36mcomputeAllRSIData[0;34m()[0m
[0;32m    192 [0;31m[0;34m[0m[0m
[0m[0;32m    193 [0;31m    [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 194 [0;31m    [0;32mif[0m [0mdateDiff[0m [0;34m>[0m [0;36m1[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    195 [0;31m        [0mdaysBefore[0m [0;34m=[0m [0mcomputeRSIDataStartTradeDateRange[0m[0;34m([0m[0mperiod[0m[0;34m,[0m [0mrsiLatestDateStr[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    196 [0;31m        [0mstartDateIndex[0m [0;34m=[0m [0mlatestDateIndex[0m [0;34m-[0m [0mdaysBefore[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> c
compute rsi tradeDates from  2019-08-20 to 2019-09-27
download process: 100  of  4201  failed: 0
download process: 200  of  4201  failed: 0
download sh.600074 error:inputs are all NaN
download process: 300  of  4201  failed: 1
download sh.600145 er

True