Skip to content

Commit

Permalink
Merge pull request #212 from fasiondog/feature/factor
Browse files Browse the repository at this point in the history
调整 MF get_score 接口; SE/PF/AF 微调; update get_part add *args
  • Loading branch information
fasiondog committed Mar 30, 2024
2 parents fe69d5c + d775fc2 commit 7d482d9
Show file tree
Hide file tree
Showing 11 changed files with 245 additions and 140 deletions.
5 changes: 3 additions & 2 deletions hikyuu/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,14 @@ def remove_hub(name):
HubManager().remove_hub(name)
def get_part(name, **kwargs):
def get_part(name, *args, **kwargs):
"""获取指定策略部件

:param str name: 策略部件名称
:param args: 其他部件相关参数
:param kwargs: 其他部件相关参数
"""
return HubManager().get_part(name, **kwargs)
return HubManager().get_part(name, *args, **kwargs)
def get_hub_path(name):
Expand Down
20 changes: 19 additions & 1 deletion hikyuu/trade_sys/trade_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from hikyuu.core import (
System, SystemPart, ConditionBase, EnvironmentBase, MoneyManagerBase,
ProfitGoalBase, SelectorBase, SignalBase, SlippageBase, StoplossBase, AllocateFundsBase
ProfitGoalBase, SelectorBase, SignalBase, SlippageBase, StoplossBase, AllocateFundsBase,
MultiFactorBase
)


Expand Down Expand Up @@ -180,6 +181,23 @@ def crtAF(allocate_weight_func, params={}, name='crtAF'):
return meta_x(name, params)
# ------------------------------------------------------------------
# multi_factor
# ------------------------------------------------------------------
def crtMF(calculate_func, params={}, name='crtMF'):
"""
快速多因子合成算法

:param calculate_func: 合成算法
:param {} params: 参数字典
:param str name: 自定义名称
:return: 自定义多因子合成算法实例
"""
meta_x = type(name, (MultiFactorBase, ), {'__init__': part_init, '_clone': part_clone})
meta_x._calculate = calculate_func
return meta_x(name, params)
# ------------------------------------------------------------------
# slippage
# ------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ SystemWeightList AllocateFundsBase::_adjust_with_running(
} else {
// 非延迟卖出的系统,立即强制卖出并回收资金
auto tr = sys->sellForceOnClose(date, MAX_DOUBLE, PART_ALLOCATEFUNDS);
HKU_DEBUG_IF(trace && tr.isNull(), "[AF] failed to sell: {}", sys->name());
// HKU_DEBUG_IF(trace && tr.isNull(), "[AF] failed to sell: {}", sys->name());
if (!tr.isNull()) {
auto sub_tm = sys->getTM();
auto sub_cash = sub_tm->currentCash();
Expand Down
42 changes: 32 additions & 10 deletions hikyuu_cpp/hikyuu/trade_sys/factor/MultiFactorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ const ScoreRecordList& MultiFactorBase::getScore(const Datetime& d) {
return m_stk_factor_by_date[iter->second];
}

ScoreRecordList MultiFactorBase::getScore(const Datetime& date, size_t start, size_t end) {
ScoreRecordList MultiFactorBase::getScores(const Datetime& date, size_t start, size_t end,
std::function<bool(const ScoreRecord&)>&& filter) {
ScoreRecordList ret;
HKU_IF_RETURN(start >= end, ret);

Expand All @@ -190,23 +191,44 @@ ScoreRecordList MultiFactorBase::getScore(const Datetime& date, size_t start, si
end = cross.size();
}

ret.resize(end - start);
for (size_t i = start; i < end; i++) {
ret[i] = cross[i];
if (filter) {
for (size_t i = start; i < end; i++) {
if (filter(cross[i])) {
ret.emplace_back(cross[i]);
}
}
} else {
for (size_t i = start; i < end; i++) {
ret.emplace_back(cross[i]);
}
}

return ret;
}

ScoreRecordList MultiFactorBase::getScore(const Datetime& date,
std::function<bool(const ScoreRecord&)> filter) {
ScoreRecordList MultiFactorBase::getScores(
const Datetime& date, size_t start, size_t end,
std::function<bool(const Datetime&, const ScoreRecord&)>&& filter) {
ScoreRecordList ret;
const auto& all_scores = getScore(date);
for (const auto& score : all_scores) {
if (filter(score)) {
ret.emplace_back(score);
HKU_IF_RETURN(start >= end, ret);

const auto& cross = getScore(date);
if (end == Null<size_t>() || end > cross.size()) {
end = cross.size();
}

if (filter) {
for (size_t i = start; i < end; i++) {
if (filter(date, cross[i])) {
ret.emplace_back(cross[i]);
}
}
} else {
for (size_t i = start; i < end; i++) {
ret.emplace_back(cross[i]);
}
}

return ret;
}

Expand Down
16 changes: 13 additions & 3 deletions hikyuu_cpp/hikyuu/trade_sys/factor/MultiFactorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,20 @@ class HKU_API MultiFactorBase : public enable_shared_from_this<MultiFactorBase>
/** 获取指定日期截面的所有因子值,已经降序排列 */
const ScoreRecordList& getScore(const Datetime&);

ScoreRecordList getScore(const Datetime& date, size_t start, size_t end = Null<size_t>());
/**
* 获取指定日期截面 [start, end] 范围内的因子值(评分), 并通过filer进行过滤
* @param date 指定日期
* @param start 排序起始点
* @param end 排序起始点(不含该点)
* @param filter 过滤函数
*/
ScoreRecordList getScores(
const Datetime& date, size_t start, size_t end = Null<size_t>(),
std::function<bool(const ScoreRecord&)>&& filter = std::function<bool(const ScoreRecord&)>());

/** 获取指定日期截面的所有因子值, 并通过指定的filer进行过滤 */
ScoreRecordList getScore(const Datetime& date, std::function<bool(const ScoreRecord&)> filter);
ScoreRecordList getScores(const Datetime& date, size_t start, size_t end = Null<size_t>(),
std::function<bool(const Datetime&, const ScoreRecord&)>&& filter =
std::function<bool(const Datetime&, const ScoreRecord&)>());

/** 获取所有截面数据,已按降序排列 */
const vector<ScoreRecordList>& getAllScores();
Expand Down
2 changes: 1 addition & 1 deletion hikyuu_cpp/hikyuu/trade_sys/portfolio/Portfolio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ void Portfolio::_runMoment(const Datetime& date, bool adjust) {
//----------------------------------------------------------------------
for (auto& sys : m_delay_adjust_sys_list) {
auto tr = sys.sys->sellForceOnOpen(date, sys.weight, PART_PORTFOLIO);
HKU_DEBUG_IF(trace && tr.isNull(), "[PF] Failed to force sell: {}", sys.sys->name());
// HKU_DEBUG_IF(trace && tr.isNull(), "[PF] Failed to force sell: {}", sys.sys->name());
if (!tr.isNull()) {
HKU_INFO_IF(trace, "[PF] Delay adjust sell: {}", tr);
m_tm->addTradeRecord(tr);
Expand Down
107 changes: 66 additions & 41 deletions hikyuu_cpp/hikyuu/trade_sys/selector/SelectorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,36 @@ HKU_API std::ostream& operator<<(std::ostream& os, const SelectorPtr& st) {
}

SelectorBase::SelectorBase() : m_name("SelectorBase") {
// 是否单独执行原型系统,仅限用于测试目的
setParam<bool>("run_proto_sys", false);
initParam();
}

SelectorBase::SelectorBase(const string& name) : m_name(name) {
// 是否单独执行原型系统
setParam<bool>("run_proto_sys", false);
initParam();
}

SelectorBase::~SelectorBase() {}

void SelectorBase::initParam() {
// 通常原型系统不参与计算,但某些特殊的场景,需要依赖于伴生系统策略,
// 此时可以认为实际执行的系统行为跟随伴生系统的买卖交易,如依赖于SG进行选择
// (不过由于仅依赖SG的场景不严谨,因为原型和实际系统的SG是一样的)
// 此时,需要在自身计算之前执行原型系统,然后SE自行时可以使用。
// 而对于实际系统和被跟随的系统完全不一样的情况,可以自行设计特殊的SE。
setParam<bool>("depend_on_proto_sys", false); // 此种情况,需要原型系统可独立运行
}

void SelectorBase::baseCheckParam(const string& name) const {}
void SelectorBase::paramChanged() {}

void SelectorBase::paramChanged() {
m_calculated = false;
m_proto_calculated = false;
}

void SelectorBase::removeAll() {
m_pro_sys_list = SystemList();
m_real_sys_list = SystemList();
m_pro_sys_list.clear();
m_real_sys_list.clear();
m_calculated = false;
m_proto_calculated = false;
}

void SelectorBase::reset() {
Expand All @@ -52,6 +65,9 @@ void SelectorBase::reset() {

m_real_sys_list.clear();
_reset();

m_calculated = false;
m_proto_calculated = false;
}

SelectorPtr SelectorBase::clone() {
Expand All @@ -70,6 +86,10 @@ SelectorPtr SelectorBase::clone() {

p->m_params = m_params;
p->m_name = m_name;
p->m_query = m_query;
p->m_proto_query = m_proto_query;
p->m_calculated = m_calculated;
p->m_proto_calculated = m_proto_calculated;

p->m_real_sys_list.reserve(m_real_sys_list.size());
for (const auto& sys : m_real_sys_list) {
Expand All @@ -83,52 +103,57 @@ SelectorPtr SelectorBase::clone() {
return p;
}

void SelectorBase::calculate(const SystemList& sysList, const KQuery& query) {
m_real_sys_list = sysList;
if (getParam<bool>("run_proto_sys")) {
// 用于手工测试
void SelectorBase::calculate(const SystemList& pf_realSysList, const KQuery& query) {
HKU_IF_RETURN(m_calculated && m_query == query, void());

m_query = query;
m_real_sys_list = pf_realSysList;

// 需要依赖于运行系统,在自身运算之前完成计算
if (getParam<bool>("depend_on_proto_sys")) {
calculate_proto(query);
}

_calculate();
m_calculated = true;
}

void SelectorBase::calculate_proto(const KQuery& query) {
if (m_proto_query != query && !m_proto_calculated) {
for (auto& sys : m_pro_sys_list) {
sys->run(query);
}
m_proto_calculated = true;
m_proto_query = query;
}
_calculate();
}

bool SelectorBase::addStock(const Stock& stock, const SystemPtr& protoSys) {
HKU_ERROR_IF_RETURN(stock.isNull(), false, "Try add Null stock, will be discard!");
HKU_ERROR_IF_RETURN(!protoSys, false, "Try add Null protoSys, will be discard!");
HKU_ERROR_IF_RETURN(!protoSys->getMM(), false, "protoSys has not MoneyManager!");
HKU_ERROR_IF_RETURN(!protoSys->getSG(), false, "protoSys has not Siganl!");
SYSPtr sys = protoSys->clone();
// 每个系统独立,不共享 tm
sys->setParam<bool>("shared_tm", false);
void SelectorBase::addStock(const Stock& stock, const SystemPtr& protoSys) {
HKU_CHECK(!stock.isNull(), "The input stock is null!");
HKU_CHECK(protoSys, "The input stock is null!");
HKU_CHECK(protoSys->getMM(), "protoSys missing MoneyManager!");
HKU_CHECK(protoSys->getSG(), "protoSys missing Siganl!");
HKU_CHECK(!protoSys->getParam<bool>("shared_tm"), "Unsupport shared TM for protoSys!");
if (getParam<bool>("depend_on_proto_sys")) {
HKU_CHECK(protoSys->getTM(),
"Scenarios that depend on prototype systems need to specify a TM!");
}

auto proto = protoSys;
proto->forceResetAll();
SYSPtr sys = proto->clone();
sys->reset();
sys->setStock(stock);
m_pro_sys_list.emplace_back(sys);
return true;
}

bool SelectorBase::addStockList(const StockList& stkList, const SystemPtr& protoSys) {
HKU_ERROR_IF_RETURN(!protoSys, false, "Try add Null protoSys, will be discard!");
HKU_ERROR_IF_RETURN(!protoSys->getMM(), false, "protoSys has not MoneyManager!");
HKU_ERROR_IF_RETURN(!protoSys->getSG(), false, "protoSys has not Signal!");
SYSPtr newProtoSys = protoSys->clone();
// 复位清除之前的数据,避免因原有数据过多导致下面循环时速度过慢
// 每个系统独立,不共享 tm
newProtoSys->setParam<bool>("shared_tm", false);
newProtoSys->reset();
StockList::const_iterator iter = stkList.begin();
for (; iter != stkList.end(); ++iter) {
if (iter->isNull()) {
HKU_WARN("Try add Null stock, will be discard!");
continue;
}
m_calculated = false;
m_proto_calculated = false;
}

SYSPtr sys = newProtoSys->clone();
sys->setStock(*iter);
m_pro_sys_list.emplace_back(sys);
void SelectorBase::addStockList(const StockList& stkList, const SystemPtr& protoSys) {
for (const auto& stk : stkList) {
addStock(stk, protoSys);
}
return true;
}

} /* namespace hku */
16 changes: 13 additions & 3 deletions hikyuu_cpp/hikyuu/trade_sys/selector/SelectorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class HKU_API SelectorBase : public enable_shared_from_this<SelectorBase> {
* @param protoSys 交易系统策略原型
* @return 如果 protoSys 无效 或 stock 无效,则返回 false, 否则返回 true
*/
bool addStock(const Stock& stock, const SystemPtr& protoSys);
void addStock(const Stock& stock, const SystemPtr& protoSys);

/**
* 加入一组相同交易策略的股票
Expand All @@ -59,7 +59,7 @@ class HKU_API SelectorBase : public enable_shared_from_this<SelectorBase> {
* @param protoSys 交易系统策略原型
* @return 如果 protoSys 无效则返回false,否则返回 true
*/
bool addStockList(const StockList& stkList, const SystemPtr& protoSys);
void addStockList(const StockList& stkList, const SystemPtr& protoSys);

/**
* @brief 获取原型系统列表
Expand Down Expand Up @@ -102,10 +102,20 @@ class HKU_API SelectorBase : public enable_shared_from_this<SelectorBase> {
virtual bool isMatchAF(const AFPtr& af) = 0;

/* 仅供PF调用,由PF通知其实际运行的系统列表,并启动计算 */
void calculate(const SystemList& sysList, const KQuery& query);
void calculate(const SystemList& pf_realSysList, const KQuery& query);

void calculate_proto(const KQuery& query);

private:
void initParam();

protected:
string m_name;
bool m_calculated{false}; // 是否已计算过
bool m_proto_calculated{false};
KQuery m_query;
KQuery m_proto_query;

SystemList m_pro_sys_list; // 原型系统列表
SystemList m_real_sys_list; // PF组合中实际运行的系统,有PF执行时设定,顺序与原型列表一一对应

Expand Down

0 comments on commit 7d482d9

Please sign in to comment.