-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
TimeSeriesDirectApi.cs
131 lines (114 loc) · 5.21 KB
/
TimeSeriesDirectApi.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Collections.Generic;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
using Xunit;
namespace Microsoft.ML.Tests
{
public sealed class TimeSeries
{
public class Prediction
{
[VectorType(4)]
public double[] Change;
}
sealed class Data
{
public float Value;
public Data(float value)
{
Value = value;
}
}
[Fact]
public void ChangeDetection()
{
using (var env = new ConsoleEnvironment(conc: 1))
{
const int size = 10;
List<Data> data = new List<Data>(size);
var dataView = env.CreateStreamingDataView(data);
List<Data> tempData = new List<Data>();
for (int i = 0; i < size / 2; i++)
tempData.Add(new Data(5));
for (int i = 0; i < size / 2; i++)
tempData.Add(new Data((float)(5 + i * 1.1)));
foreach (var d in tempData)
data.Add(new Data(d.Value));
var args = new IidChangePointDetector.Arguments()
{
Confidence = 80,
Source = "Value",
Name = "Change",
ChangeHistoryLength = size
};
// Train
var detector = new IidChangePointEstimator(env, args).Fit(dataView);
// Transform
var output = detector.Transform(dataView);
// Get predictions
var enumerator = output.AsEnumerable<Prediction>(env, true).GetEnumerator();
Prediction row = null;
List<double> expectedValues = new List<double>() { 0, 5, 0.5, 5.1200000000000114E-08, 0, 5, 0.4999999995, 5.1200000046080209E-08, 0, 5, 0.4999999995, 5.1200000092160303E-08,
0, 5, 0.4999999995, 5.12000001382404E-08};
int index = 0;
while (enumerator.MoveNext() && index < expectedValues.Count)
{
row = enumerator.Current;
Assert.Equal(expectedValues[index++], row.Change[0]);
Assert.Equal(expectedValues[index++], row.Change[1]);
Assert.Equal(expectedValues[index++], row.Change[2]);
Assert.Equal(expectedValues[index++], row.Change[3]);
}
}
}
[Fact]
public void ChangePointDetectionWithSeasonality()
{
using (var env = new ConsoleEnvironment(conc: 1))
{
const int ChangeHistorySize = 10;
const int SeasonalitySize = 10;
const int NumberOfSeasonsInTraining = 5;
const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;
List<Data> data = new List<Data>();
var dataView = env.CreateStreamingDataView(data);
var args = new SsaChangePointDetector.Arguments()
{
Confidence = 95,
Source = "Value",
Name = "Change",
ChangeHistoryLength = ChangeHistorySize,
TrainingWindowSize = MaxTrainingSize,
SeasonalWindowSize = SeasonalitySize
};
for (int j = 0; j < NumberOfSeasonsInTraining; j++)
for (int i = 0; i < SeasonalitySize; i++)
data.Add(new Data(i));
for (int i = 0; i < ChangeHistorySize; i++)
data.Add(new Data(i * 100));
// Train
var detector = new SsaChangePointEstimator(env, args).Fit(dataView);
// Transform
var output = detector.Transform(dataView);
// Get predictions
var enumerator = output.AsEnumerable<Prediction>(env, true).GetEnumerator();
Prediction row = null;
List<double> expectedValues = new List<double>() { 0, -3.31410598754883, 0.5, 5.12000000000001E-08, 0, 1.5700820684432983, 5.2001145245395008E-07,
0.012414560443710681, 0, 1.2854313254356384, 0.28810801662678009, 0.02038940454467935, 0, -1.0950627326965332, 0.36663890634019225, 0.026956459625565483};
int index = 0;
while (enumerator.MoveNext() && index < expectedValues.Count)
{
row = enumerator.Current;
Assert.Equal(expectedValues[index++], row.Change[0], precision: 7); // Alert
Assert.Equal(expectedValues[index++], row.Change[1], precision: 7); // Raw score
Assert.Equal(expectedValues[index++], row.Change[2], precision: 7); // P-Value score
Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score
}
}
}
}
}