In [None]:
import traceback
import socket
import pandas as pd
import time

from pathlib import Path
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Optional, Sequence, List
from rpc.client import RpcClient
from rpc.utility import INTERVAL_ADJUSTMENT_MAP
from rpc.utility import (get_duration, extract_vt_symbol, to_rq_symbol, to_vt_symbol, handle_df,
                         ts_to_dt, strip_digt,
                         load_json, save_json)

def get_trading_symbols() -> set:
    symbols = set()
    df = all_instruments(date=datetime.now())
    for ix, row in df.iterrows():
        symbols.add(row["order_book_id"])
    return symbols


def get_exchange_map() -> pd.Series:
    """
    获取品种代码->市场代码的映射
    例子：{"BU": "SHFE}
    """
    df = all_instruments(type='Future', date=datetime.now())
    df.drop_duplicates(subset='underlying_symbol', inplace=True)
    df.set_index('underlying_symbol', drop=True, inplace=True)
    return df['exchange']


def bt_symbol_to_vt_symbol(bt_symbol: str) -> str:
    """只适用回测连续合约代码（如RB888）的转换，不适用交易合约代码"""
    exchange = exchange_dict[strip_digt(bt_symbol).upper()]
    return f"{rq_symbol}.{exchange}"
    
    
def get_data(rq_symbol: str, rq_interval: str, start_date: datetime, end_date: datetime=None) -> dict:
    """获取单个合约某个区间的数据，并转为dict，用于远程传输"""
    # 没有加时间的dt，rq默认截止到上个收盘点，加1天可以截止到当前时间或加当日夜盘数据。
    if end_date:
        end_date += timedelta(1)
    
    df = get_price(
        rq_symbol,
        frequency=rq_interval,
        fields=["open", "high", "low", "close", "volume"],
        start_date=start_date,
        end_date=end_date,
        adjust_type="none"
    )
    
    df = handle_df(df, rq_interval)
#     return df
    return df.to_dict(orient="records")
    

def gen_start_end_pair(start: datetime, end: datetime) -> List[tuple]:
    """
    生成某月的开始日期和结束日期的组合
    """
    ends = pd.date_range(start=start, end=end, freq='M')
    starts = ends.shift(-1)
    starts = map(lambda ts: ts_to_dt(ts) + timedelta(1), starts)
    pairs = list(zip(starts, map(ts_to_dt, ends)))
    return pairs
    

def init_client(host: str, port: int, authkey: bytes):
    """初始化PRC客户端"""
    client = RpcClient(host, port, authkey)
    client.connect()
    return client


def save_all_data_by_period(client: RpcClient, rq_interval: str, start: datetime, end: datetime, symbols: Optional[Sequence[str]] = None):
    """批量传输长周期数据"""
    # 用于数据量较小的周期，可以一次性获取。
    if symbols is None:
#         symbols = get_update_symbol(client)
        print('without mission')
    for rq_symbol in symbols:
        vt_symbol = bt_symbol_to_vt_symbol(rq_symbol)
        data_dict = get_data(rq_symbol, rq_interval, start, end)
        client.save_to_database(data_dict, vt_symbol, rq_interval)
        print(f"{vt_symbol}-{rq_interval} {start.strftime('%Y%m')}-{end.strftime('%Y%m')} 数据保存成功")


def save_all_data_by_month(client: RpcClient, rq_interval: str, start: datetime, end: datetime, symbols: Optional[Sequence[str]] = None):
    """批量传输1m按月分段数据"""
    # 主要用于获取分钟数据，因为分钟数据量大，适合分段获取。
    if symbols is None:
#         symbols = get_update_symbol(client)
        print('without mission')
    for rq_symbol in symbols:
        pairs = gen_start_end_pair(start, end)
        vt_symbol = bt_symbol_to_vt_symbol(rq_symbol)
        collected = collected_dict.get(vt_symbol, [])
        print(f'{vt_symbol}任务简报：\n任务：{start.strftime("%Y%m")}-{end.strftime("%Y%m")}\n已收集合约:{collected}')
        for (s, e) in pairs:
            try:
                flag_name = s.strftime("%Y%m")
                if flag_name in collected:
#                     print(f"{vt_symbol}-{flag_name}数据已存在")
                    continue
                    
                data_dict = get_data(rq_symbol, rq_interval, s, e)
                client.save_to_database(data_dict, vt_symbol, rq_interval)
                print(f"{vt_symbol}-{flag_name}数据保存成功")
            
                collected_dict.setdefault(vt_symbol, []).append(flag_name)
                print("休息2秒")
                time.sleep(2)
            except:
                traceback.print_exc()
                
    save_json('collected.json', collected_dict)
    print(f"已收集数据更新成功")

    
def get_dominant_in_periods(underlying: str, backtest_start: datetime, backtest_end: datetime) -> pd.DataFrame:
    """
    获取某个日期区间的主力合约起止时间
    """
    underlying = underlying.upper()
    seg = pd.read_csv('dominant_data.csv', parse_dates=[1, 2])
    
    sel = seg[seg['underlying'] == underlying].copy()
    sel.reset_index(inplace=True)
    passed = sel[sel['start'] < backtest_start]
    after = sel[sel['start'] > backtest_end]
    passed_closest_idx = passed.index.values[-1]
    after_first_idx = after.index.values[0] if not after.empty else len(sel)
    
    if passed.iloc[-1]['end'] - backtest_start < timedelta(days=37):
        passed_closest_idx += 1
    return sel[passed_closest_idx: after_first_idx].copy()   


def get_all_dominant_data(client: RpcClient, start:datetime, end:datetime, underlying: str, rq_interval:str, back_days:int=31):
    """
    获取某个区间的主力合约市场数据
    """
    df = get_dominant_in_periods('RB', start, end)
    
    for idx, row in df.iterrows():
        start = row['start'] - timedelta(days=back_days)
        rq_symbol = row['dominant']
        vt_symbol = row['vt_symbol']
        data_dict = get_data(rq_symbol, rq_interval, start)
        client.save_to_database(data_dict, vt_symbol, rq_interval)
        print(f"{vt_symbol}-{rq_interval}} 数据保存成功")
        
    
# 初始化数据
collected_dict = load_json('collected.json')
connect_setting = load_json('connect.json')
exchange_dict = get_exchange_map()

### 连接信息

In [None]:
host_home = socket.gethostbyname(connect_setting['host_home'])
port = connect_setting['port']
authkey = connect_setting['authkey'].encode('ascii')
print(datetime.now(), host_home)

### 按月分段获取分钟数据

In [None]:
def get_1mbar_by_month():   
    rq_interval = "1m"
    symbols = ['TA888']
    
    start = datetime(2014, 1, 1)
    end = datetime(2014, 8, 1)
    
    client = init_client(host_home, port, authkey)
    if client:
        try:
            save_all_data_by_month(client, rq_interval, start, end, symbols)
        except:
            pass
        finally:
            client.close()

    return client
    
client = get_1mbar_by_month()

### 一次性获取长周期(1h以上)数据

In [None]:
def get_long_period_bar_once():   
    rq_interval = "1d"
    symbols = ['BU888']
    
    start = datetime(2019, 1, 1)
    end = datetime(2019, 12, 1)
    
    client = init_client(host_home, port, authkey)
    if client:
        try:
            save_all_data_by_period(client, rq_interval, start, end, symbols)
        except:
            traceback.print_exc()
        finally:
            client.close()

    return client
    
client = get_long_period_bar_once()

### 获取分段主力合约的长周期数据

In [None]:
rq_interval = "1h"
underlying = 'RB' 

start = datetime(2018, 1, 1)
end = datetime(2019, 5, 1)

client = init_client(host_home, port, authkey)
if client:
    try:
        get_all_dominant_data(client, start, end, underlying, rq_interval)
    except:
        traceback.print_exc()
    finally:
        client.close()
