diff --git a/DESCRIPTION b/DESCRIPTION index 4d8fc23..b2cacef 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: fabletools Title: Core Tools for Packages in the 'fable' Framework -Version: 0.2.0 +Version: 0.2.1 Authors@R: c(person(given = "Mitchell", family = "O'Hara-Wild", @@ -23,7 +23,7 @@ Description: Provides tools, helpers and data structures for packages. These tools support a consistent and tidy interface for time series modelling and analysis. License: GPL-3 -URL: http://fabletools.tidyverts.org/, +URL: https://fabletools.tidyverts.org/, https://github.com/tidyverts/fabletools BugReports: https://github.com/tidyverts/fabletools/issues Depends: R (>= 3.1.3) @@ -31,19 +31,18 @@ Imports: tsibble (>= 0.9.0), tibble (>= 1.4.1), ggplot2 (>= 3.0.0), tidyselect, rlang (>= 0.4.5), stats, dplyr (>= 1.0.0), tidyr (>= 1.1.0), generics, R6, utils, vctrs (>= 0.2.2), distributional, progressr, lifecycle -Suggests: covr, crayon, digest, fable (>= 0.2.0), future.apply, knitr, - methods, pillar (>= 1.0.1), feasts (>= 0.1.2), rmarkdown, - scales, spelling, testthat, tsibbledata (>= 0.2.0), lubridate, - SparseM +Suggests: covr, crayon, fable (>= 0.2.0), future.apply, knitr, pillar + (>= 1.0.1), feasts (>= 0.1.2), rmarkdown, spelling, testthat, + tsibbledata (>= 0.2.0), lubridate, Matrix VignetteBuilder: knitr RdMacros: lifecycle ByteCompile: true Encoding: UTF-8 Language: en-GB LazyData: true -RoxygenNote: 7.1.0.9000 +RoxygenNote: 7.1.1 NeedsCompilation: no -Packaged: 2020-06-15 14:28:10 UTC; mitchell +Packaged: 2020-09-01 03:47:04 UTC; mitchell Author: Mitchell O'Hara-Wild [aut, cre], Rob Hyndman [aut], Earo Wang [aut], @@ -51,4 +50,4 @@ Author: Mitchell O'Hara-Wild [aut, cre], George Athanasopoulos [ctb] Maintainer: Mitchell O'Hara-Wild Repository: CRAN -Date/Publication: 2020-06-15 23:40:08 UTC +Date/Publication: 2020-09-03 22:42:11 UTC diff --git a/MD5 b/MD5 index 89c249d..4623a63 100644 --- a/MD5 +++ b/MD5 @@ -1,91 +1,92 @@ -0fac29849cf626af7cb3bbe8f352d04b *DESCRIPTION -ae2bf8d0424c16785e56b03afb224454 *NAMESPACE -fcb97c21e758ca994888b87a5359f31b *NEWS.md +b3801cd9079d5757806e02c1504e0b62 *DESCRIPTION +fc1285de4a5ba877565e6fcaa85126ba *NAMESPACE +004e530b961df5e832439acd758a7e05 *NEWS.md a08cdb0d7f88b0e711dec46c77424b96 *R/accessors.R -e44598111eadbe57c4f8d8626c519590 *R/accuracy.R -0c45b598c52a151497f0f06613ab619f *R/aggregate.R +7ad770519af4f9c49c4cc76ee088839a *R/accuracy.R +d2717f0966a627a43b8c489543af600b *R/aggregate.R e9d5258c496dd3e45743862b3f055720 *R/box_cox.R -fb631ed039fefb0d9caa6f767eccfdc0 *R/broom.R +29c4211bf71e15aad4d3edaa39618a3d *R/broom.R cbaef737d494327d79afb92b21ccf254 *R/compat-purrr.R -00290ce1d4def4cdf495e977b73d1092 *R/components.R -e65f5eff4185bdb83aedbc178241c6aa *R/dable.R -8cdcfb08b18516a9d04db59b5aba2745 *R/definitions.R +3ef1dbefefb9435a8d4622f7f72463ef *R/components.R +515d6b02af1d96e4e5a0889f59238567 *R/dable.R +8ef591a2fa72a053a7af2b70ef8a0222 *R/definitions.R 9ad0ddbab62e1b98e2505bc165ed6818 *R/dplyr-dable.R -56390207a415bfa17c5b7aa9a98b9c82 *R/dplyr-fable.R -eaaedd83996a67a5442d82e5292ad906 *R/dplyr-mable.R +5bca8dbd4f8b4824a678d9baee943985 *R/dplyr-fable.R +a1e31597b326c1c9ec0fc5c054ff8018 *R/dplyr-mable.R f0cfa99a08bf99c3f0b2f8b4ab386305 *R/equation.R 7098a40958f70fe046ac023fe5e08167 *R/estimate.R -752d78c0b354c9bc3682c5b4f00afdbb *R/fable.R -031bb53deb9db66fff0f3d037d4aec79 *R/fabletools.R -d0d5299d48daa9c62229473d5194877a *R/features.R -c1ddc18c6abb3acde574c4c770df9d0d *R/fitted.R -c857cac1e4f1e9f5bfda1f119314f800 *R/forecast.R +c3c958ff811f50f4c6704ac529987a0e *R/fable.R +7d82bf977dc75cafa7afdc18ba21caa0 *R/fabletools-package.R +3967749c70628b8ef7145b207ca66817 *R/features.R +52033ada5cc3365cb31a79cc52b9f056 *R/fitted.R +39c0cb745d3adb837ad013c3c5e7c84d *R/forecast.R b49a5f15d5bad65d8d8130ed36a4e9b9 *R/frequency.R -320aa08d31032a75b021609d68947a12 *R/generate.R +759be65e0c74f8b7894d52132d66d5b4 *R/generate.R 6dbd4202a2f0de2202fc735856214508 *R/guess.R 6727867795e631fa91abbd4a37da1b14 *R/hilo.R c6423a8ee8154af0d8c8cc1a00ea34b5 *R/interpolate.R -43e49e62db50802a7530cdeeca5ad5ea *R/lst_mdl.R -ed0d74eadd98bb301a12f4fceb5b494f *R/mable.R +ec7eb945353800bd0a071b15a124538f *R/lst_mdl.R +57be9a08e2e9838366f4db0b3d6a2689 *R/mable.R 6cae1fdb6734995e8625daba498c6297 *R/mdl_ts.R 58fa875722db1babd25f2b30e8699c91 *R/model.R -9efb24fb9a8856e8f7fafb80a73ddbaf *R/model_combination.R +4aaaf1626ce168def4a4b445af32b14d *R/model_combination.R 29956ecc10adc8b296dcb5c9d2e56755 *R/model_decomposition.R -fd8c797b2def48cafea76ece6bd08c91 *R/model_null.R -62a47d076cc101b18a0396f2ef814007 *R/parse.R -02a8c21504fd39014d18fe98513b1b51 *R/plot.R -89a85b3adc5c44efb35eadcdc3528f81 *R/quantile.R -74547e3c4bf3ae23aa0766d6035df559 *R/reconciliation.R +df6f7d3d9bcc6fc1a989f7d821aef85c *R/model_null.R +7096b5da16c40a6436d3598dfc362d49 *R/parse.R +322519df55c818167f4036c9bc3b17d6 *R/plot.R +1783358ecf218fe0dbd5c35c316b3d87 *R/reconciliation.R 563cc27af1b423c80650eb8508c58877 *R/reexports.R -a8d7152f4787bc0add6d7a3c44cf3d63 *R/refit.R +140d7dd80ce971fea1386b96b90c80c7 *R/refit.R b60139e2f4bb88a829442c577dfddf3e *R/report.R -7a4d781f75bc99e5a113fb24af74be55 *R/residuals.R -0a94fa585d65788323b3ef731a366434 *R/response.R +37902979ade9836ebf0184e100d1f5ea *R/residuals.R +fe30a66e973d94f152eeec90864b4a15 *R/response.R 299bb581f0915f0e9bdb332d501e5c50 *R/specials.R 99b1b30bbc9298074966ad2cce7e9b79 *R/stream.R cbdf350bf426c9729a4078546d47e95c *R/tbl_utils.R +c65cf167631341af48b62aee71a0bf3f *R/temporal_aggregation.R fc5f777d65ed634b0e30f3558edac4ac *R/transform.R 2c2807619d68338517a09daedfa6835f *R/traverse.R -6e5a564e253a89f6ecacff6c4f4b262e *R/utils.R -2552de3b8f21cb5240d78f1812c2da00 *R/vctrs-dable.R -598d19952b87ed011a04cd6d3aebb1e5 *R/vctrs-fable.R +3e5aab622366f3a1386d8e0c5206c122 *R/utils.R +729d8de2770d3f84fe73a05729ae0893 *R/vctrs-dable.R +fa556d0d4264a65d3858420f15df4c0c *R/vctrs-fable.R 35fb655f21fda2dbb3960806937750bd *R/vctrs-mable.R -cdd71f8fc729466293aa5efe2e171114 *R/zzz.R -ebcfd1e1968c88ed5cea73cc921e63f9 *README.md -bbf564667749f9da66d023bafffdd621 *build/fabletools.pdf -5976875bd6ad67a36691592f97bdf39b *build/vignette.rds -043674935267964f783f87e833e3be21 *inst/WORDLIST +9eed4b755fab99446a41a2bb44576600 *R/zzz.R +933964bc28a06198d85850d94cee794e *README.md +a382d3e097c86c06ee16e7c3e62323d9 *build/fabletools.pdf +10e18aae8aba14bc6c6ca44a71fdc68a *build/vignette.rds +ad5a76670b9503e13645906221e8d713 *inst/WORDLIST 7bab195f754cc0017bd80d6b93536606 *inst/doc/extension_models.R -8e2d2ed9d388ce65e4bba0ce32d8e6d3 *inst/doc/extension_models.Rmd -7c8ffb5632a7ae985c7ae539a3d1644c *inst/doc/extension_models.html +5a159266a4d4a7fc70c5e5702e4ccd9c *inst/doc/extension_models.Rmd +691fee5909db268275ad0cfd126ed148 *inst/doc/extension_models.html cff23ee87902f0ba1c751ada29f20fc2 *man/MAAPE.Rd 2cb004482006e49a4ec956d90b70ba75 *man/accuracy.Rd +95514a857687441686e19bd0916c94b5 *man/agg_vec.Rd +dc05205882b73f99f1648380c722902d *man/aggregate_index.Rd 6306ce6869a6c1de33c801c85eb4a115 *man/aggregate_key.Rd e968b46684b00b9b8bd4d19a68c26414 *man/aggregation-vctrs.Rd -50f4b8d50b83231fee70c24f8fba4734 *man/as-dable.Rd -7dad2e7d6d82791e99a26207b4f0cbe3 *man/as-fable.Rd +c62dddfed7dc7cef64a9d94c8d9e2952 *man/as-dable.Rd +ea4a261519a8f938bbca31e2c82960ba *man/as-fable.Rd 551b9a81c42c40e09d806b815c9a3892 *man/as_mable.Rd -ea891482ba3bed667d05b0e841e26d93 *man/augment.Rd +2e6e5ffae77da16ecf82d5a2eb366be4 *man/augment.Rd 560bcd671070c6bc8d15308e40bb7adb *man/autoplot.dcmp_ts.Rd 37c1e2a9ecc19e127e3f57071ef31b85 *man/autoplot.fbl_ts.Rd f2269ba3a524bc34aa69f39e3e9bf530 *man/autoplot.tbl_ts.Rd 8c636a927212d183a7c3850bfe0084da *man/bias_adjust.Rd +5936ee99ccba59b5b3d8b902ea2e59cc *man/bottom_up.Rd cab019073996dcc2d5aba968f04fb033 *man/box_cox.Rd 3c109b101def27d5a9504671d485be4b *man/combination_ensemble.Rd cbd465b6b7674866572de9547b4fe6a5 *man/combination_model.Rd 51dd35a03ae3539a0cbb0352b8622fd4 *man/components.Rd -8705a27f60a0d6a53be02f4d924732e2 *man/construct_fc.Rd +44b374027aa1634b32e0fbcfa4431971 *man/construct_fc.Rd 311d01a7b31c53ed10632df3d81e6a37 *man/dable-vctrs.Rd -058a7563684b7b684fca0438f08fbff8 *man/dable.Rd +15f9f26ed1c93dd7002271493b6e2ad3 *man/dable.Rd e33da4e74d468e1f7f39bdd49fa0dc6d *man/decomposition_model.Rd 0d140bb5a3b44090769e7af77638089f *man/distribution_accuracy_measures.Rd 630c282ff1e8b7e32c0eeb8a3c801305 *man/distribution_var.Rd -9c997222638d0e8ae8ef4db51ca14a01 *man/distributions.Rd a0fdb24d9d15c5c5630de61a05615a2e *man/estimate.Rd b90659d7ad6d41d261b314c79e9aeeff *man/fable-vctrs.Rd -0adabf14c45279b920c97d24f4bebde8 *man/fable.Rd -7a6c27d5cfc20ca91ec93a0f614a6572 *man/fabletools-package.Rd -34e715e7bbad37d1e83b81a360bea680 *man/fcdist.Rd +c23404d1cf73d4a9f55c8de604b7afcf *man/fable.Rd +37a1869c53b733631056d0b4adca18ce *man/fabletools-package.Rd 0103d444a1b6f7328765a533af6d6553 *man/feature_set.Rd 3f7537e4cf99781018703c6ff641c5fe *man/features.Rd 8012d04e6cc0ebfd1aee0489e94f9146 *man/features_by_pkg.Rd @@ -101,8 +102,9 @@ c3978703d8f40f2679795335715e98f4 *man/figures/lifecycle-experimental.svg 46de21252239c5a23d400eae83ec6b2d *man/figures/lifecycle-retired.svg 6902bbfaf963fbc4ed98b86bda80caa2 *man/figures/lifecycle-soft-deprecated.svg 53b3f893324260b737b3c46ed2a0e643 *man/figures/lifecycle-stable.svg +1c1fe7a759b86dc6dbcbe7797ab8246c *man/figures/lifecycle-superseded.svg 2fedf198e8044afa788fe8bcbb008370 *man/fitted.mdl_df.Rd -706816fd7e91f531df6e1fc8691e28a2 *man/forecast.Rd +8b9d4831aa407715d77e1636e8bb2850 *man/forecast.Rd 676fd82e0d71779c4d0a82460c183616 *man/freq_tools.Rd b081ca74065a947884aa407113fa0c1b *man/generate.mdl_df.Rd 70fb188b1c61ab008f7f771af22a2903 *man/glance.Rd @@ -137,18 +139,20 @@ fe74782d2d7a04c98808abda763e18ff *man/report.Rd 51b8a397bf92000659fe9efaab4a3783 *man/residuals.mdl_df.Rd 04285043236595c6e83ccd810d402608 *man/response.Rd b55489638f729e2380e2d96f4716b1d6 *man/response_vars.Rd +9d66c854d2ee543c71a409b87e98e19d *man/skill_score.Rd 83b079e6ae9511eaf3088c462f9d75df *man/stream.Rd 750aaaefa9aa2044f8b40f8a48bdecaf *man/tidy.Rd +74b8186caf6db910744ce117f81e1e32 *man/top_down.Rd 9e68d7156db1ba73e94464d68fd2c308 *man/traverse.Rd 4980cb795255188f627f650622c4bf80 *man/unpack_hilo.Rd 4bf7ba0f683d07e48d465d0a4a558d7f *man/validate_formula.Rd 8176e3fbb47046d5f7220e147a102483 *tests/testthat.R -be2d2b2eeacf65178ee7d71a68c20c92 *tests/testthat/Rplots.pdf +63d96e03c81e75299331381f4ed7825f *tests/testthat/Rplots.pdf d086dba90b013cec5dda344fb7938612 *tests/testthat/setup-data.R b93b054b076aba90b60d8f5a0fd9198d *tests/testthat/setup-models.R -c09335528a353c2b6580e23c17ca0b12 *tests/testthat/test-accuracy.R -ded7db558cbf8d9d29389f75133cb856 *tests/testthat/test-broom.R -b3e726e2bf9dcf6440702e0c5f76894e *tests/testthat/test-combination.R +05dcc8db2deffe943e98d257e71236a4 *tests/testthat/test-accuracy.R +8eb061a803272c60a11e1b0ccea320d2 *tests/testthat/test-broom.R +e9703e4c2e172e3e7ad599b51641c9b6 *tests/testthat/test-combination.R fe6bbb7d8ad033dd4f3c4658bc1edb3e *tests/testthat/test-decomposition-model.R 5445de0d6ca8a6ebf5e15cdf81c7a14a *tests/testthat/test-fable.R 68676adae765873e84c710e4469d8311 *tests/testthat/test-features.R @@ -158,11 +162,9 @@ fe6bbb7d8ad033dd4f3c4658bc1edb3e *tests/testthat/test-decomposition-model.R 5c512358f60a3b13c0f15a11e67a83e4 *tests/testthat/test-interpolate.R 1618029c438b05810e96ba9ae4f72cf9 *tests/testthat/test-mable.R eb0999ceaf3888b6daf5f23699cb2c0c *tests/testthat/test-multivariate.R -0e8783b714f9076ad5a647e76131ab99 *tests/testthat/test-parser.R +935b43ed672ef5b3486ed26075fe3694 *tests/testthat/test-parser.R 8547dc560b56cbf6e8a91a0a5c18bf41 *tests/testthat/test-reconciliation.R 9a5176496e53324c9cfb5bed43de4eb4 *tests/testthat/test-spelling.R 5349d2d78b347edd154ef4669b129f95 *tests/testthat/test-transformations.R 880d0607b40a40f313f0ffa9df705b16 *tests/testthat/test-validate_model.R -8e2d2ed9d388ce65e4bba0ce32d8e6d3 *vignettes/extension_models.Rmd -49af686f551427aa4814c36a9fa9c13e *vignettes/temporal.R -f4dbc7f4e6193e6317ef1ab355b80fd3 *vignettes/temporal.html +5a159266a4d4a7fc70c5e5702e4ccd9c *vignettes/extension_models.Rmd diff --git a/NAMESPACE b/NAMESPACE index 99cd253..f6792b0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,23 +1,22 @@ # Generated by roxygen2: do not edit by hand S3method("$<-",mdl_df) +S3method("==",agg_vec) S3method("[",dcmp_ts) S3method("[",fbl_ts) -S3method("[",fcdist) S3method("[",mdl_df) S3method("names<-",mdl_df) -S3method(Ops,fcdist) S3method(Ops,lst_mdl) S3method(Ops,mdl_defn) S3method(Ops,mdl_ts) S3method(accuracy,fbl_ts) S3method(accuracy,mdl_df) S3method(accuracy,mdl_ts) -S3method(aggregate_index,tbl_ts) S3method(aggregate_key,tbl_ts) S3method(as_dable,tbl_df) S3method(as_dable,tbl_ts) S3method(as_fable,fbl_ts) +S3method(as_fable,forecast) S3method(as_fable,grouped_df) S3method(as_fable,grouped_ts) S3method(as_fable,tbl_df) @@ -36,7 +35,6 @@ S3method(autolayer,tbl_ts) S3method(autoplot,dcmp_ts) S3method(autoplot,fbl_ts) S3method(autoplot,tbl_ts) -S3method(c,fcdist) S3method(coef,mdl_df) S3method(coef,mdl_ts) S3method(common_periods,default) @@ -64,24 +62,27 @@ S3method(features,tbl_ts) S3method(features_all,tbl_ts) S3method(features_at,tbl_ts) S3method(features_if,tbl_ts) +S3method(fitted,"NULL") S3method(fitted,mdl_df) S3method(fitted,mdl_ts) S3method(fitted,model_combination) S3method(fitted,null_mdl) +S3method(forecast,"NULL") S3method(forecast,fbl_ts) S3method(forecast,lst_btmup_mdl) S3method(forecast,lst_mdl) S3method(forecast,lst_mint_mdl) +S3method(forecast,lst_topdwn_mdl) S3method(forecast,mdl_df) S3method(forecast,mdl_ts) S3method(forecast,model_combination) S3method(forecast,null_mdl) S3method(format,agg_vec) -S3method(format,fcdist) S3method(format,lst_mdl) S3method(format,mdl_ts) S3method(fortify,fbl_ts) S3method(gather,mdl_df) +S3method(generate,"NULL") S3method(generate,mdl_df) S3method(generate,mdl_ts) S3method(generate,model_combination) @@ -100,10 +101,10 @@ S3method(hilo,fbl_ts) S3method(interpolate,mdl_df) S3method(interpolate,mdl_ts) S3method(invert_transformation,transformation) +S3method(is.na,agg_vec) S3method(key,mdl_df) S3method(key_data,mdl_df) S3method(key_vars,mdl_df) -S3method(length,fcdist) S3method(mable_vars,mdl_df) S3method(model,tbl_ts) S3method(model_sum,decomposition_model) @@ -111,21 +112,22 @@ S3method(model_sum,default) S3method(model_sum,mdl_ts) S3method(model_sum,model_combination) S3method(model_sum,null_mdl) -S3method(print,fcdist) +S3method(pivot_longer,mdl_df) S3method(print,transformation) -S3method(quantile,fcdist) S3method(rbind,dcmp_ts) S3method(rbind,fbl_ts) S3method(reconcile,mdl_df) +S3method(refit,"NULL") S3method(refit,lst_mdl) S3method(refit,mdl_df) S3method(refit,mdl_ts) S3method(refit,null_mdl) -S3method(rep,fcdist) +S3method(report,"NULL") S3method(report,mdl_df) S3method(report,mdl_ts) S3method(report,model_combination) S3method(report,null_mdl) +S3method(residuals,"NULL") S3method(residuals,mdl_df) S3method(residuals,mdl_ts) S3method(residuals,null_mdl) @@ -137,10 +139,13 @@ S3method(response_vars,mdl_df) S3method(select,fbl_ts) S3method(select,grouped_fbl) S3method(select,mdl_df) +S3method(stream,"NULL") S3method(stream,lst_mdl) S3method(stream,mdl_df) S3method(stream,mdl_ts) S3method(stream,null_mdl) +S3method(summarise,fbl_ts) +S3method(summarise,grouped_fbl) S3method(tidy,mdl_df) S3method(tidy,mdl_ts) S3method(tidy,null_mdl) @@ -149,11 +154,11 @@ S3method(transmute,grouped_fbl) S3method(transmute,mdl_df) S3method(ungroup,fbl_ts) S3method(ungroup,grouped_fbl) -S3method(unique,fcdist) S3method(vec_cast,agg_vec.agg_vec) S3method(vec_cast,agg_vec.character) S3method(vec_cast,agg_vec.default) S3method(vec_cast,character.agg_vec) +S3method(vec_cast,character.lst_mdl) S3method(vec_cast,data.frame.dcmp_ts) S3method(vec_cast,data.frame.fbl_ts) S3method(vec_cast,data.frame.mdl_df) @@ -163,15 +168,15 @@ S3method(vec_cast,dcmp_ts.tbl_df) S3method(vec_cast,fbl_ts.data.frame) S3method(vec_cast,fbl_ts.fbl_ts) S3method(vec_cast,fbl_ts.tbl_df) -S3method(vec_cast,fcdist) +S3method(vec_cast,lst_mdl.lst_mdl) S3method(vec_cast,mdl_df.data.frame) S3method(vec_cast,mdl_df.mdl_df) S3method(vec_cast,mdl_df.tbl_df) S3method(vec_cast,tbl_df.dcmp_ts) S3method(vec_cast,tbl_df.fbl_ts) S3method(vec_cast,tbl_df.mdl_df) -S3method(vec_cast.fcdist,default) -S3method(vec_cast.fcdist,fcdist) +S3method(vec_cast,tbl_ts.dcmp_ts) +S3method(vec_cast,tbl_ts.fbl_ts) S3method(vec_proxy_compare,agg_vec) S3method(vec_ptype2,agg_vec.agg_vec) S3method(vec_ptype2,agg_vec.character) @@ -186,15 +191,13 @@ S3method(vec_ptype2,dcmp_ts.tbl_df) S3method(vec_ptype2,fbl_ts.data.frame) S3method(vec_ptype2,fbl_ts.fbl_ts) S3method(vec_ptype2,fbl_ts.tbl_df) -S3method(vec_ptype2,fcdist) +S3method(vec_ptype2,lst_mdl.lst_mdl) S3method(vec_ptype2,mdl_df.data.frame) S3method(vec_ptype2,mdl_df.mdl_df) S3method(vec_ptype2,mdl_df.tbl_df) S3method(vec_ptype2,tbl_df.dcmp_ts) S3method(vec_ptype2,tbl_df.fbl_ts) S3method(vec_ptype2,tbl_df.mdl_df) -S3method(vec_ptype2.fcdist,default) -S3method(vec_ptype2.fcdist,fcdist) S3method(vec_ptype_abbr,agg_vec) export("%>%") export(ACF1) @@ -209,6 +212,7 @@ export(MSE) export(RMSE) export(RMSSE) export(accuracy) +export(agg_vec) export(aggregate_key) export(as_dable) export(as_fable) @@ -218,6 +222,7 @@ export(augment) export(autolayer) export(autoplot) export(bias_adjust) +export(bottom_up) export(box_cox) export(combination_ensemble) export(combination_model) @@ -226,10 +231,6 @@ export(components) export(construct_fc) export(dable) export(decomposition_model) -export(dist_mv_normal) -export(dist_normal) -export(dist_sim) -export(dist_unknown) export(distribution_accuracy_measures) export(distribution_var) export(equation) @@ -263,8 +264,6 @@ export(model) export(model_lhs) export(model_rhs) export(model_sum) -export(new_fcdist) -export(new_fcdist_env) export(new_model_class) export(new_model_definition) export(new_specials) @@ -283,8 +282,10 @@ export(report) export(response) export(response_vars) export(scaled_pinball_loss) +export(skill_score) export(stream) export(tidy) +export(top_down) export(traverse) export(unpack_hilo) export(validate_formula) @@ -348,16 +349,15 @@ importFrom(ggplot2,labs) importFrom(ggplot2,vars) importFrom(ggplot2,xlab) importFrom(ggplot2,ylab) +importFrom(lifecycle,deprecate_warn) importFrom(stats,fitted) -importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,residuals) importFrom(stats,var) importFrom(tibble,new_tibble) importFrom(tidyr,gather) importFrom(tidyr,nest) +importFrom(tidyr,pivot_longer) importFrom(tidyr,spread) importFrom(tidyr,unnest) -importFrom(utils,combn) -importFrom(vctrs,vec_cast) -importFrom(vctrs,vec_ptype2) +importFrom(tidyselect,all_of) diff --git a/NEWS.md b/NEWS.md index 2554c03..720802c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,56 @@ +# fabletools 0.2.1 + +## New features + +* Added `bottom_up()` forecast reconciliation method. +* Added the `skill_score()` accuracy measure modifier. +* Added `agg_vec()` for manually producing aggregation vectors. + +## Improvements + +* Fixed some inconsistencies in key ordering of model accessors (such as + `augment()`, `tidy()` and `glance()`) with model methods (such as `forecast()` + and `generate()`). +* Improved equality comparison of `agg_vec` classes, aggregated values will now + always match regardless of the value used. +* Using `summarise()` with a fable will now retain the fable class if the + distribution still exists under the same variable name. +* Added `as_fable.forecast()` to convert forecast objects from the forecast + package to work with fable. +* Improved `CRPS()` performance when using sampling distributions (#240). +* Reconciliation now works with hierarchies containing aggregate leaf nodes, + allowing unbalanced hierarchies to be reconciled. +* Produce unique names for unnamed features used with `features()` (#258). +* Documentation improvements +* Performance improvements + +## Breaking changes + +* The residuals obtained from the `augment()` function are no longer controlled + by the `type` argument. Response residuals (`y - yhat`) are now always found + in the `.resid` column, and innovation residuals (the model's error) are now + found in the `.innov` column. Response residuals will differ from innovation + residuals when transformations are used, and if the model has non-additive + residuals. +* `dist_*()` functions are now removed, and are completely replaced by the + distributional package. These are removed to prevent masking issues when + loading packages. +* `fortify()` will now return a tibble with the same structure as the + fable, which is more useful for plotting forecast distributions with the + ggdist package. It can no longer be used to extract intervals from the + forecasts, this can be done using `hilo()`, and numerical values from a + `` can be extracted with `unpack_hilo()` or `interval$lower`. + +## Bug fixes + +* Fixed issue with aggregated date vectors (#230). +* Fixed display of models in `View()` panel. +* Fixed issue with combination models not inheriting vctrs functionality (#237). +* `aggregate_key()` can now be used with non-syntactic variable names. +* Added tsibble cast methods for fable and dable objects, fixing issues with + tidyverse functionality between datasets of different column orders (#247). +* Fixed `refit()` dropping reconciliation attributes (#251). + # fabletools 0.2.0 ## New features diff --git a/R/accuracy.R b/R/accuracy.R index 7ea1624..e9b8588 100644 --- a/R/accuracy.R +++ b/R/accuracy.R @@ -130,9 +130,10 @@ winkler_score <- function(.dist, .actual, level = 95, na.rm = TRUE, ...){ } #' @rdname interval_accuracy_measures +#' @importFrom stats quantile #' @export pinball_loss <- function(.dist, .actual, level = 95, na.rm = TRUE, ...){ - q <- quantile(.dist, level/100) + q <- stats::quantile(.dist, level/100) loss <- ifelse(.actual>=q, level/100 * (.actual-q), (1-level/100) * (q-.actual)) mean(loss, na.rm = na.rm) } @@ -153,7 +154,7 @@ scaled_pinball_loss <- function(.dist, .actual, .train, level = 95, na.rm = TRUE } scale <- mean(abs(.train), na.rm = na.rm) - q <- quantile(.dist, level/100) + q <- stats::quantile(.dist, level/100) loss <- ifelse(.actual>=q, level/100 * (.actual-q), (1-level/100) * (q-.actual)) mean(loss/scale, na.rm = na.rm) } @@ -184,6 +185,7 @@ percentile_score <- function(.dist, .actual, na.rm = TRUE, ...){ #' @export CRPS <- function(.dist, .actual, n_quantiles = 1000, na.rm = TRUE, ...){ is_normal <- map_lgl(.dist, inherits, "dist_normal") + is_sample <- map_lgl(.dist, inherits, "dist_sample") z <- rep(NA_real_, length(.dist)) if(any(is_normal)){ @@ -193,17 +195,26 @@ CRPS <- function(.dist, .actual, n_quantiles = 1000, na.rm = TRUE, ...){ zn <- sd*(zn*(2*stats::pnorm(zn)-1)+2*stats::dnorm(zn)-1/sqrt(pi)) z[is_normal] <- zn } - - if(any(!is_normal)){ + if(any(is_sample)){ + z[is_sample] <- map2_dbl( + .dist[is_sample], .actual[is_sample], + function(d, y){ + x <- sort(d$x) + m <- length(x) + (2/m) * mean((x-y)*(m*(y% +#' dplyr::filter(index < yearmonth("1979 Jan")) %>% +#' model( +#' ets = ETS(value ~ error("M") + trend("A") + season("A")), +#' lm = TSLM(value ~ trend() + season()) +#' ) %>% +#' forecast(h = "1 year") %>% +#' accuracy(lung_deaths, measures = list(skill = skill_score(MSE))) +#' } +#' +#' @export +skill_score <- function(measure) { + function(...) { + # Compute accuracy measure for forecasts + score <- measure(...) + + # Compute arguments of benchmark method using .train + bench <- list(...) + lag <- bench$.period + n <- length(bench$.train) + y <- bench$.train + + ## Compute point forecast from benchmark + bench$.fc <- rep_len( + y[c(rep(NA, max(0, lag - n)), seq_len(min(n, lag)) + n - min(n, lag))], + length(bench$.fc) + ) + bench$.resid <- bench$.actual - bench$.fc + + # Compute forecast distribution from benchmark + e <- y - c(rep(NA, min(lag, n)), y[seq_len(length(y) - lag)]) + mse <- mean(e^2, na.rm = TRUE) + h <- length(bench$.actual) + fullperiods <- (h - 1) / lag + 1 + steps <- rep(seq_len(fullperiods), rep(lag, fullperiods))[seq_len(h)] + bench$.dist <- distributional::dist_normal(bench$.fc, sqrt(mse * steps)) + ref_score <- do.call(measure, bench) + + 1 - score / ref_score + } +} + #' Evaluate accuracy of a forecast or model #' #' Summarise the performance of the model using accuracy measures. Accuracy @@ -273,6 +344,10 @@ accuracy <- function(object, ...){ #' #' @export accuracy.mdl_df <- function(object, measures = point_accuracy_measures, ...){ + if(is_tsibble(measures)){ + abort("The `measures` argument must contain a list of accuracy measures. +Hint: A tsibble of future values is only required when computing accuracy of a fable. To compute forecast accuracy, you'll need to compute the forecasts first.") + } as_tibble(object) %>% tidyr::pivot_longer(mable_vars(object), names_to = ".model", values_to = "fit") %>% mutate(fit = map(!!sym("fit"), accuracy, measures = measures, ...)) %>% @@ -284,7 +359,7 @@ accuracy.mdl_ts <- function(object, measures = point_accuracy_measures, ...){ dots <- dots_list(...) resp <- if(length(object$response) > 1) sym("value") else object$response[[1]] - aug <- as_tibble(augment(object, type = "response")) + aug <- as_tibble(augment(object)) # Compute inputs for each response variable if(length(object$response) > 1){ @@ -362,7 +437,7 @@ accuracy.fbl_ts <- function(object, data, measures = point_accuracy_measures, .. by <- union(index_var(object), by) - if(!(".model" %in% by)){ + if(!(".model" %in% by) & ".model" %in% names(object)){ warn('Accuracy measures should be computed separately for each model, have you forgotten to add ".model" to your `by` argument?') } diff --git a/R/aggregate.R b/R/aggregate.R index 491b2fa..6d2e7e0 100644 --- a/R/aggregate.R +++ b/R/aggregate.R @@ -32,21 +32,15 @@ aggregate_key <- function(.data, .spec, ...){ aggregate_key.tbl_ts <- function(.data, .spec = NULL, ...){#, dev = FALSE){ .spec <- enexpr(.spec) if(is.null(.spec)){ + kv <- syms(key_vars(.data)) message( sprintf("Key structural specification not found, defaulting to `.spec = %s`", - paste(key_vars(.data), collapse = "*")) + paste(kv, collapse = "*")) ) - .spec <- parse_expr(paste(key_vars(.data), collapse = "*")) - } - - # Key combinations - tm <- stats::terms(new_formula(lhs = NULL, rhs = .spec), env = empty_env()) - key_comb <- attr(tm, "factors") - key_vars <- rownames(key_comb) - key_comb <- map(split(key_comb, col(key_comb)), function(x) key_vars[x!=0]) - if(attr(tm, "intercept")){ - key_comb <- c(list(chr()), key_comb) + .spec <- reduce(kv, call2, .fn = "*") } + + key_comb <- parse_agg_spec(.spec) idx <- index2_var(.data) intvl <- interval(.data) @@ -91,6 +85,18 @@ aggregate_key.tbl_ts <- function(.data, .spec = NULL, ...){#, dev = FALSE){ interval = intvl) } +parse_agg_spec <- function(expr){ + # Key combinations + tm <- stats::terms(new_formula(lhs = NULL, rhs = expr), env = empty_env()) + key_comb <- attr(tm, "factors") + key_vars <- sub("^`(.*)`$", "\\1", rownames(key_comb)) + key_comb <- map(split(key_comb, col(key_comb)), function(x) key_vars[x!=0]) + if(attr(tm, "intercept")){ + key_comb <- c(list(chr()), key_comb) + } + unname(key_comb) +} + # #' @rdname aggregate_key # #' @@ -103,63 +109,84 @@ aggregate_key.tbl_ts <- function(.data, .spec = NULL, ...){#, dev = FALSE){ # #' library(tsibble) # #' pedestrian %>% # #' aggregate_index() -aggregate_index <- function(.data, .times, ...){ - UseMethod("aggregate_index") -} +# aggregate_index <- function(.data, .times, ...){ +# UseMethod("aggregate_index") +# } +# +# #' @export +# aggregate_index.tbl_ts <- function(.data, .times = NULL, ...){ +# warn("Temporal aggregation is highly experimental. The interface will be refined in the near future.") +# +# require_package("lubridate") +# idx <- index(.data) +# kv <- key_vars(.data) +# +# # Parse times as lubridate::period +# if(is.null(.times)){ +# interval <- with(interval(.data), lubridate::years(year) + +# lubridate::period(3*quarter + month, units = "month") + lubridate::weeks(week) + +# lubridate::days(day) + lubridate::hours(hour) + lubridate::minutes(minute) + +# lubridate::seconds(second) + lubridate::milliseconds(millisecond) + +# lubridate::microseconds(microsecond) + lubridate::nanoseconds(nanosecond)) +# periods <- common_periods(.data) +# .times <- c(set_names(names(periods), names(periods)), list2(!!format(interval(.data)) := interval)) +# } +# .times <- set_names(map(.times, lubridate::as.period), names(.times) %||% .times) +# +# secs <- map_dbl(.times, lubridate::period_to_seconds) +# .times <- .times[order(secs, decreasing = TRUE)] +# +# # Temporal aggregations +# .data <- as_tibble(.data) +# agg_dt <- vctrs::vec_rbind( +# !!!map(seq_along(.times), function(tm){ +# group_data( +# group_by(.data, +# !!!set_names(names(.times), names(.times))[seq_len(tm-1) + 1], +# !!as_string(idx) := lubridate::floor_date(!!idx, .times[[tm]]), +# !!!syms(kv)) +# ) +# }) +# ) +# kv <- setdiff(colnames(agg_dt), c(as_string(idx), ".rows")) +# agg_dt <- agg_dt[c(as_string(idx), kv, ".rows")] +# +# .data <- dplyr::new_grouped_df(.data, groups = agg_dt) +# +# # Compute aggregates and repair index attributes +# idx_attr <- attributes(.data[[as_string(idx)]]) +# .data <- ungroup(summarise(.data, ...)) +# attributes(.data[[as_string(idx)]]) <- idx_attr +# +# # Return tsibble +# as_tsibble(.data, key = kv, index = !!idx) %>% +# mutate(!!!set_names(map(kv, function(x) expr(agg_vec(!!sym(x)))), kv)) +# } +#' Create an aggregation vector +#' +#' \lifecycle{maturing} +#' +#' An aggregation vector extends usual vectors by adding values. +#' These vectors are typically produced via the [`aggregate_key()`] function, +#' however it can be useful to create them manually to produce more complicated +#' hierarchies (such as unbalanced hierarchies). +#' +#' @param x The vector of values. +#' @param aggregated A logical vector to identify which values are . +#' +#' @example +#' agg_vec( +#' x = c(NA, "A", "B"), +#' aggregated = c(TRUE, FALSE, FALSE) +#' ) +#' #' @export -aggregate_index.tbl_ts <- function(.data, .times = NULL, ...){ - warn("Temporal aggregation is highly experimental. The interface will be refined in the near future.") - - require_package("lubridate") - idx <- index(.data) - kv <- key_vars(.data) - - # Parse times as lubridate::period - if(is.null(.times)){ - interval <- with(interval(.data), lubridate::years(year) + - lubridate::period(3*quarter + month, units = "month") + lubridate::weeks(week) + - lubridate::days(day) + lubridate::hours(hour) + lubridate::minutes(minute) + - lubridate::seconds(second) + lubridate::milliseconds(millisecond) + - lubridate::microseconds(microsecond) + lubridate::nanoseconds(nanosecond)) - periods <- common_periods(.data) - .times <- c(set_names(names(periods), names(periods)), list2(!!format(interval(.data)) := interval)) - } - .times <- set_names(map(.times, lubridate::as.period), names(.times) %||% .times) - - secs <- map_dbl(.times, lubridate::period_to_seconds) - .times <- .times[order(secs, decreasing = TRUE)] - - # Temporal aggregations - .data <- as_tibble(.data) - agg_dt <- vctrs::vec_rbind( - !!!map(seq_along(.times), function(tm){ - group_data( - group_by(.data, - !!!set_names(names(.times), names(.times))[seq_len(tm-1) + 1], - !!as_string(idx) := lubridate::floor_date(!!idx, .times[[tm]]), - !!!syms(kv)) - ) - }) - ) - kv <- setdiff(colnames(agg_dt), c(as_string(idx), ".rows")) - agg_dt <- agg_dt[c(as_string(idx), kv, ".rows")] - - .data <- dplyr::new_grouped_df(.data, groups = agg_dt) - - # Compute aggregates and repair index attributes - idx_attr <- attributes(.data[[as_string(idx)]]) - .data <- ungroup(summarise(.data, ...)) - attributes(.data[[as_string(idx)]]) <- idx_attr - - # Return tsibble - as_tsibble(.data, key = kv, index = !!idx) %>% - mutate(!!!set_names(map(kv, function(x) expr(agg_vec(!!sym(x)))), kv)) -} - agg_vec <- function(x = character(), aggregated = logical(vec_size(x))){ + is_agg <- is_aggregated(x) + x[is_agg] <- NA vec_assert(aggregated, ptype = logical()) - vctrs::new_rcrd(list(x = x, agg = aggregated), class = "agg_vec") + vctrs::new_rcrd(list(x = x, agg = is_agg | aggregated), class = "agg_vec") } #' @export @@ -235,7 +262,7 @@ vec_cast.agg_vec <- function(x, to, ...) UseMethod("vec_cast.agg_vec") #' @export vec_cast.agg_vec.agg_vec <- function(x, to, ...) { x <- vec_proxy(x) - if(all(x$agg)) x$x <- rep_len(vec_cast(NA, vec_proxy(to)$x), length(x$x)) + if(all(x$agg)) x$x <- vec_rep(vec_cast(NA, vec_proxy(to)$x), length(x$x)) vec_restore(x, to) } #' @rdname aggregation-vctrs @@ -253,6 +280,34 @@ vec_proxy_compare.agg_vec <- function(x, ...) { vec_proxy(x)[c(2,1)] } +#' @export +`==.agg_vec` <- function(e1, e2){ + e1_agg <- inherits(e1, "agg_vec") + e2_agg <- inherits(e2, "agg_vec") + + if(!e1_agg || !e2_agg){ + x <- list(e1,e2)[[which(!c(e1_agg, e2_agg))]] + is_agg <- x == "" + if(any(is_agg)){ + warn(" character values have been converted to aggregated values. +Hint: If you're trying to compare aggregated values, use `is_aggregated()`.") + } + x <- agg_vec(ifelse(is_agg, NA, x), aggregated = is_agg) + if(!e1_agg) e1 <- x else e2 <- x + } + + x <- vec_recycle_common(e1, e2) + e1 <- vec_proxy(x[[1]]) + e2 <- vec_proxy(x[[2]]) + out <- logical(vec_size(e1)) + (e1$agg & e2$agg) | vec_equal(e1$x, e2$x, na_equal = TRUE) +} + +#' @export +is.na.agg_vec <- function(x) { + is.na(field(x, "x")) & !field(x, "agg") +} + #' Is the element an aggregation of smaller data #' #' @param x An object. diff --git a/R/broom.R b/R/broom.R index 75358f3..1ef3559 100644 --- a/R/broom.R +++ b/R/broom.R @@ -14,22 +14,29 @@ #' # Forecasting with an ETS(M,Ad,A) model to Australian beer production #' aus_production %>% #' model(ets = ETS(log(Beer) ~ error("M") + trend("Ad") + season("A"))) %>% -#' augment(type = "response") +#' augment() #' } #' #' @rdname augment #' @export augment.mdl_df <- function(x, ...){ - x <- gather(x, ".model", ".fit", !!!syms(mable_vars(x))) + mbl_vars <- mable_vars(x) kv <- key_vars(x) - x <- transmute(as_tibble(x), !!!syms(kv), !!sym(".model"), - aug = map(!!sym(".fit"), augment, ...)) - unnest_tsbl(x, "aug", parent_key = kv) + x <- mutate(as_tibble(x), + dplyr::across(all_of(mbl_vars), function(x) lapply(x, augment, ...))) + x <- pivot_longer(x, mbl_vars, names_to = ".model", values_to = ".aug") + unnest_tsbl(x, ".aug", parent_key = c(kv, ".model")) } #' @rdname augment +#' @param type Deprecated. #' @export -augment.mdl_ts <- function(x, ...){ +augment.mdl_ts <- function(x, type = NULL, ...){ + if (!is.null(type)) { + lifecycle::deprecate_warn("0.2.1", "fabletools::augment(type = )", + details = "The type argument is now deprecated for changes to broom v0.7.0. +Response residuals are now always found in `.resid` and innovation residuals are now found in `.innov`.") + } tryCatch(augment(x[["fit"]], ...), error = function(e){ idx <- index_var(x$data) @@ -43,7 +50,12 @@ augment.mdl_ts <- function(x, ...){ by = c(".response", idx) ) %>% left_join( - gather(residuals(x, ...), ".response", ".resid", + gather(residuals(x, type = "response", ...), ".response", ".resid", + !!!resp, factor_key = TRUE), + by = c(".response", idx) + ) %>% + left_join( + gather(residuals(x, type = "innovation", ...), ".response", ".innov", !!!resp, factor_key = TRUE), by = c(".response", idx) ) @@ -51,7 +63,8 @@ augment.mdl_ts <- function(x, ...){ mutate( set_names(response(x), c(idx, as_string(resp[[1]]))), .fitted = fitted(x, ...)[[".fitted"]], - .resid = residuals(x, ...)[[".resid"]] + .resid = residuals(x, type = "response", ...)[[".resid"]], + .innov = residuals(x, type = "innovation", ...)[[".resid"]], ) } @@ -80,11 +93,11 @@ augment.mdl_ts <- function(x, ...){ #' @rdname glance #' @export glance.mdl_df <- function(x, ...){ - x <- gather(x, ".model", ".fit", !!!syms(mable_vars(x))) - keys <- key(x) - x <- transmute(as_tibble(x), - !!!keys, !!sym(".model"), glanced = map(!!sym(".fit"), glance)) - unnest_tbl(x, "glanced") + mbl_vars <- mable_vars(x) + x <- mutate(as_tibble(x), + dplyr::across(all_of(mbl_vars), function(x) lapply(x, glance, ...))) + x <- pivot_longer(x, mbl_vars, names_to = ".model", values_to = ".glanced") + unnest(x, ".glanced") } #' @rdname glance @@ -114,11 +127,11 @@ glance.mdl_ts <- function(x, ...){ #' @rdname tidy #' @export tidy.mdl_df <- function(x, ...){ - x <- gather(x, ".model", ".fit", !!!syms(mable_vars(x))) - keys <- key(x) - x <- transmute(as_tibble(x), - !!!keys, !!sym(".model"), tidied = map(!!sym(".fit"), tidy)) - unnest_tbl(x, "tidied") + mbl_vars <- mable_vars(x) + x <- mutate(as_tibble(x), + dplyr::across(all_of(mbl_vars), function(x) lapply(x, tidy, ...))) + x <- pivot_longer(x, mbl_vars, names_to = ".model", values_to = ".tidied") + unnest(x, ".tidied") } #' @rdname tidy diff --git a/R/components.R b/R/components.R index 8e4959b..4320c33 100644 --- a/R/components.R +++ b/R/components.R @@ -26,7 +26,8 @@ #' @rdname components #' @export components.mdl_df <- function(object, ...){ - object <- gather(object, ".model", ".fit", !!!syms(mable_vars(object))) + object <- tidyr::pivot_longer(object, all_of(mable_vars(object)), + names_to = ".model", values_to = ".fit") kv <- key_vars(object) object <- transmute(as_tibble(object), !!!syms(kv), !!sym(".model"), diff --git a/R/dable.R b/R/dable.R index d4f250a..1a5425b 100644 --- a/R/dable.R +++ b/R/dable.R @@ -7,8 +7,8 @@ #' `autoplot()` method for displaying decompositions. Beyond this, a dable #' (`dcmp_ts`) behaves very similarly to a tsibble (`tbl_ts`). #' -#' @inheritParams fable #' @param ... Arguments passed to [tsibble::tsibble()]. +#' @param response The name of the response variable column. #' @param method The name of the decomposition method. #' @param seasons A named list describing the structure of seasonal components #' (such as `period`, and `base`). @@ -95,7 +95,7 @@ tbl_sum.dcmp_ts <- function(x){ #' @export rbind.dcmp_ts <- function(...){ - .Deprecated("bind_rows()") + deprecate_warn("0.2.0", "rbind.fbl_ts()", "bind_rows()") dots <- dots_list(...) attrs <- combine_dcmp_attr(dots) diff --git a/R/definitions.R b/R/definitions.R index 6ecbf92..95f3adf 100644 --- a/R/definitions.R +++ b/R/definitions.R @@ -35,7 +35,7 @@ model_definition <- R6::R6Class(NULL, }, recall_lag = function(x, n = 1L, ...){ start <- NULL - if(self$stage == "forecast"){ + if(self$stage %in% c("generate", "forecast")){ x_expr <- enexpr(x) start <- eval_tidy(x_expr, self$recent_data) } diff --git a/R/dplyr-fable.R b/R/dplyr-fable.R index 1b6db65..1c8ec47 100644 --- a/R/dplyr-fable.R +++ b/R/dplyr-fable.R @@ -24,3 +24,23 @@ dplyr_reconstruct.fbl_ts <- function(data, template) { #' @export dplyr_reconstruct.grouped_fbl <- dplyr_reconstruct.fbl_ts + +#' @export +summarise.fbl_ts <- function(.data, ..., .groups = NULL) { + dist_var <- distribution_var(.data) + dist_ptype <- vec_ptype(.data[[dist_var]]) + resp_var <- response_vars(.data) + .data <- summarise(as_tsibble(.data), ..., .groups = .groups) + + # If the distribution is lost, return a tsibble + if(!(dist_var %in% names(.data))) { + if(!vec_is(.data[[dist_var]], dist_ptype)){ + return(.data) + } + } + + build_fable(.data, response = resp_var, distribution = dist_var) +} + +#' @export +summarise.grouped_fbl <- summarise.fbl_ts \ No newline at end of file diff --git a/R/dplyr-mable.R b/R/dplyr-mable.R index 3eab60e..880a361 100644 --- a/R/dplyr-mable.R +++ b/R/dplyr-mable.R @@ -21,7 +21,6 @@ dplyr_col_modify.mdl_df <- function(data, cols) { #' @export dplyr_reconstruct.mdl_df <- function(data, template) { res <- NextMethod() - build_mable(data, - key = !!key_vars(template), - model = !!intersect(mable_vars(template), colnames(res))) + mbl_vars <- names(which(vapply(data, inherits, logical(1L), "lst_mdl"))) + build_mable(data, key = !!key_vars(template), model = !!mbl_vars) } diff --git a/R/fable.R b/R/fable.R index 5081386..c03f0b8 100644 --- a/R/fable.R +++ b/R/fable.R @@ -2,14 +2,13 @@ #' #' A fable (forecast table) data class (`fbl_ts`) which is a tsibble-like data #' structure for representing forecasts. In extension to the key and index from -#' the tsibble (`tbl_ts`) class, a fable (`fbl_ts`) must contain columns of -#' point forecasts for the response variable(s), and a single distribution -#' column (`fcdist`). +#' the tsibble (`tbl_ts`) class, a fable (`fbl_ts`) must also contain a single +#' distribution column that uses values from the distributional package. #' #' @param ... Arguments passed to [tsibble::tsibble()]. -#' @param response The response variable(s). A single response can be specified -#' directly via `response = y`, multiple responses should be use `response = c(y, z)`. -#' @param distribution The distribution variable (given as a bare or unquoted variable). +#' @param response The character vector of response variable(s). +#' @param distribution The name of the distribution column (can be provided +#' using a bare expression). #' #' @export fable <- function(..., response, distribution){ @@ -80,6 +79,47 @@ as_fable.fbl_ts <- function(x, response, distribution, ...){ #' @export as_fable.grouped_df <- as_fable.tbl_df +#' @inheritParams forecast.mdl_df +#' @rdname as-fable +#' @export +as_fable.forecast <- function(x, ..., point_forecast = list(.mean = mean)){ + if(is.null(x$upper)){ + # Without intervals, the best guess is the point forecast + dist <- distributional::dist_degenerate(x$mean) + } else { + if(!is.null(x$lambda)){ + x$upper <- box_cox(x$upper, x$lambda) + x$lower <- box_cox(x$lower, x$lambda) + } + warn("Assuming intervals are computed from a normal distribution.") + level <- colnames(x$upper)[1] + level <- as.numeric(gsub("^[^0-9]+|%", "", level))/100 + mid <- (x$upper[,1] - x$lower[,1])/2 + mu <- x$lower[,1] + mid + sigma <- mid/(stats::qnorm((1+level)/2)) + dist <- distributional::dist_normal(mu = as.numeric(mu), sigma = as.numeric(sigma)) + if(!is.null(x$lambda)){ + dist <- distributional::dist_transformed( + dist, + transform = rlang::new_function(exprs(x = ), expr(inv_box_cox(x, !!x$lambda)), env = rlang::pkg_env("fabletools")), + inverse = rlang::new_function(exprs(x = ), expr(inv_box_cox(x, !!x$lambda)), env = rlang::pkg_env("fabletools")) + ) + } + } + out <- as_tsibble(x$mean) + dimnames(dist) <- "value" + out[["value"]] <- dist + + point_fc <- compute_point_forecasts(dist, point_forecast) + out[names(point_fc)] <- point_fc + + build_fable( + out, + response = "value", + distribution = "value" + ) +} + build_fable <- function (x, response, distribution) { # If the response (from user input) needs converting response <- eval_tidy(enquo(response)) @@ -95,6 +135,10 @@ build_fable <- function (x, response, distribution) { x, response = response, dist = distribution, model_cn = ".model", class = "fbl_ts") } + if(is.null(dimnames(fbl[[distribution]]))) { + warn("The dimnames of the fable's distribution are missing and have been set to match the response variables.") + dimnames(fbl[[distribution]]) <- response + } validate_fable(fbl) fbl } @@ -127,7 +171,7 @@ validate_fable <- function(fbl){ abort(sprintf("Could not find distribution variable `%s` in the fable. A fable must contain a distribution, if you want to remove it convert to a tsibble with `as_tsibble()`.", chr_dist)) } - vec_is(fbl[[chr_dist]], distributional::new_dist()) + vec_assert(fbl[[chr_dist]], distributional::new_dist(dimnames = response_vars(fbl))) } tbl_sum.fbl_ts <- function(x){ @@ -203,7 +247,7 @@ ungroup.grouped_fbl <- group_by.fbl_ts #' @export rbind.fbl_ts <- function(...){ - .Deprecated("bind_rows()") + deprecate_warn("0.2.0", "rbind.fbl_ts()", "bind_rows()") fbls <- dots_list(...) response <- map(fbls, response_vars) dist <- map(fbls, distribution_var) diff --git a/R/fabletools.R b/R/fabletools-package.R similarity index 85% rename from R/fabletools.R rename to R/fabletools-package.R index 57cec4a..8cf7c6b 100644 --- a/R/fabletools.R +++ b/R/fabletools-package.R @@ -12,4 +12,6 @@ globalVariables(".") #' @importFrom dplyr dplyr_row_slice dplyr_col_modify dplyr_reconstruct #' @importFrom dplyr bind_rows bind_cols #' @importFrom tidyr nest unnest gather spread +#' @importFrom tidyselect all_of +#' @importFrom lifecycle deprecate_warn NULL \ No newline at end of file diff --git a/R/features.R b/R/features.R index 88b41f6..3974766 100644 --- a/R/features.R +++ b/R/features.R @@ -28,13 +28,13 @@ features_impl <- function(.tbl, .var, features, ...){ res <- transpose(map(key_dt[[".rows"]], function(i){ out <- do.call(fn_safe, c(list(x[i]), dots[intersect(names(fmls), names(dots))])) if(is.null(names(out[["result"]]))) - names(out[["result"]]) <- rep(".?", length(out[["result"]])) + names(out[["result"]]) <- paste0("..?", seq_along(out[["result"]])) out })) err <- compact(res[["error"]]) tbl <- vctrs::vec_rbind(!!!res[["result"]]) - names(tbl)[names(tbl) == ".?"] <- "" + names(tbl)[grepl("^\\.\\.?", names(tbl))] <- "" if(is.character(nm) && nzchar(nm)){ names(tbl) <- sprintf("%s%s%s", nm, ifelse(nzchar(names(tbl)), "_", ""), names(tbl)) } @@ -73,7 +73,8 @@ features_impl <- function(.tbl, .var, features, ...){ bind_cols( key_dt[-NCOL(key_dt)], - !!!out + !!!out, + .name_repair = "minimal" ) } diff --git a/R/fitted.R b/R/fitted.R index 2919b6a..1ff66da 100644 --- a/R/fitted.R +++ b/R/fitted.R @@ -10,14 +10,12 @@ #' @importFrom stats fitted #' @export fitted.mdl_df <- function(object, ...){ - out <- gather(object, ".model", ".fit", !!!syms(mable_vars(object))) - kv <- key_vars(out) - out <- transmute(as_tibble(out), - !!!syms(kv), - !!sym(".model"), - fitted = map(!!sym(".fit"), fitted, ...) - ) - unnest_tsbl(out, "fitted", parent_key = kv) + mbl_vars <- mable_vars(object) + kv <- key_vars(object) + object <- mutate(as_tibble(object), + dplyr::across(all_of(mbl_vars), function(x) lapply(x, fitted, ...))) + object <- pivot_longer(object, mbl_vars, names_to = ".model", values_to = ".fitted") + unnest_tsbl(object, ".fitted", parent_key = c(kv, ".model")) } #' @rdname fitted.mdl_df @@ -30,5 +28,7 @@ fitted.mdl_ts <- function(object, ...){ nm <- if(length(fits) == 1) ".fitted" else map_chr(object$response, expr_name) - transmute(object$data, !!!set_names(fits, nm)) + out <- object$data[index_var(object$data)] + out[nm] <- fits + out } \ No newline at end of file diff --git a/R/forecast.R b/R/forecast.R index 8591439..ae10226 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -11,7 +11,7 @@ #' A specific forecast interval can be extracted from the distribution using the #' [`hilo()`] function, and multiple intervals can be obtained using [`report()`]. #' These intervals are stored in a single column using the `hilo` class, to -#' extract the numerical upper and lower bounds you can use [`tidyr::unnest()`]. +#' extract the numerical upper and lower bounds you can use [`unpack_hilo()`]. #' #' @param object The time series model used to produce the forecasts #' @@ -36,10 +36,11 @@ forecast <- function(object, ...){ #' A fable containing the following columns: #' - `.model`: The name of the model used to obtain the forecast. Taken from #' the column names of models in the provided mable. -#' - The point forecast, which by default is the mean. The name of this column -#' will be the same as the dependent variable in the model(s). -#' - `.distribution`. A column of objects of class `fcdist`, representing the -#' statistical distribution of the forecast in the given time period. +#' - The forecast distribution. The name of this column will be the same as the +#' dependent variable in the model(s). If multiple dependent variables exist, +#' it will be named `.distribution`. +#' - Point forecasts computed from the distribution using the functions in the +#' `point_forecast` argument. #' - All columns in `new_data`, excluding those whose names conflict with the #' above. #' @examples @@ -136,7 +137,7 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL h <- NULL } if(!is.null(bias_adjust)){ - warn("The `bias_adjust` argument for forecast() has been deprecated. Please specify the desired point forecasts using `point_forecast`.\nBias adjusted forecasts are forecast means (`point_forecast = 'mean'`), non-adjusted forecasts are medians (`point_forecast = 'median'`)") + deprecate_warn("0.2.0", "forecast(bias_adjust = )", "forecast(point_forecast = )") point_forecast <- if(bias_adjust) list(.mean = mean) else list(.median = stats::median) } if(is.null(new_data)){ @@ -146,7 +147,7 @@ forecast.mdl_ts <- function(object, new_data = NULL, h = NULL, bias_adjust = NUL # Useful variables idx <- index_var(new_data) mv <- measured_vars(new_data) - resp_vars <- map_chr(object$response, expr_name) + resp_vars <- vapply(object$response, expr_name, character(1L), USE.NAMES = FALSE) dist_col <- if(length(resp_vars) > 1) ".distribution" else resp_vars # If there's nothing to forecast, return an empty fable. @@ -173,7 +174,7 @@ Does your model require extra variables to produce forecasts?", e$message)) # Compute forecasts fc <- forecast(object$fit, new_data, specials = specials, ...) - dimnames(fc) <- vapply(object$response, expr_name, character(1L), USE.NAMES = FALSE) + dimnames(fc) <- resp_vars # Back-transform forecast distributions bt <- map(object$transformation, function(x){ @@ -181,13 +182,13 @@ Does your model require extra variables to produce forecasts?", e$message)) env <- new_environment(new_data, get_env(bt)) req_vars <- setdiff(all.vars(body(bt)), names(formals(bt))) exists_vars <- map_lgl(req_vars, exists, env) - if(any(!exists_vars)){ - bt <- custom_error(bt, sprintf( -"Unable to find all required variables to back-transform the forecasts (missing %s). -These required variables can be provided by specifying `new_data`.", - paste0("`", req_vars[!exists_vars], "`", collapse = ", ") - )) - } +# if(any(!exists_vars)){ +# bt <- custom_error(bt, sprintf( +# "Unable to find all required variables to back-transform the forecasts (missing %s). +# These required variables can be provided by specifying `new_data`.", +# paste0("`", req_vars[!exists_vars], "`", collapse = ", ") +# )) +# } set_env(bt, env) }) @@ -221,8 +222,10 @@ These required variables can be provided by specifying `new_data`.", #' Construct a new set of forecasts #' -#' Will be deprecated in the future, forecast objects should be produced with -#' either `fable` or `as_fable` functions. +#' \lifecycle{deprecated} +#' +#' This function is deprecated. `forecast()` methods for a model should return +#' a vector of distributions using the distributional package. #' #' Backtransformations are automatically handled, and so no transformations should be specified here. #' @@ -232,20 +235,8 @@ These required variables can be provided by specifying `new_data`.", #' #' @export construct_fc <- function(point, sd, dist){ - stopifnot(inherits(dist, "fcdist")) - - dist_env <- dist[[1]]$.env - if(identical(dist_env, env_dist_normal)){ - distributional::dist_normal(map_dbl(dist, `[[`, "mean"), map_dbl(dist, `[[`, "sd")) - } else if(identical(dist_env, env_dist_sim)){ - distributional::dist_sample(flatten(map(dist, `[[`, 1))) - } else if(identical(dist_env, env_dist_unknown)){ - distributional::dist_degenerate(point) - } else if(identical(dist_env, env_dist_mv_normal)){ - distributional::dist_multivariate_normal(map(dist, `[[`, "mean"), map(dist, `[[`, "sd")) - } else { - abort("Unknown forecast distribution type to convert.") - } + lifecycle::deprecate_stop("0.3.0", what = "fabletools::construct_fc()", + details = "The forecast function should now return a vector of distributions from the 'distributional' package.") } compute_point_forecasts <- function(distribution, measures){ diff --git a/R/generate.R b/R/generate.R index 400e898..b95fc5f 100644 --- a/R/generate.R +++ b/R/generate.R @@ -28,12 +28,13 @@ #' } #' @export generate.mdl_df <- function(x, new_data = NULL, h = NULL, times = 1, seed = NULL, ...){ - kv <- c(key_vars(x), ".model") mdls <- mable_vars(x) if(!is.null(new_data)){ new_data <- bind_new_data(x, new_data)[["new_data"]] } - x <- as_tibble(gather(x, ".model", ".sim", !!!syms(mdls))) + kv <- c(key_vars(x), ".model") + x <- tidyr::pivot_longer(as_tibble(x), all_of(mdls), + names_to = ".model", values_to = ".sim") # Evaluate simulations x$.sim <- map2(x[[".sim"]], diff --git a/R/lst_mdl.R b/R/lst_mdl.R index adb99fe..ebd3d41 100644 --- a/R/lst_mdl.R +++ b/R/lst_mdl.R @@ -1,4 +1,4 @@ -list_of_models <- function(x){ +list_of_models <- function(x = list()){ vctrs::new_vctr(x, class = "lst_mdl") } @@ -10,3 +10,19 @@ type_sum.lst_mdl <- function(x){ format.lst_mdl <- function(x, ...){ map_chr(x, function(x) paste0("<", model_sum(x), ">")) } + +#' @export +vec_cast.character.lst_mdl <- function(x, to, ...) format(x) + +#' @export +vec_ptype2.lst_mdl.lst_mdl <- function(x, y, ...){ + list_of_models() +} + +#' @export +vec_cast.lst_mdl.lst_mdl <- function(x, to, ...){ + if(!identical(class(x), class(to))){ + abort("Cannot combine model lists with different reconciliation strategies.") + } + x +} \ No newline at end of file diff --git a/R/mable.R b/R/mable.R index d5eee8c..82f4983 100644 --- a/R/mable.R +++ b/R/mable.R @@ -45,7 +45,7 @@ as_mable.data.frame <- function(x, key = NULL, model = NULL, ...){ build_mable(x, key = !!enquo(key), model = !!enquo(model)) } -build_mable <- function (x, key = NULL, key_data = NULL, model) { +build_mable <- function (x, key = NULL, key_data = NULL, model = NULL) { model <- names(tidyselect::eval_select(enquo(model), data = x)) if(length(unique(map(x[model], function(mdl) mdl[[1]]$response))) > 1){ @@ -121,6 +121,22 @@ gather.mdl_df <- function(data, key = "key", value = "value", ..., na.rm = FALSE build_mable(tbl, key = !!kv, model = !!mdls) } +# Adapted from tsibble:::pivot_longer.tbl_ts +#' @importFrom tidyr pivot_longer +#' @export +pivot_longer.mdl_df <- function (data, ..., names_to = "name") { + if (!has_length(names_to)) { + abort("`pivot_longer()` can't accept zero-length `names_to`.") + } + if (".value" %in% names_to) { + abort("`pivot_longer()` can't accept the special \".value\" in `names_to`.") + } + new_key <- c(key_vars(data), names_to) + tbl <- tidyr::pivot_longer(as_tibble(data), ..., names_to = names_to) + build_mable(tbl, key = !!new_key, + model = !!which(vapply(tbl, inherits, logical(1L), "lst_mdl"))) +} + #' @export select.mdl_df <- function (.data, ...){ res <- select(as_tibble(.data), ...) diff --git a/R/model_combination.R b/R/model_combination.R index 2fbe83c..a4bd955 100644 --- a/R/model_combination.R +++ b/R/model_combination.R @@ -222,7 +222,7 @@ Ops.mdl_ts <- function(e1, e2){ #' @export Ops.lst_mdl <- function(e1, e2){ - structure(map2(e1, e2, .Generic), class = c("lst_mdl", "list")) + list_of_models(map2(e1, e2, .Generic)) } #' @importFrom stats var diff --git a/R/model_null.R b/R/model_null.R index 74295cb..47c538a 100644 --- a/R/model_null.R +++ b/R/model_null.R @@ -25,43 +25,64 @@ null_model <- function(formula, ...){ is_null_model <- function(x){ if(is_model(x)) return(is_null_model(x[["fit"]])) if(inherits(x, "lst_mdl")) return(map_lgl(x, is_null_model)) - inherits(x, "null_mdl") + is.null(x) || inherits(x, "null_mdl") } #' @export forecast.null_mdl <- function(object, new_data, ...){ h <- NROW(new_data) - construct_fc(rep(NA_real_, h), rep(0, h), dist_unknown(h)) + vec_cast(rep(NA_real_, h), distributional::new_dist()) } +#' @export +forecast.NULL <- forecast.null_mdl + #' @export generate.null_mdl <- function(x, new_data, ...){ mutate(new_data, .sim = NA_real_) } +#' @export +generate.NULL <- generate.null_mdl #' @export stream.null_mdl <- function(object, new_data, ...){ object$n <- object$n + NROW(new_data) object } +#' @export +stream.NULL <- function(object, new_data, ...) { + NULL +} #' @export refit.null_mdl <- function(object, new_data, ...){ object$n <- NROW(new_data) object } +#' @export +refit.NULL <- function(object, new_data, ...) { + NULL +} #' @export residuals.null_mdl <- function(object, ...){ matrix(NA_real_, nrow = object$n, ncol = length(object$vars), dimnames = list(NULL, object$vars)) } +#' @export +residuals.NULL <- function(object, new_data, ...) { + NA_real_ +} #' @export fitted.null_mdl <- function(object, ...){ matrix(NA_real_, nrow = object$n, ncol = length(object$vars), dimnames = list(NULL, object$vars)) } +#' @export +fitted.NULL <- function(object, new_data, ...) { + NA_real_ +} #' @export glance.null_mdl <- function(x, ...){ @@ -77,8 +98,12 @@ tidy.null_mdl <- function(x, ...){ report.null_mdl <- function(object, ...){ cat("NULL model") } +#' @export +report.NULL <- report.null_mdl #' @export model_sum.null_mdl <- function(x){ "NULL model" -} \ No newline at end of file +} +#' @export +report.NULL <- report.null_mdl \ No newline at end of file diff --git a/R/parse.R b/R/parse.R index 8d46b4f..9974bd3 100644 --- a/R/parse.R +++ b/R/parse.R @@ -178,7 +178,7 @@ parse_model_lhs <- function(model){ .g = function(x){ if(all(names(x) != "response") && !is.null(attr(x, "call"))){ # parent_len <- length(eval(attr(x, "call") %||% x[[1]], envir = model$data)) - len <- map_dbl(x, function(y) length(eval(attr(y, "call") %||% y[[1]], envir = model$data, enclos = model$env))) + len <- map_dbl(x, function(y) length(eval(attr(y, "call") %||% y[[1]], envir = model$data, enclos = model$specials))) if(sum(len == max(len)) == 1){ names(x)[which.max(len)] <- "response" } @@ -206,7 +206,7 @@ parse_model_lhs <- function(model){ x <- traverse(x, .f = function(x, y) as.call(c(y[[1]], x)), .g = function(x) x[-1], - .h = function(x) {if(x == response) sym(".x") else x}, + .h = function(x) {if(identical(x, response)) sym(".x") else x}, base = function(x) is_syntactic_literal(x) || is_symbol(x) || x == response ) new_function(args = alist(.x = ), x, env = model$env) diff --git a/R/plot.R b/R/plot.R index d3de8a8..dcc4339 100644 --- a/R/plot.R +++ b/R/plot.R @@ -121,45 +121,11 @@ autolayer.tbl_ts <- function(object, .vars = NULL, ...){ #' @importFrom ggplot2 fortify #' @export fortify.fbl_ts <- function(object, level = c(80, 95)){ - resp <- response_vars(object) - dist <- distribution_var(object) - idx <- index(object) - kv <- key_vars(object) - # if(length(resp) > 1){ - # object <- object %>% - # mutate( - # .response = rep(list(factor(resp)), NROW(object)), - # value = transpose_dbl(list2(!!!syms(resp))) - # ) - # } - # - if(!is.null(level)){ - object[as.character(level)] <- map(level, hilo, x = object[[dist]]) - object[resp] <- mean(object[[dist]]) - object <- tidyr::pivot_longer(as_tibble(object), as.character(level), names_to = ".rm", values_to = ".hilo") - - if(length(resp) > 1){ - stop("Plotting multivariate forecasts is not currently supported.") - tidyr::unpack(object, c(dist, ".hilo"), names_sep = "?") - - object <- unnest_tbl(object, c(".response", "value", ".hilo")) - resp <- "value" - kv <- c(kv, ".response") - } - else{ - object[c(".lower", ".upper", ".level")] <- vec_data(object[[".hilo"]]) - } - kv <- c(kv, ".level") - - # Drop temporary col - object[c(".rm", ".hilo")] <- NULL + if(deparse(match.call()) != "fortify.fbl_ts(object = data)"){ + warn("The output of `fortify()` has changed to better suit usage with the ggdist package. +If you're using it to extract intervals, consider using `hilo()` to compute intervals, and `unpack_hilo()` to obtain values.") } - else if (length(resp) > 1) { - object <- unnest_tbl(object, c(".response", "value")) - kv <- c(kv, ".response") - } - - as_tsibble(object, key = !!kv, index = !!idx, validate = FALSE) + return(as_tibble(object)) } #' Plot a set of forecasts @@ -299,7 +265,8 @@ build_fbl_layer <- function(object, data = NULL, level = c(80, 95), gap <- left_join(gap, last_obs, by = key_vars(last_obs)) } if (length(resp_var) > 1) abort("`show_gap = FALSE` is not yet supported for multivariate forecasts.") - gap[[distribution_var(object)]] <- gap[[resp_var]] + gap[[distribution_var(object)]] <- distributional::dist_degenerate(gap[[resp_var]]) + dimnames(gap[[distribution_var(object)]]) <- resp_var gap <- as_fable(gap, index = !!idx, key = key_vars(object), response = resp_var, distribution = distribution_var(object)) diff --git a/R/quantile.R b/R/quantile.R deleted file mode 100644 index 88e255d..0000000 --- a/R/quantile.R +++ /dev/null @@ -1,366 +0,0 @@ -#' Create a forecast distribution object -#' -#' @param ... Arguments for `f` function -#' @param .env An environment produced using `new_fcdist_env` -#' -#' @rdname fcdist -#' @export -new_fcdist <- function(..., .env){ - structure( - pmap(dots_list(...), list, .env = .env), - class = c("fcdist", "list") - ) -} - -#' @param quantile A distribution function producing quantiles (such as `qnorm`) -#' @param transformation Transformation to be applied to resulting quantiles -#' from `quantile` -#' @param display Function that is used to format the distribution display -#' -#' @rdname fcdist -#' @export -new_fcdist_env <- function(quantile, transformation = list(identity), display = NULL){ - if(is.null(display)){ - display <- format_dist(as_string(enexpr(quantile))) - } - new_environment( - list(f = quantile, t = transformation, format = display, - trans = any(map_lgl(transformation, compose(`!`, is.name, body)))) - ) -} - -update_fcdist <- function(x, quantile = NULL, transformation = NULL, format_fn = NULL){ - .env_ids <- map_chr(x, function(x) env_label(x[[length(x)]])) - x <- map(split(x, .env_ids), function(dist){ - env <- env_clone(dist[[1]][[length(dist[[1]])]]) - if(!is.null(quantile)){ - env$f <- quantile - } - if(!is.null(transformation)){ - env$t <- transformation - env$trans <- any(map_lgl(transformation, compose(`!`, is.name, body))) - } - if(!is.null(format_fn)){ - env$format_fn <- format_fn - } - map(dist, function(x){x[[length(x)]] <- env; x}) - }) - structure(unsplit(x, .env_ids), class = c("fcdist", "list")) -} - -#' @importFrom stats qnorm -#' @export -Ops.fcdist <- function(e1, e2){ - ok <- switch(.Generic, `+` = , `-` = , `*` = , `/` = TRUE, FALSE) - if (!ok) { - warn(sprintf("`%s` not meaningful for distributions", .Generic)) - return(dist_unknown(max(length(e1), if (!missing(e2)) length(e2)))) - } - if(.Generic == "/" && inherits(e2, "fcdist")){ - warn(sprintf("Cannot divide by a distribution")) - return(dist_unknown(max(length(e1), if (!missing(e2)) length(e2)))) - } - if(.Generic %in% c("-", "+") && missing(e2)){ - e2 <- e1 - e1 <- if(.Generic == "+") 1 else -1 - .Generic <- "*" - } - if(.Generic == "-"){ - .Generic <- "+" - e2 <- -e2 - } - else if(.Generic == "/"){ - .Generic <- "*" - e2 <- 1/e2 - } - e_len <- c(length(e1), length(e2)) - if(max(e_len) %% min(e_len) != 0){ - warn("longer object length is not a multiple of shorter object length") - } - if(e_len[[1]] != e_len[[2]]){ - if(which.min(e_len) == 1){ - e1 <- rep(e1, e_len[[2]]) - } - else{ - e2 <- rep(e2, e_len[[1]]) - } - } - - if(is_dist_unknown(e1) || is_dist_unknown(e2)){ - return(dist_unknown(length(e1))) - } - - if(inherits(e1, "fcdist") && inherits(e2, "fcdist")){ - if(.Generic == "*"){ - warn(sprintf("Multiplying forecast distributions is not supported.")) - return(dist_unknown(max(length(e1), if (!missing(e2)) length(e2)))) - } - - grps <- paste(sep = "-", - map_chr(e1, function(x) env_label(x[[length(x)]])), - map_chr(e2, function(x) env_label(x[[length(x)]])) - ) - - e1 <- map2(split(e1, grps), split(e2, grps), function(x, y){ - if(!is_dist_normal(x) || !is_dist_normal(y)){ - warn("Combinations of non-normal forecast distributions is not supported.") - return(dist_unknown(max(length(e1), length(e2)))) - } - x <- transpose(x) %>% map(unlist, recursive = FALSE) - y <- transpose(y) %>% map(unlist, recursive = FALSE) - if(.Generic == "+"){ - x$mean <- x$mean + y$mean - x$sd <- sqrt(x$sd^2 + y$sd^2) - } - transpose(x) - }) - - return(structure(unsplit(e1, grps), class = c("fcdist", "list"))) - } - - if(inherits(e1, "fcdist")){ - dist <- e1 - scalar <- e2 - } else { - dist <- e2 - scalar <- e1 - } - if(!is.numeric(scalar)){ - warn(sprintf("Cannot %s a `%s` with a distribution", switch(.Generic, - `+` = "add", `-` = "subtract", `*` = "multiply", `/` = "divide"), class(scalar))) - return(dist_unknown(length(e1))) - } - - .env_ids <- map_chr(dist, function(x) env_label(x[[length(x)]])) - dist <- map2(split(dist, .env_ids), split(scalar, .env_ids), function(x, y){ - if(!is_dist_normal(x)){ - warn("Cannot perform calculations with this non-normal distributions") - return(dist_unknown(length(x))) - } - x <- transpose(x) %>% map(unlist, recursive = FALSE) - if(.Generic == "+"){ - x$mean <- x$mean + y - } - else if(.Generic == "*"){ - x$mean <- x$mean * y - x$sd <- x$sd * y - } - transpose(x) - }) - structure(unsplit(dist, .env_ids), class = c("fcdist", "list")) -} - -type_sum.fcdist <- function(x){ - "dist" -} - -is_vector_s3.fcdist <- function(x){ - TRUE -} - -obj_sum.fcdist <- function(x) { - rep("dist", length(x)) -} - -pillar_shaft.fcdist <- function(x, ...){ - pillar::new_pillar_shaft_simple(format(x), align = "left", min_width = 10) -} - -#' @export -print.fcdist <- function(x, ...) { - print(format(x, ...), quote = FALSE) - invisible(x) -} - -# Brief hack for vctrs support. To be replaced by distributional. -#' @importFrom vctrs vec_ptype2 -#' @method vec_ptype2 fcdist -#' @export -vec_ptype2.fcdist <- function(x, y, ...) UseMethod("vec_ptype2.fcdist", y) -#' @method vec_ptype2.fcdist default -#' @export -vec_ptype2.fcdist.default <- function(x, y, ..., x_arg = "x", y_arg = "y") { - vctrs::vec_default_ptype2(x, y, x_arg = x_arg, y_arg = y_arg) -} -#' @method vec_ptype2.fcdist fcdist -#' @export -vec_ptype2.fcdist.fcdist <- function(x, y, ..., x_arg = "x", y_arg = "y") { - x -} - -#' @importFrom vctrs vec_cast -#' @method vec_cast fcdist -#' @export -vec_cast.fcdist <- function(x, to, ...) UseMethod("vec_cast.fcdist") -#' @method vec_cast.fcdist default -#' @export -vec_cast.fcdist.default <- function(x, to, ...) vctrs::vec_default_cast(x, to) -#' @method vec_cast.fcdist fcdist -#' @export -vec_cast.fcdist.fcdist <- function(x, to, ...) x - -format_dist <- function(fn_nm){ - function(x, ...){ - out <- transpose(x) %>% - imap(function(arg, nm){ - arg <- unlist(arg, recursive = FALSE) - if(!is_list(arg)){ - out <- format(arg, digits = 2, ...) - } - else{ - out <- sprintf("%s[%i]", map_chr(arg, tibble::type_sum), map_int(arg, length)) - } - if(is_character(nm)){ - out <- paste0(nm, "=", out) - } - out - }) %>% - invoke("paste", ., sep = ", ") - - # Add dist name q() - sprintf("%s(%s)", fn_nm, out) - } -} - -#' @export -format.fcdist <- function(x, ...){ - .env_ids <- map_chr(x, function(x) possibly(env_label, ".na")(x[[length(x)]])) - split(x, .env_ids) %>% - set_names(NULL) %>% - map(function(x){ - if(!is_environment(x[[1]][[length(x[[1]])]])) return("NA") - out <- x[[1]]$.env$format(map(x, function(x) x[-length(x)])) - if(x[[1]]$.env$trans){ - out <- paste0("t(", out, ")") - } - out - }) %>% - unsplit(.env_ids) -} - -#' @export -`[.fcdist` <- function(x, ...){ - structure(NextMethod(), class = c("fcdist", "list")) -} - -#' @export -c.fcdist <- function(...){ - structure(NextMethod(), class = c("fcdist", "list")) -} - -#' @export -rep.fcdist <- function(x, ...){ - structure(NextMethod(), class = c("fcdist", "list")) -} - -#' @export -unique.fcdist <- function(x, ...){ - structure(NextMethod(), class = c("fcdist", "list")) -} - -#' @export -length.fcdist <- function(x){ - NextMethod() -} - - -#' @export -quantile.fcdist <- function(x, probs = seq(0, 1, 0.25), ...){ - .Deprecated("distributional::quantile") - env <- x[[1]][[length(x[[1]])]] - args <- transpose(x)[-length(x[[1]])] - map(probs, function(prob){ - intr <- do.call(env$f, c(list(prob), as.list(args), dots_list(...))) - if(!is.list(intr)){ - intr <- list(intr) - } - map2(env$t, intr, calc) - }) -} - -format_dist_normal <- function(x, ...){ - args <- transpose(x) %>% - map(unlist) - - # Add dist name q() - sprintf("N(%s, %s)", - format(args$mean, digits = 2, ...), - format(args$sd^2, digits = 2, ...) - ) -} - -env_dist_normal <- new_fcdist_env(function(mean, sd, ...){ - qnorm(..., mean = unlist(mean), sd = unlist(sd)) - }, display = format_dist_normal) - -#' Distributions for intervals -#' -#' @param mean vector of distributional means. -#' @param sd vector of distributional standard deviations. -#' @param ... Additional arguments passed on to quantile methods. -#' -#' @rdname distributions -#' -#' @examples -#' dist_normal(rep(3, 10), seq(0, 1, length.out=10)) -#' -#' @export -dist_normal <- function(mean, sd, ...){ - new_fcdist(mean = mean, sd = sd, ..., .env = env_dist_normal) -} - -env_dist_mv_normal <- new_fcdist_env(function(p, mean, sd, var){ - map(transpose(map2(mean, map(sd, diag), qnorm, p = p)), as.numeric) - # abort("Multivariate normal intervals are not currently supported.") -}, display = function(x, ...) rep(sprintf("MVN[%i]", length(x[[1]][["mean"]])), length(x))) - -#' @rdname distributions -#' @export -dist_mv_normal <- function(mean, sd, ...){ - new_fcdist(mean = mean, sd = sd, ..., .env = env_dist_mv_normal) -} - -#' @importFrom stats quantile -qsample <- function(p, x = list(), ...){ - map_dbl(x, function(x) as.numeric(stats::quantile(unlist(x), p, ...))) -} - -env_dist_sim <- new_fcdist_env(qsample, display = format_dist("sim")) - -#' @rdname distributions -#' -#' @param sample a list of simulated values -#' -#' @examples -#' dist_sim(list(rnorm(100), rnorm(100), rnorm(100))) -#' -#' @export -dist_sim <- function(sample, ...){ - new_fcdist(map(sample, list), ..., .env = env_dist_sim) -} - -env_dist_unknown <- new_fcdist_env(function(x, ...) rep(NA, length(x)), - display = function(x, ...) rep("?", length(x))) - -#' @rdname distributions -#' -#' @param n The number of distributions. -#' -#' @examples -#' dist_unknown(10) -#' -#' @export -dist_unknown <- function(n, ...){ - new_fcdist(vector("double", n), ..., .env = env_dist_unknown) -} - - -is_dist_normal <- function(dist){ - if(!inherits(dist, "fcdist")) return(FALSE) - identical(dist[[1]]$.env$f, env_dist_normal$f) && !dist[[1]]$.env$trans -} - -is_dist_unknown <- function(dist){ - if(!inherits(dist, "fcdist")) return(FALSE) - identical(dist[[1]]$.env$f, env_dist_unknown$f) && !dist[[1]]$.env$trans -} \ No newline at end of file diff --git a/R/reconciliation.R b/R/reconciliation.R index 60f3059..c7920d3 100644 --- a/R/reconciliation.R +++ b/R/reconciliation.R @@ -52,19 +52,28 @@ reconcile.mdl_df <- function(.data, ...){ min_trace <- function(models, method = c("wls_var", "ols", "wls_struct", "mint_cov", "mint_shrink"), sparse = NULL){ if(is.null(sparse)){ - sparse <- requireNamespace("SparseM", quietly = TRUE) + sparse <- requireNamespace("Matrix", quietly = TRUE) } structure(models, class = c("lst_mint_mdl", "lst_mdl", "list"), method = match.arg(method), sparse = sparse) } -#' @importFrom utils combn #' @export forecast.lst_mint_mdl <- function(object, key_data, new_data = NULL, h = NULL, point_forecast = list(.mean = mean), ...){ method <- object%@%"method" sparse <- object%@%"sparse" + if(sparse){ + require_package("Matrix") + as.matrix <- Matrix::as.matrix + t <- Matrix::t + diag <- function(x) if(is.vector(x)) Matrix::Diagonal(x = x) else Matrix::diag(x) + solve <- Matrix::solve + cov2cor <- Matrix::cov2cor + } else { + cov2cor <- stats::cov2cor + } point_method <- point_forecast point_forecast <- list() @@ -73,15 +82,7 @@ forecast.lst_mint_mdl <- function(object, key_data, if(length(unique(map(fc, interval))) > 1){ abort("Reconciliation of temporal hierarchies is not yet supported.") } - fc_dist <- map(fc, function(x) x[[distribution_var(x)]]) - is_normal <- all(map_lgl(fc_dist, function(x) inherits(x[[1]], "dist_normal"))) - - fc_mean <- as.matrix(invoke(cbind, map(fc_dist, mean))) - fc_var <- transpose_dbl(map(fc_dist, distributional::variance)) - # Construct S matrix - ??GA: have moved this here as I need it for Structural scaling - S <- build_smat_rows(key_data) - # Compute weights (sample covariance) res <- map(object, function(x, ...) residuals(x, ...), type = "response") if(length(unique(map_dbl(res, nrow))) > 1){ @@ -91,29 +92,32 @@ forecast.lst_mint_mdl <- function(object, key_data, res <- matrix(invoke(c, map(res, `[[`, 2)), ncol = length(object)) } + # Construct S matrix - ??GA: have moved this here as I need it for Structural scaling + agg_data <- build_key_data_smat(key_data) + n <- nrow(res) covm <- crossprod(stats::na.omit(res)) / n if(method == "ols"){ # OLS - W <- diag(nrow = nrow(covm), ncol = ncol(covm)) + W <- diag(rep(1L, nrow(covm))) } else if(method == "wls_var"){ # WLS variance scaling W <- diag(diag(covm)) } else if (method == "wls_struct"){ # WLS structural scaling - W <- diag(apply(S,1,sum)) + W <- diag(vapply(agg_data$agg,length,integer(1L))) } else if (method == "mint_cov"){ # min_trace covariance W <- covm } else if (method == "mint_shrink"){ # min_trace shrink tar <- diag(apply(res, 2, compose(crossprod, stats::na.omit))/n) - corm <- stats::cov2cor(covm) + corm <- cov2cor(covm) xs <- scale(res, center = FALSE, scale = sqrt(diag(covm))) xs <- xs[stats::complete.cases(xs),] v <- (1/(n * (n - 1))) * (crossprod(xs^2) - 1/n * (crossprod(xs))^2) diag(v) <- 0 - corapn <- stats::cov2cor(tar) + corapn <- cov2cor(tar) d <- (corm - corapn)^2 lambda <- sum(v)/sum(d) lambda <- max(min(lambda, 1), 0) @@ -129,79 +133,170 @@ forecast.lst_mint_mdl <- function(object, key_data, } # Reconciliation matrices - R1 <- stats::cov2cor(W) - W_h <- map(fc_var, function(var) diag(sqrt(var))%*%R1%*%t(diag(sqrt(var)))) - - if(sparse){ - require_package("SparseM") - require_package("methods") - as.matrix <- SparseM::as.matrix - t <- SparseM::t - diag <- SparseM::diag - - row_btm <- key_data %>% - dplyr::filter( - !!!map(colnames(key_data[-length(key_data)]), function(x){ - expr(!is_aggregated(!!sym(x))) - }) - ) - row_btm <- vctrs::vec_c(!!!row_btm[[length(row_btm)]]) - row_agg <- seq_len(NROW(key_data))[-row_btm] - - i_pos <- which(as.logical(S[row_btm,])) - S <- SparseM::as.matrix.csr(S) - J <- methods::new("matrix.csr", ra = rep(1,ncol(S)), ja = row_btm, - ia = c((i_pos-1L)%/%ncol(S)+1L, ncol(S) + 1L), dimension = rev(dim(S))) - - U <- cbind(methods::as(diff(dim(J)), "matrix.diag.csr"), SparseM::as.matrix.csr(-S[row_agg,])) - U <- U[, order(c(row_agg, row_btm))] - - P <- J - J%*%W%*%t(U)%*%SparseM::solve(U%*%W%*%t(U), eps = Inf)%*%U + if(sparse){ + row_btm <- agg_data$leaf + row_agg <- seq_len(nrow(key_data))[-row_btm] + S <- Matrix::sparseMatrix( + i = rep(seq_along(agg_data$agg), lengths(agg_data$agg)), + j = vec_c(!!!agg_data$agg), + x = rep(1, sum(lengths(agg_data$agg)))) + J <- Matrix::sparseMatrix(i = S[row_btm,]@i+1, j = row_btm, x = 1L, + dims = rev(dim(S))) + U <- cbind( + Matrix::Diagonal(diff(dim(J))), + -S[row_agg,,drop = FALSE] + ) + U <- U[, order(c(row_agg, row_btm)), drop = FALSE] + Ut <- t(U) + WUt <- W %*% Ut + P <- J - J %*% WUt %*% solve(U %*% WUt, U) + # P <- J - J%*%W%*%t(U)%*%solve(U%*%W%*%t(U))%*%U } else { + S <- matrix(0L, nrow = length(agg_data$agg), ncol = max(vec_c(!!!agg_data$agg))) + S[length(agg_data$agg)*(vec_c(!!!agg_data$agg)-1) + rep(seq_along(agg_data$agg), lengths(agg_data$agg))] <- 1L R <- t(S)%*%solve(W) P <- solve(R%*%S)%*%R } - # Apply to forecasts - fc_mean <- as.matrix(S%*%P%*%t(fc_mean)) - fc_mean <- split(fc_mean, row(fc_mean)) - if(is_normal){ - fc_var <- map(W_h, function(W) diag(S%*%P%*%W%*%t(P)%*%t(S))) - fc_dist <- map2(fc_mean, transpose_dbl(map(fc_var, sqrt)), distributional::dist_normal) - } else { - fc_dist <- map(fc_mean, distributional::dist_degenerate) - } - - # Update fables - map2(fc, fc_dist, function(fc, dist){ - dimnames(dist) <- dimnames(fc[[distribution_var(fc)]]) - fc[[distribution_var(fc)]] <- dist - point_fc <- compute_point_forecasts(dist, point_method) - fc[names(point_fc)] <- point_fc - fc - }) + reconcile_fbl_list(fc, S, P, W, point_forecast = point_method) } +#' Bottom up forecast reconciliation +#' +#' \lifecycle{experimental} +#' +#' Reconciles a hierarchy using the bottom up reconciliation method. The +#' response variable of the hierarchy must be aggregated using sums. The +#' forecasted time points must match for all series in the hierarchy. +#' +#' @param models A column of models in a mable. +#' +#' @seealso +#' [`reconcile()`], [`aggregate_key()`] +#' @export bottom_up <- function(models){ structure(models, class = c("lst_btmup_mdl", "lst_mdl", "list")) } -#' @importFrom utils combn #' @export forecast.lst_btmup_mdl <- function(object, key_data, point_forecast = list(.mean = mean), ...){ # Keep only bottom layer - S <- build_smat_rows(key_data) - object <- object[rowSums(S) == 1] + agg_data <- build_key_data_smat(key_data) + + S <- matrix(0L, nrow = length(agg_data$agg), ncol = max(vec_c(!!!agg_data$agg))) + S[length(agg_data$agg)*(vec_c(!!!agg_data$agg)-1) + rep(seq_along(agg_data$agg), lengths(agg_data$agg))] <- 1L + + btm <- which(rowSums(S) == 1) + object <- object[btm] point_method <- point_forecast point_forecast <- list() + # Get base forecasts - fc <- NextMethod() + fc <- vector("list", nrow(S)) + fc[btm] <- NextMethod() + + # Add dummy forecasts to unused levels + fc[seq_along(fc)[-btm]] <- fc[btm[1]] + + P <- matrix(0L, nrow = ncol(S), ncol = nrow(S)) + P[(btm-1L)*nrow(P) + seq_len(nrow(P))] <- 1L + + reconcile_fbl_list(fc, S, P, W = diag(nrow(S)), + point_forecast = point_method) +} + + +#' Top down forecast reconciliation +#' +#' \lifecycle{experimental} +#' +#' Reconciles a hierarchy using the top down reconciliation method. The +#' response variable of the hierarchy must be aggregated using sums. The +#' forecasted time points must match for all series in the hierarchy. +#' +#' @param models A column of models in a mable. +#' @param method The reconciliation method to use. +#' +#' @seealso +#' [`reconcile()`], [`aggregate_key()`] +#' +#' @export +top_down <- function(models, method = c("forecast_proportions", "average_proportions", "proportion_averages")){ + structure(models, class = c("lst_topdwn_mdl", "lst_mdl", "list"), + method = match.arg(method)) +} + +#' @export +forecast.lst_topdwn_mdl <- function(object, key_data, + point_forecast = list(.mean = mean), ...){ + method <- object%@%"method" + point_method <- point_forecast + point_forecast <- list() + + # TODO: Add check for grouped hierarchies + agg_data <- build_key_data_smat(key_data) + S <- matrix(0L, nrow = length(agg_data$agg), ncol = max(vec_c(!!!agg_data$agg))) + S[length(agg_data$agg)*(vec_c(!!!agg_data$agg)-1) + rep(seq_along(agg_data$agg), lengths(agg_data$agg))] <- 1L + + # Identify top and bottom level + top <- which.max(rowSums(S)) + btm <- which(rowSums(S) == 1L) + + if(method == "forecast_proportions") { + abort("`method = 'forecast_proportions'` is not yet supported") + fc <- NextMethod() + fc_mean <- lapply(fc, function(x) mean(x[[distribution_var(x)]])) + } else { + # Compute dis-aggregation matrix + history <- lapply(object, function(x) response(x)[[".response"]]) + top_y <- history[[top]] + btm_y <- history[btm] + if (method == "average_proportions") { + prop <- map_dbl(btm_y, function(y) mean(y/top_y)) + } else if (method == "proportion_averages") { + prop <- map_dbl(btm_y, mean) / mean(top_y) + } else { + abort("Unkown `top_down()` reconciliation `method`.") + } + + # Keep only top layer + object <- object[top] + + # Get base forecasts + fc <- vector("list", nrow(S)) + fc[top] <- NextMethod() + + # Add dummy forecasts to unused levels + fc[seq_along(fc)[-top]] <- fc[top] + } + + P <- matrix(0L, nrow = ncol(S), ncol = nrow(S)) + P[,top] <- prop + + reconcile_fbl_list(fc, S, P, W = diag(nrow(S)), + point_forecast = point_method) +} + +reconcile_fbl_list <- function(fc, S, P, W, point_forecast, SP = NULL) { if(length(unique(map(fc, interval))) > 1){ abort("Reconciliation of temporal hierarchies is not yet supported.") } + if(!inherits(S, "matrix")) { + # Use sparse functions + require_package("Matrix") + as.matrix <- Matrix::as.matrix + t <- Matrix::t + diag <- function(x) if(is.vector(x)) Matrix::Diagonal(x = x) else Matrix::diag(x) + cov2cor <- Matrix::cov2cor + } else { + cov2cor <- stats::cov2cor + } + if(is.null(SP)) { + SP <- S%*%P + } fc_dist <- map(fc, function(x) x[[distribution_var(x)]]) is_normal <- all(map_lgl(fc_dist, function(x) inherits(x[[1]], "dist_normal"))) @@ -210,61 +305,29 @@ forecast.lst_btmup_mdl <- function(object, key_data, fc_var <- transpose_dbl(map(fc_dist, distributional::variance)) # Apply to forecasts - fc_mean <- as.matrix(S%*%t(fc_mean)) + fc_mean <- as.matrix(SP%*%t(fc_mean)) fc_mean <- split(fc_mean, row(fc_mean)) if(is_normal){ - fc_var <- map(fc_var, function(W) diag(S%*%diag(W)%*%t(S))) + R1 <- cov2cor(W) + W_h <- map(fc_var, function(var) diag(sqrt(var))%*%R1%*%t(diag(sqrt(var)))) + fc_var <- map(W_h, function(W) diag(SP%*%W%*%t(SP))) fc_dist <- map2(fc_mean, transpose_dbl(map(fc_var, sqrt)), distributional::dist_normal) } else { fc_dist <- map(fc_mean, distributional::dist_degenerate) } # Update fables - pmap(list(rep_along(fc_mean, fc[1]), fc_mean, fc_dist), function(fc, point, dist){ + map2(fc, fc_dist, function(fc, dist){ dimnames(dist) <- dimnames(fc[[distribution_var(fc)]]) fc[[distribution_var(fc)]] <- dist - point_fc <- compute_point_forecasts(dist, point_method) + point_fc <- compute_point_forecasts(dist, point_forecast) fc[names(point_fc)] <- point_fc fc }) } -build_smat <- function(key_data){ - row_col <- sym(colnames(key_data)[length(key_data)]) - - fct <- key_data %>% - unnest(!!row_col) %>% - dplyr::arrange(!!row_col) %>% - select(!!expr(-!!row_col)) %>% - dplyr::mutate_all(factor) - - lvls <- invoke(paste, fct[stats::complete.cases(fct),]) - - smat <- map(fct, function(x){ - mat <- rep(0, length(x)*length(levels(x))) - i <- which(!is.na(x)) - if(length(i) == length(x) && length(levels(x)) > 1){ - abort("Reconciliation of disjoint hierarchical structures is not yet supported.") - } - j <- as.numeric(x[i]) - mat[i + length(x) * (j-1)] <- 1 - mat <- matrix(mat, nrow = length(x), ncol = length(levels(x)), - dimnames = list(NULL, levels(x))) - mat[is.na(x), ] <- 1 - mat - }) - - join_smat <- function(x, y){ - smat <- map(split(x, col(x)), `*`, y) - smat <- map2(smat, colnames(x), function(S, cn) `colnames<-`(S, paste(cn, colnames(S)))) - invoke(cbind, smat) - } - - reduce(smat, join_smat)[,lvls,drop = FALSE] -} - - build_smat_rows <- function(key_data){ + lifecycle::deprecate_warn("0.2.1", "fabletools::build_smat_rows()", "fabletools::build_key_data_smat()") row_col <- sym(colnames(key_data)[length(key_data)]) smat <- key_data %>% @@ -310,3 +373,46 @@ build_smat_rows <- function(key_data){ return(smat) } + +build_key_data_smat <- function(x){ + kv <- names(x)[-ncol(x)] + agg_shadow <- as_tibble(map(x[kv], is_aggregated)) + grp <- as_tibble(vctrs::vec_group_loc(agg_shadow)) + num_agg <- rowSums(grp$key) + # Initialise comparison leafs with known/guaranteed leafs + x_leaf <- x[vec_c(!!!grp$loc[which(num_agg == min(num_agg))]),] + + # Sort by disaggregation to identify aggregated leafs in order + grp <- grp[order(num_agg),] + + grp$match <- lapply(unname(split(grp, seq_len(nrow(grp)))), function(level){ + disagg_col <- which(!vec_c(!!!level$key)) + agg_idx <- level[["loc"]][[1]] + pos <- vec_match(x_leaf[disagg_col], x[agg_idx, disagg_col]) + pos <- vec_group_loc(pos) + pos <- pos[!is.na(pos$key),] + # Add non-matches as leaf nodes + agg_leaf <- setdiff(seq_along(agg_idx), pos$key) + if(!is_empty(agg_leaf)){ + pos <- vec_rbind( + pos, + structure(list(key = agg_leaf, loc = as.list(seq_along(agg_leaf) + nrow(x_leaf))), + class = "data.frame", row.names = agg_leaf) + ) + x_leaf <<- vec_rbind( + x_leaf, + x[agg_idx[agg_leaf],] + ) + } + pos$loc[order(pos$key)] + }) + if(any(lengths(grp$loc) != lengths(grp$match))) { + abort("An error has occurred when constructing the summation matrix.\nPlease report this bug here: https://github.com/tidyverts/fabletools/issues") + } + idx_leaf <- vec_c(!!!x_leaf$.rows) + x$.rows[unlist(x$.rows)[vec_c(!!!grp$loc)]] <- vec_c(!!!grp$match) + return(list(agg = x$.rows, leaf = idx_leaf)) + # out <- matrix(0L, nrow = nrow(x), ncol = length(idx_leaf)) + # out[nrow(x)*(vec_c(!!!x$.rows)-1) + rep(seq_along(x$.rows), lengths(x$.rows))] <- 1L + # out +} diff --git a/R/refit.R b/R/refit.R index f966308..2ea5f4a 100644 --- a/R/refit.R +++ b/R/refit.R @@ -33,7 +33,10 @@ refit.mdl_df <- function(object, new_data, ...){ #' @export refit.lst_mdl <- function(object, new_data, ...){ - `class<-`(map2(object, new_data, refit, ...), class(object)) + attrb <- attributes(object) + object <- map2(object, new_data, refit, ...) + attributes(object) <- attrb + object } #' @rdname refit diff --git a/R/residuals.R b/R/residuals.R index 45e72eb..ca69a01 100644 --- a/R/residuals.R +++ b/R/residuals.R @@ -9,14 +9,12 @@ #' @importFrom stats residuals #' @export residuals.mdl_df <- function(object, ...){ - out <- gather(object, ".model", ".fit", !!!syms(mable_vars(object))) - kv <- key_vars(out) - out <- transmute(as_tibble(out), - !!!syms(kv), - !!sym(".model"), - residuals = map(!!sym(".fit"), residuals, ...) - ) - unnest_tsbl(out, "residuals", parent_key = kv) + mbl_vars <- mable_vars(object) + kv <- key_vars(object) + object <- mutate(as_tibble(object), + dplyr::across(all_of(mbl_vars), function(x) lapply(x, residuals, ...))) + object <- pivot_longer(object, mbl_vars, names_to = ".model", values_to = ".resid") + unnest_tsbl(object, ".resid", parent_key = c(kv, ".model")) } #' @param type The type of residuals to compute. If `type="response"`, residuals on the back-transformed data will be computed. @@ -31,12 +29,10 @@ residuals.mdl_ts <- function(object, type = "innovation", ...){ else{ .resid <- residuals(object$fit, type = type, ...) if(is.null(.resid)){ - warn(sprintf( -'Residuals of type `%s` are not supported for %s models. -Defaulting to `type="response"`', type, model_sum(object))) - .resid <- response(object) - .fits <- fitted(object) - .resid <- as.matrix(.resid[measured_vars(.resid)]) - as.matrix(.fits[measured_vars(.fits)]) +# warn(sprintf( +# 'Residuals of type `%s` are not supported for %s models. +# Defaulting to `type="response"`', type, model_sum(object))) + return(residuals(object, type = "response", ...)) } } .resid <- as.matrix(.resid) @@ -44,5 +40,7 @@ Defaulting to `type="response"`', type, model_sum(object))) .resid <- split(.resid, col(.resid)) nm <- if(length(.resid) == 1) ".resid" else map_chr(object$response, expr_name) - transmute(object$data, !!!set_names(.resid, nm)) + out <- object$data[index_var(object$data)] + out[nm] <- .resid + out } \ No newline at end of file diff --git a/R/response.R b/R/response.R index bed8588..3864bdf 100644 --- a/R/response.R +++ b/R/response.R @@ -13,14 +13,15 @@ response <- function(object, ...){ #' @export response.mdl_df <- function(object, ...){ - out <- gather(object, ".model", ".fit", !!!syms(mable_vars(object))) - kv <- key_vars(out) - out <- transmute(as_tibble(out), + object <- tidyr::pivot_longer(object, all_of(mable_vars(object)), + names_to = ".model", values_to = ".fit") + kv <- c(key_vars(object), ".model") + object <- transmute(as_tibble(object), !!!syms(kv), !!sym(".model"), response = map(!!sym(".fit"), response) ) - unnest_tsbl(out, "response", parent_key = kv) + unnest_tsbl(object, "response", parent_key = kv) } #' @export diff --git a/R/temporal_aggregation.R b/R/temporal_aggregation.R new file mode 100644 index 0000000..3f8d937 --- /dev/null +++ b/R/temporal_aggregation.R @@ -0,0 +1,147 @@ +# Adapted from cut.Date +date_breaks <- function(x, breaks, start_monday = TRUE, offset = TRUE){ + # Currently only dates are supported + x <- as.Date(x) + by2 <- strsplit(breaks, " ", fixed = TRUE)[[1L]] + if (length(by2) > 2L || length(by2) < 1L) + stop("invalid specification of 'breaks'") + valid <- pmatch(by2[length(by2)], c("days", "weeks", + "months", "years", "quarters")) + if (is.na(valid)) + stop("invalid specification of 'breaks'") + start <- as.POSIXlt(min(x, na.rm = TRUE)) + if (valid == 1L) + incr <- 1L + if (valid == 2L) { + start$mday <- start$mday - start$wday + if (start_monday) + start$mday <- start$mday + ifelse(start$wday > + 0L, 1L, -6L) + start$isdst <- -1L + incr <- 7L + } + if (valid == 3L) { + start$mday <- 1L + start$isdst <- -1L + end <- as.POSIXlt(max(x, na.rm = TRUE)) + step <- if (length(by2) == 2L) + as.integer(by2[1L]) + else 1L + end <- as.POSIXlt(end + (31 * step * 86400)) + end$mday <- 1L + end$isdst <- -1L + breaks <- as.Date(seq(start, end, breaks)) + } + else if (valid == 4L) { + start$mon <- 0L + start$mday <- 1L + start$isdst <- -1L + end <- as.POSIXlt(max(x, na.rm = TRUE)) + step <- if (length(by2) == 2L) + as.integer(by2[1L]) + else 1L + end <- as.POSIXlt(end + (366 * step * 86400)) + end$mon <- 0L + end$mday <- 1L + end$isdst <- -1L + breaks <- as.Date(seq(start, end, breaks)) + } + else if (valid == 5L) { + qtr <- rep(c(0L, 3L, 6L, 9L), each = 3L) + start$mon <- qtr[start$mon + 1L] + start$mday <- 1L + start$isdst <- -1L + maxx <- max(x, na.rm = TRUE) + end <- as.POSIXlt(maxx) + step <- if (length(by2) == 2L) + as.integer(by2[1L]) + else 1L + end <- as.POSIXlt(end + (93 * step * 86400)) + end$mon <- qtr[end$mon + 1L] + end$mday <- 1L + end$isdst <- -1L + breaks <- as.Date(seq(start, end, paste(step * 3L, + "months"))) + lb <- length(breaks) + if (maxx < breaks[lb - 1]) + breaks <- breaks[-lb] + } + else { + start <- as.Date(start) + if (length(by2) == 2L) + incr <- incr * as.integer(by2[1L]) + maxx <- max(x, na.rm = TRUE) + breaks <- seq(start, maxx + incr, breaks) + breaks <- breaks[seq_len(1L + max(which(breaks <= + maxx)))] + } + if(offset == "end" || (is.logical(offset) && offset)) { + breaks <- breaks + (x[length(x)] - breaks[length(breaks)]) + } else if (offset == "start") { + breaks <- breaks + (x[1] - breaks[1]) + } + breaks +} + +bin_date <- function(time, breaks, offset){ + if(is.character(breaks) && length(breaks) == 1){ + breaks <- date_breaks(time, breaks, offset = offset) + } + bincode <- .bincode(unclass(as.Date(time)), unclass(breaks), right = FALSE) + list( + bin = bincode, + breaks = breaks, + complete_size = diff(breaks) + ) +} + +#' Expand a dataset to include temporal aggregates +#' +#' \lifecycle{experimental} +#' +#' This feature is very experimental. It currently allows for temporal +#' aggregation of daily data as a proof of concept. +#' +#' @inheritParams aggregate_key +#' @param .window Temporal aggregations to include. The default (NULL) will +#' automatically identify appropriate temporal aggregations. This can be +#' specified in several ways (see details). +#' @param .offset Offset the temporal aggregation windows to align with the start +#' or end of the data. If FALSE, no offset will be applied (giving common +#' breakpoints for temporal bins.) +#' @param .bin_size Temporary. Define the number of observations in each temporal bucket +#' +#' @details +#' The aggregation `.window` can be specified in several ways: +#' * A character string, containing one of "day", "week", "month", "quarter" or +#' "year". This can optionally be preceded by a (positive or negative) integer +#' and a space, or followed by "s". +#' * A number, taken to be in days. +#' * A [`difftime`] object. +#' +#' @examples +#' library(tsibble) +#' pedestrian %>% +#' # Currently only supports daily data +#' index_by(Date) %>% +#' dplyr::summarise(Count = sum(Count)) %>% +#' # Compute weekly aggregates +#' fabletools:::aggregate_index("1 week", Count = sum(Count)) +aggregate_index <- function(.data, .window, ..., .offset = "end", .bin_size = NULL){ + idx <- index_var(.data) + # Compute temporal bins and bin sizes + new_index <- bin_date(.data[[idx]], .window, .offset) + if(!is.null(.bin_size)) new_index$complete_size <- vec_recycle(.bin_size, length(new_index$complete_size)) + as_tibble(.data) %>% + # Compute groups of temporal bins + group_by( + !!idx := !!{new_index$breaks[new_index$bin]}, + !!!key(.data) + ) %>% + # Keep only complete windows, currently assumes daily base interval + filter(dplyr::n() == (!!new_index$complete_size)[match((!!sym(idx))[1], !!new_index$breaks)]) %>% + # Compute aggregates + summarise(..., .groups = "drop") %>% + # Rebuild tsibble + as_tsibble(key = key_vars(.data), index = !!index(.data)) +} diff --git a/R/utils.R b/R/utils.R index e35b767..aabdc12 100644 --- a/R/utils.R +++ b/R/utils.R @@ -37,25 +37,25 @@ make_future_data <- function(.data, h = NULL){ } if(is.null(h)) n <- n*2 - # tsibble::new_data(.data, round(n)) + tsibble::new_data(.data, round(n)) # Re-implemented here using a simpler/faster method - n <- round(n) - - idx <- index_var(.data) - itvl <- interval(.data) - tunit <- default_time_units(itvl) - - idx_max <- max(.data[[idx]]) - if(is.factor(idx_max)){ - abort("Cannot automatically create `new_data` from a factor/ordered time index. Please provide `new_data` directly.") - } - - .data <- list2(!!idx := seq(idx_max, by = tunit, length.out = n+1)[-1]) - build_tsibble_meta( - new_tibble(.data, nrow = n), - key_data = new_tibble(list(.rows = list(seq_len(n))), nrow = 1), - index = idx, index2 = idx, ordered = TRUE, interval = itvl - ) + # n <- round(n) + # + # idx <- index_var(.data) + # itvl <- interval(.data) + # tunit <- default_time_units(itvl) + # + # idx_max <- max(.data[[idx]]) + # if(is.factor(idx_max)){ + # abort("Cannot automatically create `new_data` from a factor/ordered time index. Please provide `new_data` directly.") + # } + # + # .data <- list2(!!idx := seq(idx_max, by = tunit, length.out = n+1)[-1]) + # build_tsibble_meta( + # new_tibble(.data, nrow = n), + # key_data = new_tibble(list(.rows = list(seq_len(n))), nrow = 1), + # index = idx, index2 = idx, ordered = TRUE, interval = itvl + # ) } bind_new_data <- function(object, new_data){ diff --git a/R/vctrs-dable.R b/R/vctrs-dable.R index b85cdf1..13ab81c 100644 --- a/R/vctrs-dable.R +++ b/R/vctrs-dable.R @@ -102,3 +102,12 @@ vec_cast.tbl_df.dcmp_ts <- function(x, to, ...) { vec_cast.data.frame.dcmp_ts <- function(x, to, ...) { df_cast(x, to, ...) } + +#' @export +vec_cast.tbl_ts.dcmp_ts <- function(x, to, ...) { + tbl <- tib_cast(x, to, ...) + build_tsibble( + tbl, key = key_vars(to), index = index_var(to), index2 = index2_var(to), + ordered = TRUE, validate = TRUE, .drop = key_drop_default(to) + ) +} \ No newline at end of file diff --git a/R/vctrs-fable.R b/R/vctrs-fable.R index 5153efb..b988058 100644 --- a/R/vctrs-fable.R +++ b/R/vctrs-fable.R @@ -85,6 +85,15 @@ vec_cast.fbl_ts.tbl_df <- function(x, to, ...) { build_fable(tsbl, response = response_vars(to), distribution = distribution_var(to)) } +#' @export +vec_cast.tbl_ts.fbl_ts <- function(x, to, ...) { + tbl <- tib_cast(x, to, ...) + build_tsibble( + tbl, key = key_vars(to), index = index_var(to), index2 = index2_var(to), + ordered = TRUE, validate = TRUE, .drop = key_drop_default(to) + ) +} + #' @export vec_cast.fbl_ts.data.frame <- vec_cast.fbl_ts.tbl_df diff --git a/R/zzz.R b/R/zzz.R index 8408e3c..7d15b36 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,17 +1,11 @@ # nocov start .onLoad <- function(...) { register_s3_method("pillar", "type_sum", "mdl_ts") - register_s3_method("pillar", "type_sum", "fcdist") register_s3_method("pillar", "type_sum", "lst_mdl") register_s3_method("pillar", "type_sum", "fbl_ts") - register_s3_method("pillar", "obj_sum", "fcdist") - - register_s3_method("pillar", "pillar_shaft", "fcdist") register_s3_method("pillar", "pillar_shaft", "agg_vec") - register_s3_method("pillar", "is_vector_s3", "fcdist") - register_s3_method("tibble", "tbl_sum", "dcmp_ts") register_s3_method("tibble", "tbl_sum", "mdl_df") register_s3_method("tibble", "tbl_sum", "fbl_ts") diff --git a/README.md b/README.md index fc26ace..c3ee9d9 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,6 @@ You can install the **development** version from [GitHub](https://github.com/tidyverts/fabletools): ``` r -# install.packages("devtools") -devtools::install_github("tidyverts/fabletools") +# install.packages("remotes") +remotes::install_github("tidyverts/fabletools") ``` diff --git a/build/fabletools.pdf b/build/fabletools.pdf index a34b03e..53e7a70 100644 Binary files a/build/fabletools.pdf and b/build/fabletools.pdf differ diff --git a/build/vignette.rds b/build/vignette.rds index 37847fe..87d5eff 100644 Binary files a/build/vignette.rds and b/build/vignette.rds differ diff --git a/inst/WORDLIST b/inst/WORDLIST index 1314b80..36bc89c 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -10,6 +10,7 @@ etc forecast's formals fourier +ggdist Heeyoung hilo horison diff --git a/inst/doc/extension_models.Rmd b/inst/doc/extension_models.Rmd index 3938f68..7853bb0 100644 --- a/inst/doc/extension_models.Rmd +++ b/inst/doc/extension_models.Rmd @@ -2,7 +2,7 @@ title: "Extending fabletools: Models" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{extension_models} + %\VignetteIndexEntry{Extending fabletools: Models} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- diff --git a/inst/doc/extension_models.html b/inst/doc/extension_models.html index d319f58..ddf5fad 100644 --- a/inst/doc/extension_models.html +++ b/inst/doc/extension_models.html @@ -14,6 +14,22 @@ Extending fabletools: Models + @@ -429,6 +445,11 @@

The training function

Methods for models

+++++ diff --git a/man/agg_vec.Rd b/man/agg_vec.Rd new file mode 100644 index 0000000..e22d687 --- /dev/null +++ b/man/agg_vec.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aggregate.R +\name{agg_vec} +\alias{agg_vec} +\title{Create an aggregation vector} +\usage{ +agg_vec(x = character(), aggregated = logical(vec_size(x))) +} +\arguments{ +\item{x}{The vector of values.} + +\item{aggregated}{A logical vector to identify which values are .} +} +\description{ +\lifecycle{maturing} +} +\details{ +An aggregation vector extends usual vectors by adding values. +These vectors are typically produced via the \code{\link[=aggregate_key]{aggregate_key()}} function, +however it can be useful to create them manually to produce more complicated +hierarchies (such as unbalanced hierarchies). +} diff --git a/man/aggregate_index.Rd b/man/aggregate_index.Rd new file mode 100644 index 0000000..5c42bff --- /dev/null +++ b/man/aggregate_index.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/temporal_aggregation.R +\name{aggregate_index} +\alias{aggregate_index} +\title{Expand a dataset to include temporal aggregates} +\usage{ +aggregate_index(.data, .window, ..., .offset = "end", .bin_size = NULL) +} +\arguments{ +\item{.data}{A tsibble.} + +\item{.window}{Temporal aggregations to include. The default (NULL) will +automatically identify appropriate temporal aggregations. This can be +specified in several ways (see details).} + +\item{...}{<\code{\link[dplyr:dplyr_data_masking]{data-masking}}> Name-value pairs of summary +functions. The name will be the name of the variable in the result. + +The value can be: +\itemize{ +\item A vector of length 1, e.g. \code{min(x)}, \code{n()}, or \code{sum(is.na(y))}. +\item A vector of length \code{n}, e.g. \code{quantile()}. +\item A data frame, to add multiple columns from a single expression. +}} + +\item{.offset}{Offset the temporal aggregation windows to align with the start +or end of the data. If FALSE, no offset will be applied (giving common +breakpoints for temporal bins.)} + +\item{.bin_size}{Temporary. Define the number of observations in each temporal bucket} +} +\description{ +\lifecycle{experimental} +} +\details{ +This feature is very experimental. It currently allows for temporal +aggregation of daily data as a proof of concept. + +The aggregation \code{.window} can be specified in several ways: +\itemize{ +\item A character string, containing one of "day", "week", "month", "quarter" or +"year". This can optionally be preceded by a (positive or negative) integer +and a space, or followed by "s". +\item A number, taken to be in days. +\item A \code{\link{difftime}} object. +} +} +\examples{ +library(tsibble) +pedestrian \%>\% + # Currently only supports daily data + index_by(Date) \%>\% + dplyr::summarise(Count = sum(Count)) \%>\% + # Compute weekly aggregates + fabletools:::aggregate_index("1 week", Count = sum(Count)) +} diff --git a/man/as-dable.Rd b/man/as-dable.Rd index 03b59f3..75b4196 100644 --- a/man/as-dable.Rd +++ b/man/as-dable.Rd @@ -17,8 +17,7 @@ as_dable(x, ...) \item{...}{Additional arguments passed to methods} -\item{response}{The response variable(s). A single response can be specified -directly via \code{response = y}, multiple responses should be use \code{response = c(y, z)}.} +\item{response}{The character vector of response variable(s).} \item{method}{The name of the decomposition method.} diff --git a/man/as-fable.Rd b/man/as-fable.Rd index 02157e1..232f16b 100644 --- a/man/as-fable.Rd +++ b/man/as-fable.Rd @@ -7,6 +7,7 @@ \alias{as_fable.tbl_df} \alias{as_fable.fbl_ts} \alias{as_fable.grouped_df} +\alias{as_fable.forecast} \title{Coerce to a fable object} \usage{ as_fable(x, ...) @@ -20,16 +21,23 @@ as_fable(x, ...) \method{as_fable}{fbl_ts}(x, response, distribution, ...) \method{as_fable}{grouped_df}(x, response, distribution, ...) + +\method{as_fable}{forecast}(x, ..., point_forecast = list(.mean = mean)) } \arguments{ \item{x}{Object to be coerced to a fable (\code{fbl_ts})} \item{...}{Additional arguments passed to methods} -\item{response}{The response variable(s). A single response can be specified -directly via \code{response = y}, multiple responses should be use \code{response = c(y, z)}.} +\item{response}{The character vector of response variable(s).} + +\item{distribution}{The name of the distribution column (can be provided +using a bare expression).} -\item{distribution}{The distribution variable (given as a bare or unquoted variable).} +\item{point_forecast}{The point forecast measure(s) which should be returned +in the resulting fable. Specified as a named list of functions which accept +a distribution and return a vector. To compute forecast medians, you can use +\code{list(.median = median)}.} } \description{ Coerce to a fable object diff --git a/man/augment.Rd b/man/augment.Rd index cad9c17..c9cb815 100644 --- a/man/augment.Rd +++ b/man/augment.Rd @@ -7,12 +7,14 @@ \usage{ \method{augment}{mdl_df}(x, ...) -\method{augment}{mdl_ts}(x, ...) +\method{augment}{mdl_ts}(x, type = NULL, ...) } \arguments{ \item{x}{A mable.} \item{...}{Arguments for model methods.} + +\item{type}{Deprecated.} } \description{ Uses a fitted model to augment the response variable with fitted values and @@ -26,7 +28,7 @@ library(tsibbledata) # Forecasting with an ETS(M,Ad,A) model to Australian beer production aus_production \%>\% model(ets = ETS(log(Beer) ~ error("M") + trend("Ad") + season("A"))) \%>\% - augment(type = "response") + augment() } } diff --git a/man/bottom_up.Rd b/man/bottom_up.Rd new file mode 100644 index 0000000..286da78 --- /dev/null +++ b/man/bottom_up.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/reconciliation.R +\name{bottom_up} +\alias{bottom_up} +\title{Bottom up forecast reconciliation} +\usage{ +bottom_up(models) +} +\arguments{ +\item{models}{A column of models in a mable.} +} +\description{ +\lifecycle{experimental} +} +\details{ +Reconciles a hierarchy using the bottom up reconciliation method. The +response variable of the hierarchy must be aggregated using sums. The +forecasted time points must match for all series in the hierarchy. +} +\seealso{ +\code{\link[=reconcile]{reconcile()}}, \code{\link[=aggregate_key]{aggregate_key()}} +} diff --git a/man/construct_fc.Rd b/man/construct_fc.Rd index 37c6be8..0f26957 100644 --- a/man/construct_fc.Rd +++ b/man/construct_fc.Rd @@ -14,9 +14,11 @@ construct_fc(point, sd, dist) \item{dist}{The forecast distribution (typically produced using \code{new_fcdist})} } \description{ -Will be deprecated in the future, forecast objects should be produced with -either \code{fable} or \code{as_fable} functions. +\lifecycle{deprecated} } \details{ +This function is deprecated. \code{forecast()} methods for a model should return +a vector of distributions using the distributional package. + Backtransformations are automatically handled, and so no transformations should be specified here. } diff --git a/man/dable.Rd b/man/dable.Rd index 42c61b2..f0ac2a1 100644 --- a/man/dable.Rd +++ b/man/dable.Rd @@ -9,8 +9,7 @@ dable(..., response, method = NULL, seasons = list(), aliases = list()) \arguments{ \item{...}{Arguments passed to \code{\link[tsibble:tsibble]{tsibble::tsibble()}}.} -\item{response}{The response variable(s). A single response can be specified -directly via \code{response = y}, multiple responses should be use \code{response = c(y, z)}.} +\item{response}{The name of the response variable column.} \item{method}{The name of the decomposition method.} diff --git a/man/distributions.Rd b/man/distributions.Rd deleted file mode 100644 index c6ff59b..0000000 --- a/man/distributions.Rd +++ /dev/null @@ -1,39 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/quantile.R -\name{dist_normal} -\alias{dist_normal} -\alias{dist_mv_normal} -\alias{dist_sim} -\alias{dist_unknown} -\title{Distributions for intervals} -\usage{ -dist_normal(mean, sd, ...) - -dist_mv_normal(mean, sd, ...) - -dist_sim(sample, ...) - -dist_unknown(n, ...) -} -\arguments{ -\item{mean}{vector of distributional means.} - -\item{sd}{vector of distributional standard deviations.} - -\item{...}{Additional arguments passed on to quantile methods.} - -\item{sample}{a list of simulated values} - -\item{n}{The number of distributions.} -} -\description{ -Distributions for intervals -} -\examples{ -dist_normal(rep(3, 10), seq(0, 1, length.out=10)) - -dist_sim(list(rnorm(100), rnorm(100), rnorm(100))) - -dist_unknown(10) - -} diff --git a/man/fable.Rd b/man/fable.Rd index 4719182..661f553 100644 --- a/man/fable.Rd +++ b/man/fable.Rd @@ -9,15 +9,14 @@ fable(..., response, distribution) \arguments{ \item{...}{Arguments passed to \code{\link[tsibble:tsibble]{tsibble::tsibble()}}.} -\item{response}{The response variable(s). A single response can be specified -directly via \code{response = y}, multiple responses should be use \code{response = c(y, z)}.} +\item{response}{The character vector of response variable(s).} -\item{distribution}{The distribution variable (given as a bare or unquoted variable).} +\item{distribution}{The name of the distribution column (can be provided +using a bare expression).} } \description{ A fable (forecast table) data class (\code{fbl_ts}) which is a tsibble-like data structure for representing forecasts. In extension to the key and index from -the tsibble (\code{tbl_ts}) class, a fable (\code{fbl_ts}) must contain columns of -point forecasts for the response variable(s), and a single distribution -column (\code{fcdist}). +the tsibble (\code{tbl_ts}) class, a fable (\code{fbl_ts}) must also contain a single +distribution column that uses values from the distributional package. } diff --git a/man/fabletools-package.Rd b/man/fabletools-package.Rd index 941d16f..45c36bc 100644 --- a/man/fabletools-package.Rd +++ b/man/fabletools-package.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fabletools.R +% Please edit documentation in R/fabletools-package.R \docType{package} \name{fabletools-package} \alias{fabletools} @@ -14,7 +14,7 @@ Provides tools, helpers and data structures for \seealso{ Useful links: \itemize{ - \item \url{http://fabletools.tidyverts.org/} + \item \url{https://fabletools.tidyverts.org/} \item \url{https://github.com/tidyverts/fabletools} \item Report bugs at \url{https://github.com/tidyverts/fabletools/issues} } diff --git a/man/fcdist.Rd b/man/fcdist.Rd deleted file mode 100644 index 3b9306f..0000000 --- a/man/fcdist.Rd +++ /dev/null @@ -1,26 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/quantile.R -\name{new_fcdist} -\alias{new_fcdist} -\alias{new_fcdist_env} -\title{Create a forecast distribution object} -\usage{ -new_fcdist(..., .env) - -new_fcdist_env(quantile, transformation = list(identity), display = NULL) -} -\arguments{ -\item{...}{Arguments for \code{f} function} - -\item{.env}{An environment produced using \code{new_fcdist_env}} - -\item{quantile}{A distribution function producing quantiles (such as \code{qnorm})} - -\item{transformation}{Transformation to be applied to resulting quantiles -from \code{quantile}} - -\item{display}{Function that is used to format the distribution display} -} -\description{ -Create a forecast distribution object -} diff --git a/man/figures/lifecycle-superseded.svg b/man/figures/lifecycle-superseded.svg new file mode 100644 index 0000000..75f24f5 --- /dev/null +++ b/man/figures/lifecycle-superseded.svg @@ -0,0 +1 @@ + lifecyclelifecyclesupersededsuperseded \ No newline at end of file diff --git a/man/forecast.Rd b/man/forecast.Rd index b0ef9da..2a07b77 100644 --- a/man/forecast.Rd +++ b/man/forecast.Rd @@ -48,10 +48,11 @@ A fable containing the following columns: \itemize{ \item \code{.model}: The name of the model used to obtain the forecast. Taken from the column names of models in the provided mable. -\item The point forecast, which by default is the mean. The name of this column -will be the same as the dependent variable in the model(s). -\item \code{.distribution}. A column of objects of class \code{fcdist}, representing the -statistical distribution of the forecast in the given time period. +\item The forecast distribution. The name of this column will be the same as the +dependent variable in the model(s). If multiple dependent variables exist, +it will be named \code{.distribution}. +\item Point forecasts computed from the distribution using the functions in the +\code{point_forecast} argument. \item All columns in \code{new_data}, excluding those whose names conflict with the above. } @@ -69,7 +70,7 @@ The forecasts returned contain both point forecasts and their distribution. A specific forecast interval can be extracted from the distribution using the \code{\link[=hilo]{hilo()}} function, and multiple intervals can be obtained using \code{\link[=report]{report()}}. These intervals are stored in a single column using the \code{hilo} class, to -extract the numerical upper and lower bounds you can use \code{\link[tidyr:nest]{tidyr::unnest()}}. +extract the numerical upper and lower bounds you can use \code{\link[=unpack_hilo]{unpack_hilo()}}. } \examples{ if (requireNamespace("fable", quietly = TRUE)) { diff --git a/man/skill_score.Rd b/man/skill_score.Rd new file mode 100644 index 0000000..86f7bfe --- /dev/null +++ b/man/skill_score.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/accuracy.R +\name{skill_score} +\alias{skill_score} +\title{Forecast skill score measure} +\usage{ +skill_score(measure) +} +\arguments{ +\item{measure}{The accuracy measure to use in computing the skill score.} +} +\description{ +This function converts other error metrics such as \code{MSE} into a skill score. +The reference or benchmark forecasting method is the Naive method for +non-seasonal data, and the seasonal naive method for seasonal data. +} +\examples{ + +skill_score(MSE) + +if (requireNamespace("fable", quietly = TRUE)) { +library(fable) +library(tsibble) + +lung_deaths <- as_tsibble(cbind(mdeaths, fdeaths)) +lung_deaths \%>\% + dplyr::filter(index < yearmonth("1979 Jan")) \%>\% + model( + ets = ETS(value ~ error("M") + trend("A") + season("A")), + lm = TSLM(value ~ trend() + season()) + ) \%>\% + forecast(h = "1 year") \%>\% + accuracy(lung_deaths, measures = list(skill = skill_score(MSE))) +} + +} diff --git a/man/top_down.Rd b/man/top_down.Rd new file mode 100644 index 0000000..4c589dd --- /dev/null +++ b/man/top_down.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/reconciliation.R +\name{top_down} +\alias{top_down} +\title{Top down forecast reconciliation} +\usage{ +top_down( + models, + method = c("forecast_proportions", "average_proportions", "proportion_averages") +) +} +\arguments{ +\item{models}{A column of models in a mable.} + +\item{method}{The reconciliation method to use.} +} +\description{ +\lifecycle{experimental} +} +\details{ +Reconciles a hierarchy using the top down reconciliation method. The +response variable of the hierarchy must be aggregated using sums. The +forecasted time points must match for all series in the hierarchy. +} +\seealso{ +\code{\link[=reconcile]{reconcile()}}, \code{\link[=aggregate_key]{aggregate_key()}} +} diff --git a/tests/testthat/Rplots.pdf b/tests/testthat/Rplots.pdf index 1c611a9..9072290 100644 Binary files a/tests/testthat/Rplots.pdf and b/tests/testthat/Rplots.pdf differ diff --git a/tests/testthat/test-accuracy.R b/tests/testthat/test-accuracy.R index 11ad734..5d432c5 100644 --- a/tests/testthat/test-accuracy.R +++ b/tests/testthat/test-accuracy.R @@ -1,5 +1,14 @@ context("test-accuracy") +test_that("accuracy() hints", { + skip_if_not_installed("fable") + expect_error( + accuracy(mbl, us_deaths), + "To compute forecast accuracy, you'll need to compute the forecasts first.", + fixed = TRUE + ) +}) + test_that("In-sample accuracy", { skip_if_not_installed("fable") @@ -16,7 +25,7 @@ test_that("In-sample accuracy", { expect_true(!any(map_lgl(acc, compose(any, is.na)))) expect_equal( as.list(acc), - as_tibble(augment(mbl, type = "response")) %>% + as_tibble(augment(mbl)) %>% group_by(.model) %>% summarise(.type = "Training", ME = mean(.resid), RMSE = sqrt(mean(.resid^2)), MAE = mean(abs(.resid)), MPE = mean(.resid/value*100), diff --git a/tests/testthat/test-broom.R b/tests/testthat/test-broom.R index a5a2b05..8e4b94c 100644 --- a/tests/testthat/test-broom.R +++ b/tests/testthat/test-broom.R @@ -6,17 +6,20 @@ test_that("augment", { aug <- augment(mbl) expect_equal(aug$index, us_deaths_tr$index) expect_equal(aug$.fitted, fitted(mbl)$.fitted) - expect_equal(aug$.resid, residuals(mbl)$.resid) + expect_equal(aug$.resid, residuals(mbl, type ="response")$.resid) + expect_equal(aug$.innov, residuals(mbl)$.resid) aug <- augment(mbl_multi) expect_equal(aug$index, lung_deaths_long_tr$index) expect_equal(aug$.fitted, fitted(mbl_multi)$.fitted) - expect_equal(aug$.resid, residuals(mbl_multi)$.resid) + expect_equal(aug$.resid, residuals(mbl_multi, type = "response")$.resid) + expect_equal(aug$.innov, residuals(mbl_multi)$.resid) aug <- augment(mbl_complex) expect_equal(aug$index, rep(lung_deaths_long_tr$index, 2)) expect_equal(aug$.fitted, fitted(mbl_complex)$.fitted) - expect_equal(aug$.resid, residuals(mbl_complex)$.resid) + expect_equal(aug$.resid, residuals(mbl_complex, type = "response")$.resid) + expect_equal(aug$.innov, residuals(mbl_complex)$.resid) aug <- augment(mbl_mv) expect_equal(aug$index, rep(lung_deaths_wide_tr$index, 2)) @@ -34,8 +37,8 @@ test_that("glance", { expect_equal(gl_multi$key, c("fdeaths", "mdeaths")) gl_complex <- glance(mbl_complex) expect_equal(NROW(gl_complex), 4) - expect_equal(gl_complex$key, rep(c("fdeaths", "mdeaths"), 2)) - expect_equal(gl_multi[-2], gl_complex[c(1,2), names(gl_multi)][-2]) + expect_equal(gl_complex$key, rep(c("fdeaths", "mdeaths"), each = 2)) + expect_equal(gl_multi[-2], gl_complex[c(1,3), names(gl_multi)][-2]) gl_mv <- glance(mbl_mv) expect_equal(NROW(gl_mv), 1) diff --git a/tests/testthat/test-combination.R b/tests/testthat/test-combination.R index 4e4c777..2ae8fff 100644 --- a/tests/testthat/test-combination.R +++ b/tests/testthat/test-combination.R @@ -7,13 +7,13 @@ test_that("Combination modelling", { transmute(combination = (ets + ets)/2) expect_equal( - augment(mbl_cmbn, type = "response")[,-1], - augment(mbl, type = "response")[,-1] + select(augment(mbl_cmbn), -.model, -.innov), + select(augment(mbl), -.model, -.innov) ) expect_equivalent( - unclass(ggplot2::fortify(mbl_cmbn %>% forecast(h = 12))[,-1]), - unclass(ggplot2::fortify(fbl)[,-1]) + forecast(mbl_cmbn, h = 12)[,-1], + fbl[,-1] ) mbl_cmbn <- us_deaths_tr %>% @@ -25,8 +25,8 @@ test_that("Combination modelling", { fbl_cmbn <- forecast(mbl_cmbn) expect_equivalent( - unclass(ggplot2::fortify(fbl_cmbn)[1:48, -1]), - unclass(ggplot2::fortify(fbl_cmbn)[49:96, -1]) + fbl_cmbn[1:24, -1], + fbl_cmbn[25:48, -1] ) mbl_cmbn <- us_deaths_tr %>% diff --git a/tests/testthat/test-parser.R b/tests/testthat/test-parser.R index 54e5dfd..04276ff 100644 --- a/tests/testthat/test-parser.R +++ b/tests/testthat/test-parser.R @@ -91,6 +91,16 @@ test_that("Model parsing scope", { expect_equal(mdl[[1]][[1]]$response[[1]], sym("value")) + # Transformation from scalar in function env + mdl <- eval({ + {function() { + scale <- pi + model(us_deaths, no_specials(value/scale)) + }} () + }, envir = new_environment(list(no_specials = no_specials))) + + expect_equal(mdl[[1]][[1]]$response[[1]], sym("value")) + # Specials missing values expect_warning( eval({ diff --git a/vignettes/extension_models.Rmd b/vignettes/extension_models.Rmd index 3938f68..7853bb0 100644 --- a/vignettes/extension_models.Rmd +++ b/vignettes/extension_models.Rmd @@ -2,7 +2,7 @@ title: "Extending fabletools: Models" output: rmarkdown::html_vignette vignette: > - %\VignetteIndexEntry{extension_models} + %\VignetteIndexEntry{Extending fabletools: Models} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- diff --git a/vignettes/temporal.R b/vignettes/temporal.R deleted file mode 100644 index 719da1a..0000000 --- a/vignettes/temporal.R +++ /dev/null @@ -1,29 +0,0 @@ -## ---- include = FALSE--------------------------------------------------------- -knitr::opts_chunk$set( - collapse = TRUE, - comment = "#>" -) - -## ----setup-------------------------------------------------------------------- -library(fabletools) - -## ----------------------------------------------------------------------------- -library(ggplot2) -library(lubridate) -library(tidyr) -granular <- tibble::tibble( - interval = ordered(c("hour", "day", "week", "fortnight", "month"), levels = c("hour", "day", "week", "fortnight", "month")), - times = list( - seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 hour"), - seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 day"), - seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 week"), - seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "2 weeks"), - seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 month") - ) -) %>% - unnest(times) - -granular %>% - ggplot(aes(x = times, y = interval)) + - geom_point() - diff --git a/vignettes/temporal.html b/vignettes/temporal.html deleted file mode 100644 index 3c1b7f5..0000000 --- a/vignettes/temporal.html +++ /dev/null @@ -1,354 +0,0 @@ - - - - - - - - - - - - - - -Temporal aggregation - - - - - - - - - - - - - - - - - - - - - -

Temporal aggregation

- - - -
library(fabletools)
-
library(ggplot2)
-#> Warning: package 'ggplot2' was built under R version 3.6.3
-library(lubridate)
-#> Warning: package 'lubridate' was built under R version 3.6.3
-#> 
-#> Attaching package: 'lubridate'
-#> The following object is masked from 'package:fabletools':
-#> 
-#>     interval
-#> The following objects are masked from 'package:base':
-#> 
-#>     date, intersect, setdiff, union
-library(tidyr)
-granular <- tibble::tibble(
-  interval = ordered(c("hour", "day", "week", "fortnight", "month"), levels = c("hour", "day", "week", "fortnight", "month")),
-  times = list(
-    seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 hour"),
-    seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 day"),
-    seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 week"),
-    seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "2 weeks"),
-    seq(ymd_hms("1970-01-01 00:00:00"), ymd_hms("1970-03-01 00:00:00"), by = "1 month")
-  )
-) %>% 
-  unnest(times)
-
-granular %>% 
-  ggplot(aes(x = times, y = interval)) + 
-  geom_point()
-

-

Offset each aggregation period to end at the last observation. Incomplete aggregations should not be created To potentially reconcile aggregations which don’t nest exactly, also include lower levels of disaggregation

- - - - - - - - - - -
Method