forked from rocketlaunchr/dataframe-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.go
60 lines (47 loc) · 1.54 KB
/
predict.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Copyright 2018-20 PJ Engineering and Business Solutions Pty. Ltd. All rights reserved.
package hw
import (
"context"
dataframe "github.com/rocketlaunchr/dataframe-go"
"github.com/rocketlaunchr/dataframe-go/forecast"
)
// Predict forecasts the next n values for the loaded data.
func (hw *HoltWinters) Predict(ctx context.Context, n uint) (*dataframe.SeriesFloat64, []forecast.Confidence, error) {
name := hw.sf.Name(dataframe.DontLock)
nsf := dataframe.NewSeriesFloat64(name, &dataframe.SeriesInit{Capacity: int(n)})
if n <= 0 {
if len(hw.cfg.ConfidenceLevels) == 0 {
return nsf, nil, nil
}
return nsf, []forecast.Confidence{}, nil
}
cnfdnce := []forecast.Confidence{}
var (
st float64 = hw.tstate.smoothingLevel
seasonals []float64 = hw.tstate.seasonalComps
trnd float64 = hw.tstate.trendLevel
period int = int(hw.cfg.Period)
)
for i := uint(0); i < n; i++ {
if err := ctx.Err(); err != nil {
return nil, nil, err
}
m := int(i + 1)
var fval float64
if hw.cfg.SeasonalMethod == Multiplicative {
fval = (st + float64(m)*trnd) * seasonals[(m-1)%period]
} else {
fval = (st + float64(m)*trnd) + seasonals[(m-1)%period]
}
nsf.Append(fval, dataframe.DontLock)
cis := map[float64]forecast.ConfidenceInterval{}
for _, level := range hw.cfg.ConfidenceLevels {
cis[level] = forecast.DriftConfidenceInterval(fval, level, hw.tstate.rmse, hw.tstate.T, n)
}
cnfdnce = append(cnfdnce, cis)
}
if len(hw.cfg.ConfidenceLevels) == 0 {
return nsf, nil, nil
}
return nsf, cnfdnce, nil
}