Skip to content

Commit

Permalink
Merge pull request #230 from fasiondog/feature/tmpstock
Browse files Browse the repository at this point in the history
开放 stock 属性可写,供使用外部数据源时使用
  • Loading branch information
fasiondog committed Apr 17, 2024
2 parents ac2bd2d + 339ab7b commit 19ab50e
Show file tree
Hide file tree
Showing 11 changed files with 457 additions and 220 deletions.
135 changes: 64 additions & 71 deletions hikyuu/extend.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#
# 对 C++ 引出类和函数进行扩展, pybind11 对小函数到导出效率不如 python 直接执行
#

import numpy as np
import pandas as pd
from datetime import *
from .util.slice import list_getitem
from .core import *

# ------------------------------------------------------------------
Expand Down Expand Up @@ -293,94 +293,87 @@ def new_Query_init(self, start=0, end=None, ktype=Query.DAY, recover_type=Query.

Query.__init__ = new_Query_init


# ------------------------------------------------------------------
# 增加转化为 np.array、pandas.DataFrame 的功能
# ------------------------------------------------------------------

try:
import numpy as np
import pandas as pd

def KData_to_np(kdata):
"""转化为numpy结构数组"""
if kdata.get_query().ktype in ('DAY', 'WEEK', 'MONTH', 'QUARTER', 'HALFYEAR', 'YEAR'):
k_type = np.dtype(
{
'names': ['datetime', 'open', 'high', 'low', 'close', 'amount', 'volume'],
'formats': ['datetime64[D]', 'd', 'd', 'd', 'd', 'd', 'd']
}
)
else:
k_type = np.dtype(
{
'names': ['datetime', 'open', 'high', 'low', 'close', 'amount', 'volume'],
'formats': ['datetime64[ms]', 'd', 'd', 'd', 'd', 'd', 'd']
}
)
return np.array(
[(k.datetime.datetime(), k.open, k.high, k.low, k.close, k.amount, k.volume) for k in kdata], dtype=k_type
def KData_to_np(kdata):
"""转化为numpy结构数组"""
if kdata.get_query().ktype in ('DAY', 'WEEK', 'MONTH', 'QUARTER', 'HALFYEAR', 'YEAR'):
k_type = np.dtype(
{
'names': ['datetime', 'open', 'high', 'low', 'close', 'amount', 'volume'],
'formats': ['datetime64[D]', 'd', 'd', 'd', 'd', 'd', 'd']
}
)
else:
k_type = np.dtype(
{
'names': ['datetime', 'open', 'high', 'low', 'close', 'amount', 'volume'],
'formats': ['datetime64[ms]', 'd', 'd', 'd', 'd', 'd', 'd']
}
)
return np.array(
[(k.datetime.datetime(), k.open, k.high, k.low, k.close, k.amount, k.volume) for k in kdata], dtype=k_type
)

def KData_to_df(kdata):
"""转化为pandas的DataFrame"""
return pd.DataFrame.from_records(KData_to_np(kdata), index='datetime')

KData.to_np = KData_to_np
KData.to_df = KData_to_df
def KData_to_df(kdata):
"""转化为pandas的DataFrame"""
return pd.DataFrame.from_records(KData_to_np(kdata), index='datetime')

def PriceList_to_np(data):
"""仅在安装了numpy模块时生效,转换为numpy.array"""
return np.array(data, dtype='d')

def PriceList_to_df(data):
"""仅在安装了pandas模块时生效,转换为pandas.DataFrame"""
return pd.DataFrame(data.to_np(), columns=('Value', ))
KData.to_np = KData_to_np
KData.to_df = KData_to_df

PriceList.to_np = PriceList_to_np
PriceList.to_df = PriceList_to_df

def DatetimeList_to_np(data):
"""仅在安装了numpy模块时生效,转换为numpy.array"""
return np.array(data, dtype='datetime64[D]')
def DatetimeList_to_np(data):
"""仅在安装了numpy模块时生效,转换为numpy.array"""
return np.array(data, dtype='datetime64[D]')

def DatetimeList_to_df(data):
"""仅在安装了pandas模块时生效,转换为pandas.DataFrame"""
return pd.DataFrame(data.to_np(), columns=('Datetime', ))

DatetimeList.to_np = DatetimeList_to_np
DatetimeList.to_df = DatetimeList_to_df
def DatetimeList_to_df(data):
"""仅在安装了pandas模块时生效,转换为pandas.DataFrame"""
return pd.DataFrame(data.to_np(), columns=('Datetime', ))

def TimeLine_to_np(data):
"""转化为numpy结构数组"""
t_type = np.dtype({'names': ['datetime', 'price', 'vol'], 'formats': ['datetime64[ms]', 'd', 'd']})
return np.array([(t.date.datetime(), t.price, t.vol) for t in data], dtype=t_type)

def TimeLine_to_df(kdata):
"""转化为pandas的DataFrame"""
return pd.DataFrame.from_records(TimeLine_to_np(kdata), index='datetime')
DatetimeList.to_np = DatetimeList_to_np
DatetimeList.to_df = DatetimeList_to_df

TimeLineList.to_np = TimeLine_to_np
TimeLineList.to_df = TimeLine_to_df

def TransList_to_np(data):
"""转化为numpy结构数组"""
t_type = np.dtype(
{
'names': ['datetime', 'price', 'vol', 'direct'],
'formats': ['datetime64[ms]', 'd', 'd', 'd']
}
)
return np.array([(t.date.datetime(), t.price, t.vol, t.direct) for t in data], dtype=t_type)
def TimeLine_to_np(data):
"""转化为numpy结构数组"""
t_type = np.dtype({'names': ['datetime', 'price', 'vol'], 'formats': ['datetime64[ms]', 'd', 'd']})
return np.array([(t.date.datetime(), t.price, t.vol) for t in data], dtype=t_type)


def TimeLine_to_df(kdata):
"""转化为pandas的DataFrame"""
return pd.DataFrame.from_records(TimeLine_to_np(kdata), index='datetime')


TimeLineList.to_np = TimeLine_to_np
TimeLineList.to_df = TimeLine_to_df


def TransList_to_np(data):
"""转化为numpy结构数组"""
t_type = np.dtype(
{
'names': ['datetime', 'price', 'vol', 'direct'],
'formats': ['datetime64[ms]', 'd', 'd', 'd']
}
)
return np.array([(t.date.datetime(), t.price, t.vol, t.direct) for t in data], dtype=t_type)


def TransList_to_df(kdata):
"""转化为pandas的DataFrame"""
return pd.DataFrame.from_records(TransList_to_np(kdata), index='datetime')
def TransList_to_df(kdata):
"""转化为pandas的DataFrame"""
return pd.DataFrame.from_records(TransList_to_np(kdata), index='datetime')

TransList.to_np = TransList_to_np
TransList.to_df = TransList_to_df

except:
pass
TransList.to_np = TransList_to_np
TransList.to_df = TransList_to_df

# ------------------------------------------------------------------
# 增强 Parameter
Expand Down
38 changes: 14 additions & 24 deletions hikyuu/indicator/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,20 @@ def indicator_getitem(data, i):
Indicator.__iter__ = indicator_iter


try:
import numpy as np
import pandas as pd

def indicator_to_df(indicator):
"""转化为pandas.DataFrame"""
if indicator.get_result_num() == 1:
return pd.DataFrame(indicator.to_np(), columns=[indicator.name])

data = {}
name = indicator.name
columns = []
for i in range(indicator.get_result_num()):
data[name + str(i)] = indicator.get_result(i)
columns.append(name + str(i + 1))
return pd.DataFrame(data, columns=columns)

Indicator.to_df = indicator_to_df

except:
print(
"warning:can't import numpy or pandas lib, ",
"you can't use method Inidicator.to_np() and to_df!"
)
def indicator_to_df(indicator):
"""转化为pandas.DataFrame"""
if indicator.get_result_num() == 1:
return pd.DataFrame(indicator.to_np(), columns=[indicator.name])
data = {}
name = indicator.name
columns = []
for i in range(indicator.get_result_num()):
data[name + str(i)] = indicator.get_result(i)
columns.append(name + str(i + 1))
return pd.DataFrame(data, columns=columns)


Indicator.to_df = indicator_to_df


def concat_to_df(dates, ind_list, head_stock_code=True, head_ind_name=False):
Expand Down
24 changes: 12 additions & 12 deletions hikyuu/test/MoneyManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# -*- coding: utf8 -*-
# gb18030

#===============================================================================
# ===============================================================================
# 作者:fasiondog
# 历史:1)20130316, Added by fasiondog
#===============================================================================
# ===============================================================================

import unittest

Expand All @@ -18,7 +18,7 @@ def __init__(self):
self.set_param("n", 10)
self._m_flag = False

def getBuyNumber(self, datetime, stock, price, risk):
def get_buy_num(self, datetime, stock, price, risk):
if self._m_flag:
return 10
else:
Expand All @@ -44,14 +44,14 @@ def test_ConditionBase(self):
self.assertEqual(p.get_param("n"), 10)
p.set_param("n", 20)
self.assertEqual(p.get_param("n"), 20)
self.assertEqual(p.getBuyNumber(Datetime(200101010000), stock, 10.0, 0.0), 20)
self.assertEqual(p.get_buy_num(Datetime(200101010000), stock, 10.0, 0.0), 20)
p.reset()
self.assertEqual(p.getBuyNumber(Datetime(200101010000), stock, 10.0, 0.0), 10)
self.assertEqual(p.get_buy_num(Datetime(200101010000), stock, 10.0, 0.0), 10)

p_clone = p.clone()
self.assertEqual(p_clone.name, "MoneyManagerPython")
self.assertEqual(p_clone.get_param("n"), 20)
self.assertEqual(p_clone.getBuyNumber(Datetime(200101010000), stock, 10, 0.0), 10)
self.assertEqual(p_clone.get_buy_num(Datetime(200101010000), stock, 10, 0.0), 10)

p.set_param("n", 1)
p_clone.set_param("n", 3)
Expand All @@ -63,18 +63,18 @@ def testCrtMM(self):
pass


def testgetBuyNumber(self, datetime, stock, price, risk):
def testget_buy_num(self, datetime, stock, price, risk, part):
return 10.0 if datetime == Datetime(200101010000) else 0.0


class TestCrtMM(unittest.TestCase):
def test_crt_mm(self):
p = crtMM(testCrtMM, params={'n': 10}, name="TestMM")
p.getBuyNumber = testgetBuyNumber
p = crtMM(testget_buy_num, testCrtMM, params={'n': 10}, name="TestMM")
p.tm = crtTM(Datetime(200101010000))
self.assertEqual(p.name, "TestMM")
stock = sm['sh000001']
self.assertEqual(p.getBuyNumber(p, Datetime(200101010000), stock, 1.0, 1.0), 10.0)
self.assertEqual(p.getBuyNumber(p, Datetime(200101020000), stock, 1.0, 1.0), 0.0)
self.assertEqual(p.get_buy_num(Datetime(200101010000), stock, 1.0, 1.0, SystemPart.MM), 10.0)
self.assertEqual(p.get_buy_num(Datetime(200101020000), stock, 1.0, 1.0, SystemPart.MM), 0.0)

p_clone = p.clone()
self.assertEqual(p_clone.name, "TestMM")
Expand All @@ -85,4 +85,4 @@ def suite():


def suiteTestCrtMM():
return unittest.TestLoader().loadTestsFromTestCase(TestCrtMM)
return unittest.TestLoader().loadTestsFromTestCase(TestCrtMM)
30 changes: 14 additions & 16 deletions hikyuu/test/Slippage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# -*- coding: utf8 -*-
# gb18030

#===============================================================================
# ===============================================================================
# 作者:fasiondog
# 历史:1)20130321, Added by fasiondog
#===============================================================================
# ===============================================================================

import unittest

Expand All @@ -17,12 +17,12 @@ def __init__(self):
super(SlippagePython, self).__init__("SlippagePython")
self._x = 0

def getRealBuyPrice(self, datetime, price):
def get_real_buy_price(self, datetime, price):
if self._x < 10:
return 0.0
return 1.0

def getRealSellPrice(self, datetime, price):
def get_real_sell_price(self, datetime, price):
if self._x < 10:
return 0.0
return 1.0
Expand All @@ -43,14 +43,14 @@ class SlippageTest(unittest.TestCase):
def test_SlippageBase(self):
p = SlippagePython()
self.assertEqual(p.name, "SlippagePython")
self.assertEqual(p.getRealBuyPrice(Datetime(200101010000), 1.0), 0.0)
self.assertEqual(p.getRealSellPrice(Datetime(200101010000), 1.0), 0.0)
self.assertEqual(p.get_real_buy_price(Datetime(200101010000), 1.0), 0.0)
self.assertEqual(p.get_real_sell_price(Datetime(200101010000), 1.0), 0.0)

self.assertEqual(p._x, 0)
p._x = 10
self.assertEqual(p._x, 10)
self.assertEqual(p.getRealBuyPrice(Datetime(200101010000), 1.0), 1.0)
self.assertEqual(p.getRealSellPrice(Datetime(200101010000), 1.0), 1.0)
self.assertEqual(p.get_real_buy_price(Datetime(200101010000), 1.0), 1.0)
self.assertEqual(p.get_real_sell_price(Datetime(200101010000), 1.0), 1.0)
p.reset()
self.assertEqual(p._x, 0)

Expand All @@ -66,18 +66,16 @@ def test_crtSL_func(self):
pass


def test_getRealBuyPrice_func(self, datetime, price):
def test_get_real_buy_price_func(self, datetime, price):
return 10.0 if datetime == Datetime(200101010000) else 0.0


class TestCrtSL(unittest.TestCase):
def test_crtSL(self):
p = crtSL(test_crtSL_func, params={'n': 10}, name="TestSL")
p.getRealBuyPrice = test_getRealBuyPrice_func
def test_crtSP(self):
p = crtSP(test_get_real_buy_price_func, test_crtSL_func, params={'n': 10}, name="TestSL")
self.assertEqual(p.name, "TestSL")
self.assertEqual(p.getRealBuyPrice(p, Datetime(200101010000), 1.0), 10.0)
self.assertEqual(p.getRealBuyPrice(p, Datetime(200101020000), 1.0), 0.0)

self.assertEqual(p.get_real_buy_price(Datetime(200101010000), 1.0), 10.0)
self.assertEqual(p.get_real_buy_price(Datetime(200101020000), 1.0), 0.0)
p_clone = p.clone()
self.assertEqual(p_clone.name, "TestSL")

Expand All @@ -87,4 +85,4 @@ def suite():


def suiteTestCrtSL():
return unittest.TestLoader().loadTestsFromTestCase(TestCrtSL)
return unittest.TestLoader().loadTestsFromTestCase(TestCrtSL)

0 comments on commit 19ab50e

Please sign in to comment.