diff --git a/docs/CHANGELOG.asciidoc b/docs/CHANGELOG.asciidoc index 9c35475f60..823e3f0708 100644 --- a/docs/CHANGELOG.asciidoc +++ b/docs/CHANGELOG.asciidoc @@ -33,6 +33,7 @@ === Bug Fixes * Correct inference model definition for MSLE regression models. (See {ml-pull}1375[#1375].) +* Fix restoration of change detectors after seasonality change. (See {ml-pull}1391[#1391].) * Fix cause of SIGSEGV of classification and regression. (See {ml-pull}1379[#1379].) == {es} version 7.8.1 diff --git a/include/maths/CTimeSeriesChangeDetector.h b/include/maths/CTimeSeriesChangeDetector.h index 5e3b9f1429..eb0aa2134b 100644 --- a/include/maths/CTimeSeriesChangeDetector.h +++ b/include/maths/CTimeSeriesChangeDetector.h @@ -136,6 +136,9 @@ class MATHS_EXPORT CUnivariateTimeSeriesChangeDetector { using TMinMaxAccumulator = CBasicStatistics::CMinMax; using TRegression = CLeastSquaresOnlineRegression<1, double>; + //! Initialise the m_ChangeModels vector + void initChangeModels(TPriorPtr residualModel); + private: //! The minimum amount of time we need to observe before //! selecting a change model. @@ -227,6 +230,9 @@ class MATHS_EXPORT CUnivariateChangeModel { //! Get a checksum for this object. virtual uint64_t checksum(uint64_t seed) const = 0; + //! Get the time series residual model member variable. + const TPriorPtr& residualModelPtr() const; + protected: CUnivariateChangeModel(const CUnivariateChangeModel& other, const TDecompositionPtr& trendModel, @@ -257,8 +263,6 @@ class MATHS_EXPORT CUnivariateChangeModel { const CPrior& residualModel() const; //! Get the time series residual model. CPrior& residualModel(); - //! Get the time series residual model member variable. - const TPriorPtr& residualModelPtr() const; private: using TMeanVarAccumulator = CBasicStatistics::SSampleMeanVar::TAccumulator; diff --git a/lib/maths/CTimeSeriesChangeDetector.cc b/lib/maths/CTimeSeriesChangeDetector.cc index 19c5e74bb8..797865af5b 100644 --- a/lib/maths/CTimeSeriesChangeDetector.cc +++ b/lib/maths/CTimeSeriesChangeDetector.cc @@ -98,11 +98,16 @@ CUnivariateTimeSeriesChangeDetector::CUnivariateTimeSeriesChangeDetector( : m_MinimumTimeToDetect{minimumTimeToDetect}, m_MaximumTimeToDetect{maximumTimeToDetect}, m_MinimumDeltaBicToDetect{minimumDeltaBicToDetect}, m_SampleCount{0}, m_DecisionFunction{0.0}, m_TrendModel{trendModel->clone()} { + this->initChangeModels(residualModel); +} + +void CUnivariateTimeSeriesChangeDetector::initChangeModels(TPriorPtr residualModel) { + m_ChangeModels.clear(); m_ChangeModels.push_back( - std::make_unique(trendModel, residualModel)); + std::make_unique(m_TrendModel, residualModel)); m_ChangeModels.push_back( std::make_unique(m_TrendModel, residualModel)); - if (trendModel->seasonalComponents().size() > 0) { + if (m_TrendModel->seasonalComponents().size() > 0) { m_ChangeModels.push_back(std::make_unique( m_TrendModel, residualModel, -core::constants::HOUR)); m_ChangeModels.push_back(std::make_unique( @@ -146,10 +151,13 @@ bool CUnivariateTimeSeriesChangeDetector::acceptRestoreTraverser( RESTORE_SETUP_TEARDOWN(MAX_TIME_TAG, core_t::TTime time, core::CStringUtils::stringToType(traverser.value(), time), m_TimeRange.add(time)) - RESTORE(TREND_MODEL_TAG, traverser.traverseSubLevel(std::bind( - CTimeSeriesDecompositionStateSerialiser(), - std::cref(params.s_DecompositionParams), - std::ref(m_TrendModel), std::placeholders::_1))) + RESTORE_SETUP_TEARDOWN(TREND_MODEL_TAG, /**/, + traverser.traverseSubLevel(std::bind( + CTimeSeriesDecompositionStateSerialiser(), + std::cref(params.s_DecompositionParams), + std::ref(m_TrendModel), std::placeholders::_1)), + this->initChangeModels((*model)->residualModelPtr()); + model = m_ChangeModels.begin()) RESTORE_SETUP_TEARDOWN( CHANGE_MODEL_TAG, TChangeModelPtr restoredModel{(*model)->clone(m_TrendModel)}, traverser.traverseSubLevel(std::bind(