-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Gam.cs
176 lines (153 loc) · 7.52 KB
/
Gam.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Microsoft.ML.Data;
namespace Samples.Dynamic.Trainers.BinaryClassification
{
public static class Gam
{
// This example requires installation of additional NuGet package for
// Microsoft.ML.FastTree at
// https://www.nuget.org/packages/Microsoft.ML.FastTree/
public static void Example()
{
// Create a new context for ML.NET operations. It can be used for
// exception tracking and logging, as a catalog of available operations
// and as the source of randomness.
var mlContext = new MLContext();
// Create the dataset.
var samples = GenerateData();
// Convert the dataset to an IDataView.
var data = mlContext.Data.LoadFromEnumerable(samples);
// Create training and validation sets.
var dataSets = mlContext.Data.TrainTestSplit(data);
var trainSet = dataSets.TrainSet;
var validSet = dataSets.TestSet;
// Create a GAM trainer.
// Use a small number of bins for this example. The setting below means
// for each feature, we divide its range into 16 discrete regions for
// the training process. Note that these regions are not evenly spaced,
// and that the final model may contain fewer bins, as neighboring bins
// with identical values will be combined. In general, we recommend
// using at least the default number of bins, as a small number of bins
// limits the capacity of the model.
var trainer = mlContext.BinaryClassification.Trainers
.Gam(maximumBinCountPerFeature: 16);
// Fit the model using both of training and validation sets. GAM can use
// a technique called pruning to tune the model to the validation set
// after training to improve generalization.
var model = trainer.Fit(trainSet, validSet);
// Extract the model parameters.
var gam = model.Model.SubModel;
// Now we can inspect the parameters of the Generalized Additive Model
// to understand the fit and potentially learn about our dataset. First,
// we will look at the bias; the bias represents the average prediction
// for the training data.
Console.WriteLine($"Average prediction: {gam.Bias:0.00}");
// Now look at the shape functions that the model has learned. Similar
// to a linear model, we have one response per feature, and they are
// independent. Unlike a linear model, this response is a generic
// function instead of a line. Because we have included a bias term,
// each feature response represents the deviation from the average
// prediction as a function of the feature value.
for (int i = 0; i < gam.NumberOfShapeFunctions; i++)
{
// Break a line.
Console.WriteLine();
// Get the bin upper bounds for the feature.
var binUpperBounds = gam.GetBinUpperBounds(i);
// Get the bin effects; these are the function values for each bin.
var binEffects = gam.GetBinEffects(i);
// Now, write the function to the console. The function is a set of
// bins, and the corresponding function values. You can think of
// GAMs as building a bar-chart or lookup table for each feature.
Console.WriteLine($"Feature{i}");
for (int j = 0; j < binUpperBounds.Count; j++)
Console.WriteLine(
$"x < {binUpperBounds[j]:0.00} => {binEffects[j]:0.000}");
}
// Expected output:
// Average prediction: 0.82
//
// Feature0
// x < -0.44 => 0.286
// x < -0.38 => 0.225
// x < -0.32 => 0.048
// x < -0.26 => -0.110
// x < -0.20 => -0.116
// x < 0.18 => -0.143
// x < 0.25 => -0.115
// x < 0.31 => -0.005
// x < 0.37 => 0.097
// x < 0.44 => 0.263
// x < ∞ => 0.284
//
// Feature1
// x < 0.00 => -0.350
// x < 0.24 => 0.875
// x < 0.31 => -0.138
// x < ∞ => -0.188
// Let's consider this output. To score a given example, we look up the
// first bin where the inequality is satisfied for the feature value.
// We can look at the whole function to get a sense for how the model
// responds to the variable on a global level.The model can be seen to
// reconstruct the parabolic and step-wise function, shifted with
// respect to the average expected output over the training set.
// Very few bins are used to model the second feature because the GAM
// model discards unchanged bins to create smaller models. One last
// thing to notice is that these feature functions can be noisy. While
// we know that Feature1 should be symmetric, this is not captured in
// the model. This is due to noise in the data. Common practice is to
// use resampling methods to estimate a confidence interval at each bin.
// This will help to determine if the effect is real or just sampling
// noise. See for example: Tan, Caruana, Hooker, and Lou.
// "Distill-and-Compare: Auditing Black-Box Models Using Transparent
// Model Distillation."
// <a href='https://arxiv.org/abs/1710.06169'>arXiv:1710.06169</a>."
}
private class Data
{
public bool Label { get; set; }
[VectorType(2)]
public float[] Features { get; set; }
}
/// <summary>
/// Creates a dataset, an IEnumerable of Data objects, for a GAM sample.
/// Feature1 is a parabola centered around 0, while Feature2 is a simple
/// piecewise function.
/// </summary>
/// <param name="numExamples">The number of examples to generate.</param>
/// <param name="seed">The seed for the random number generator used to
/// produce data.</param>
/// <returns></returns>
private static IEnumerable<Data> GenerateData(int numExamples = 25000,
int seed = 1)
{
var rng = new Random(seed);
float centeredFloat() => (float)(rng.NextDouble() - 0.5);
for (int i = 0; i < numExamples; i++)
{
// Generate random, uncoupled features.
var data = new Data
{
Features = new float[2] { centeredFloat(), centeredFloat() }
};
// Compute the label from the shape functions and add noise.
data.Label = Sigmoid(Parabola(data.Features[0])
+ SimplePiecewise(data.Features[1]) + centeredFloat()) > 0.5;
yield return data;
}
}
private static float Parabola(float x) => x * x;
private static float SimplePiecewise(float x)
{
if (x < 0)
return 0;
else if (x < 0.25)
return 1;
else
return 0;
}
private static double Sigmoid(double x) => 1.0 / (1.0 + Math.Exp(-1 * x));
}
}