## 导入相关库

In [1]:
# 导入相关库
import pandas as pd # 读取数据及数据处理
import numpy as np # 数据处理
import matplotlib.pyplot as plt # 绘图设置
# 我使用的 matplotlib 版本没有 finance 子模块，所以用 mpl_finance 中导入 candlestick_ohlc，如果你的 matplotlib 版本里有的话也可以从 matplotlib 里导入
from mpl_finance import candlestick_ohlc # 绘制 k 线图
from matplotlib.pylab import date2num # matplotlib 时间格式
from sqlalchemy import create_engine # 从数据库中读取数据
from matplotlib.dates import DateFormatter, MonthLocator # 时间格式
%matplotlib inline

## 读取数据

<font size=4>从数据库中读取数据的话使用这个函数，如果读取本地文件的话就直接用```pandas```来读取就好了 </font>

In [2]:
# 读取数据
# 数据集比较大，读取时间可能比较久
def readsql(table, user='blacktech1', pw='blacktech1', host='132.232.102.34', port='3306', database='anxin_heikeji'):    
    connect_info = 'mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8'.format(user, pw, host, port, database)
    engine = create_engine(connect_info)
    stock_daily_tiny = 'SELECT * FROM {}'.format(table)
    dataset = pd.read_sql(sql=table, con=engine)
    return dataset

## 绘图

In [3]:
def kplot(dataset, subdata, stockID, date, startdate=None, enddate=None):
    '''
    参数说明：
    dataset  - 提供 k 线数据的数据集(stock_daily_tiny)
    subdata  - 辅助线的数据集(cycle_low_frequency_stock)
    stockID  - 股票代码(eg:'SH600519')
    stardate - k 线的起始日期， 没有可不填，格式为'yyyy-mm-dd'(需要加引号)
    enddate  - k 线的结束日期，没有可不填,格式为'yyyy-mm-dd'(需要加引号)
    
    下列库需要在调用函数前导入环境：
    import pandas as pd
    import numpy as np
    from matplotlib.pylab import date2num
    from matplotlib.dates import DateFormatter, WeekdayLocator,DayLocator, MONDAY, MonthLocator
    import matplotlib.dates as mdates
    import matplotlib.pyplot as plt
    from mpl_finance import candlestick_ohlc
    %matplotlib inline
    '''
    
    # 1. 筛选整个数据集中相应的股票代码的信息
    k_dataset = dataset[dataset.StockID == stockID] # 日线信息数据集
    k_subdata = subdata[subdata.stockid == stockID] # 辅助线信息数据集
    
    # 2. >>>   处理时间   <<<
    # 日线集
    date_daily = pd.to_datetime(k_dataset['date']) # 如果数据集里的时间列的列名不是 date 的要修改
    date_daily = date_daily.apply(lambda d: date2num(d.to_pydatetime()))
    k_dataset['date2num'] = date_daily
    
    # 辅助集
    date_sub = pd.to_datetime(k_subdata['date'])
    date_sub = date_sub.apply(lambda d: date2num(d.to_pydatetime()))
    startdate_sub = pd.to_datetime(k_subdata['startdate'])
    startdate_sub = startdate_sub.apply(lambda d: date2num(d.to_pydatetime()))
    enddate_sub = pd.to_datetime(k_subdata['enddate'])
    enddate_sub = enddate_sub.apply(lambda d: date2num(d.to_pydatetime()))
    k_subdata['date2num'] = date_sub
    k_subdata['startdate2num'] = startdate_sub
    k_subdata['enddate2num'] = enddate_sub
    
    # 3. 筛选辅助数据集中时间为 date 的信息
    date_date2num = date2num(pd.to_datetime(date))
    k_subdata = k_subdata[k_subdata.date2num == date_date2num]
    
    # 4. >>>   处理数据集起始时间和终止时间   <<<
    if startdate:
        startdate2num = date2num(pd.to_datetime(startdate))
        k_dataset = k_dataset[k_dataset.date2num > (startdate2num-1)]
    elif k_dataset.date2num.max() - k_subdata.startdate2num.min() > 365:
        k_dataset = k_dataset[k_dataset.date2num > k_subdata.startdate2num.min()-30]
    else:
        k_dataset = k_dataset[k_dataset.date2num > (k_dataset.date2num.max() -365)] # 没有设置起始时间的话默认绘制最近五年 k 线图
    
    if enddate:
        enddate2num = date2num(pd.to_datetime(enddate))
        k_dataset = k_dataset[k_dataset.date2num < enddate2num]
        
    # 5. 设置 dohlc(date, open, high, low, close)
    dohlc = [tuple(x) for x in k_dataset[['date2num', 'open', 'high', 'low', 'close']].values] # 列明不对的也需要修改
    
    # 6.获取股票简称
    def readsql(table, user='blacktech1', pw='blacktech1', host='132.232.102.34', port='3306', database='anxin_heikeji'):    
        connect_info = 'mysql+pymysql://{}:{}@{}:{}/{}?charset=utf8'.format(user, pw, host, port, database)
        engine = create_engine(connect_info)
        stock_daily_tiny = 'SELECT * FROM {}'.format(table)
        dataset = pd.read_sql(sql=table, con=engine)
        return dataset
    info = readsql('stock_base_info')
    stock_name = info[info.stockid == stockID].stockname.values[0]

    # 7. >>>   绘图   <<<
    # 创建画布，大小为 16 * 9
    fig, ax = plt.subplots(figsize=(20, 9))
    
    #调节横坐标距离底部的长度
    fig.subplots_adjust(bottom=0.2)

    #设置横坐标标签的显示时间间隔(显示每月的第一天)
    days = MonthLocator()
    ax.xaxis.set_major_locator(days)

    #设置横坐标刻度的显示间隔(每个刻度代表一个月)
    alldays = MonthLocator()
    ax.xaxis.set_minor_locator(alldays)

    #设置横坐标标签的显示样式，为“年-月-日”
    yearFormatter = DateFormatter('%Y-%m-%d')
    ax.xaxis.set_major_formatter(yearFormatter)

    # 绘制 k 线图
    candlestick_ohlc(ax, dohlc, width=0.6,colorup='r', colordown="g")
    plt.setp(plt.gca().get_xticklabels(), rotation=45, horizontalalignment='right')

    # 辅助线  
    for i in range(len(k_subdata)):
        if k_subdata.linearity.values[i] == 0:
            plt.plot((k_subdata.startdate2num.values[i], k_subdata.enddate2num.values[i]),
                     (k_subdata.startprice.values[i], k_subdata.endprice.values[i]),
                     linestyle='--', color='black', 
                     label='{}-{}-({},{})'.format(k_subdata.startdate.values[i], k_subdata.type.values[i], int(k_subdata.shortterm.values[i]), int(k_subdata.longterm.values[i])))    
            plt.text(k_subdata.startdate2num.values[i], k_subdata.startprice.values[i], k_subdata.startdate.values[i], fontdict={'size':16})
            if k_subdata.lowerborder.values[i] != 0:
                #plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16})
            #else:
                plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i],'[{}-{}]'.format(k_subdata.lowerborder.values[i], k_subdata.upperborder.values[i]))
                #plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16})
        elif k_subdata.linearity.values[i] == 0.5:
            plt.plot((k_subdata.startdate2num.values[i], k_subdata.enddate2num.values[i]),
                     (k_subdata.startprice.values[i], k_subdata.endprice.values[i]), color='black', linestyle='--',
                     label='{}-{}-({},{})'.format(k_subdata.startdate.values[i], k_subdata.type.values[i], int(k_subdata.shortterm.values[i]), int(k_subdata.longterm.values[i])))
            plt.text(k_subdata.startdate2num.values[i], k_subdata.startprice.values[i], k_subdata.startdate.values[i], fontdict={'size':16})
            plt.scatter(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], marker='o', c='', edgecolors='black')
            if k_subdata.lowerborder.values[i] != 0:
                #plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16})
            #else:
                plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i],'[{}-{}]'.format(k_subdata.lowerborder.values[i], k_subdata.upperborder.values[i]), fontdict={'size':16, 'color':'blue'})
              #  plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16, 'color':'blue'})
        elif k_subdata.linearity.values[i] == 1:
            plt.plot((k_subdata.startdate2num.values[i], k_subdata.enddate2num.values[i]),
                     (k_subdata.startprice.values[i], k_subdata.endprice.values[i]),
                     color='black',
                     label='{}-{}-({},{})'.format(k_subdata.startdate.values[i], k_subdata.type.values[i], int(k_subdata.shortterm.values[i]), int(k_subdata.longterm.values[i])))   
            plt.text(k_subdata.startdate2num.values[i], k_subdata.startprice.values[i], k_subdata.startdate.values[i], fontdict={'size':16})
            if k_subdata.lowerborder.values[i] != 0:
                #plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16})
            #else:
                plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i],'[{}-{}]'.format(k_subdata.lowerborder.values[i], k_subdata.upperborder.values[i]), fontdict={'size':16, 'color':'blue'})
               # plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16, 'color':'blue'})
        elif k_subdata.linearity.values[i] == 2:
            plt.plot((k_subdata.startdate2num.values[i], k_subdata.enddate2num.values[i]),
                     (k_subdata.startprice.values[i], k_subdata.endprice.values[i]),
                     linestyle='--', color='#FF7F00',
                     label='{}-{}-({},{})'.format(k_subdata.startdate.values[i], k_subdata.type.values[i], int(k_subdata.shortterm.values[i]), int(k_subdata.longterm.values[i])))
            plt.text(k_subdata.startdate2num.values[i], k_subdata.startprice.values[i], k_subdata.startdate.values[i], fontdict={'size':16})
            if k_subdata.lowerborder.values[i] != 0:
                #plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16})
            #else:
                plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i],'[{}-{}]'.format(k_subdata.lowerborder.values[i], k_subdata.upperborder.values[i]), fontdict={'size':16, 'color':'blue'})
               # plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16, 'color':'blue'})
        elif k_subdata.linearity.values[i] == 2.5:
            plt.plot((k_subdata.startdate2num.values[i], k_subdata.enddate2num.values[i]),
                     (k_subdata.startprice.values[i], k_subdata.endprice.values[i]),
                     linestyle='-.', color='#FF7F00',
                     label='{}-{}-({},{})'.format(k_subdata.startdate.values[i], k_subdata.type.values[i], int(k_subdata.shortterm.values[i]), int(k_subdata.longterm.values[i])))
            plt.text(k_subdata.startdate2num.values[i], k_subdata.startprice.values[i], k_subdata.startdate.values[i], fontdict={'size':16})
            plt.scatter(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], marker='o', c='', edgecolors='#FF7F00')
            if k_subdata.lowerborder.values[i] != 0:
                #plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16})
           # else:
                plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i],'[{}-{}]'.format(k_subdata.lowerborder.values[i], k_subdata.upperborder.values[i]), fontdict={'size':16, 'color':'blue'})
               # plt.text(k_subdata.enddate2num.values[i], k_subdata.endprice.values[i], k_subdata.enddate.values[i], fontdict={'size':16, 'color':'blue'})

    plt.title('{}({})-(日线图{})  实线 - 趋势或已结束； 带圈虚线 - 趋势或接近末端，待确认； 虚线 - 趋势或仍在运行； 橙色线 - 对应起点的更大级别的趋势'.format(stock_name, stockID, date), fontsize=15)
    plt.xlabel('时间', fontsize=20)
    plt.ylabel('价格', fontsize=20)
    plt.xlim(k_dataset.date2num.min(), int(k_dataset.date2num.max()+ (k_dataset.date2num.max()-k_dataset.date2num.min())/10))
    plt.ylim(int(k_dataset.low.min()-(k_dataset.high.max()-k_dataset.low.min())/8), int(k_dataset.high.max()+(k_dataset.high.max()-k_dataset.low.min())/7))
    plt.legend(fontsize=16)
    plt.tight_layout()
    plt.rcParams['font.sans-serif']=['SimHei']   # 正常显示中文标签
    return plt.show()
    
    