/
GainsLift.java
executable file
·278 lines (258 loc) · 11.4 KB
/
GainsLift.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
package hex;
import hex.quantile.Quantile;
import hex.quantile.QuantileModel;
import water.*;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.PrettyPrint;
import water.util.TwoDimTable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
public class GainsLift extends Iced {
private double[] _quantiles;
//INPUT
public int _groups = -1;
public Vec _labels;
public Vec _preds; //of length N, n_i = N/GROUPS
public Vec _weights;
//OUTPUT
public double[] response_rates; // p_i = e_i/n_i
public double[] avg_scores; // s_i
public double avg_response_rate; // P
public double avg_score; // S
public long[] events; // e_i
public long[] observations; // n_i
TwoDimTable table;
public GainsLift(Vec preds, Vec labels) {
this(preds, labels, null);
}
public GainsLift(Vec preds, Vec labels, Vec weights) {
_preds = preds;
_labels = labels;
_weights = weights;
}
private void init(Job job) throws IllegalArgumentException {
_labels = _labels.toCategoricalVec();
if( _labels ==null || _preds ==null )
throw new IllegalArgumentException("Missing actualLabels or predictedProbs!");
if (_labels.length() != _preds.length())
throw new IllegalArgumentException("Both arguments must have the same length ("+ _labels.length()+"!="+ _preds.length()+")!");
if (!_labels.isInt())
throw new IllegalArgumentException("Actual column must be integer class labels!");
if (_labels.cardinality() != -1 && _labels.cardinality() != 2)
throw new IllegalArgumentException("Actual column must contain binary class labels, but found cardinality " + _labels.cardinality() + "!");
if (_preds.isCategorical())
throw new IllegalArgumentException("Predicted probabilities cannot be class labels, expect probabilities.");
if (_weights != null && !_weights.isNumeric())
throw new IllegalArgumentException("Observation weights must be numeric.");
// The vectors are from different groups => align them, but properly delete it after computation
if (!_labels.group().equals(_preds.group())) {
_preds = _labels.align(_preds);
Scope.track(_preds);
if (_weights !=null) {
_weights = _labels.align(_weights);
Scope.track(_weights);
}
}
boolean fast = false;
if (fast) {
// FAST VERSION: single-pass, only works with the specific pre-computed quantiles from rollupstats
assert(_groups == 10);
assert(Arrays.equals(Vec.PERCENTILES,
// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15, 16
new double[]{0.001, 0.01, 0.1, 0.2, 0.25, 0.3, 1.0 / 3.0, 0.4, 0.5, 0.6, 2.0 / 3.0, 0.7, 0.75, 0.8, 0.9, 0.99, 0.999}));
//HACK: hardcoded quantiles for simplicity (0.9,0.8,...,0.1,0)
double[] rq = _preds.pctiles(); //might do a full pass over the Vec
_quantiles = new double[]{
rq[14], rq[13], rq[11], rq[9], rq[8], rq[7], rq[5], rq[3], rq[2], 0 /*ignored*/
};
} else {
// ACCURATE VERSION: multi-pass
Frame fr = null;
QuantileModel qm = null;
try {
QuantileModel.QuantileParameters qp = new QuantileModel.QuantileParameters();
if (_weights==null) {
fr = new Frame(Key.<Frame>make(), new String[]{"predictions"}, new Vec[]{_preds});
} else {
fr = new Frame(Key.<Frame>make(), new String[]{"predictions", "weights"}, new Vec[]{_preds, _weights});
qp._weights_column = "weights";
}
DKV.put(fr);
qp._train = fr._key;
if (_groups > 0) {
qp._probs = new double[_groups];
for (int i = 0; i < _groups; ++i) {
qp._probs[i] = (_groups - i - 1.) / _groups; // This is 0.9, 0.8, 0.7, 0.6, ..., 0.1, 0 for 10 groups
}
} else {
qp._probs = new double[]{0.99, 0.98, 0.97, 0.96, 0.95, 0.9, 0.85, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0};
}
qm = job != null && !job.isDone() ? new Quantile(qp, job).trainModelNested(null) : new Quantile(qp).trainModel().get();
_quantiles = qm._output._quantiles[0];
// find uniques (is there a more elegant way?)
TreeSet<Double> hs = new TreeSet<>();
for (double d : _quantiles) hs.add(d);
_quantiles = new double[hs.size()];
Iterator<Double> it = hs.descendingIterator();
int i = 0;
while (it.hasNext()) _quantiles[i++] = it.next();
} finally {
if (qm!=null) qm.remove();
if (fr!=null) DKV.remove(fr._key);
}
}
}
public void exec() {
exec(null);
}
public void exec(Job job) {
Scope.enter();
init(job); //check parameters and obtain _quantiles from _preds
try {
GainsLiftBuilder gt = new GainsLiftBuilder(_quantiles);
gt = (_weights != null) ? gt.doAll(_labels, _preds, _weights) : gt.doAll(_labels, _preds);
response_rates = gt.response_rates();
avg_scores = gt.avg_scores();
avg_response_rate = gt.avg_response_rate();
avg_score = gt.avg_score();
events = gt.events();
observations = gt.observations();
} finally { // Delete adaptation vectors
Scope.exit();
}
}
@Override public String toString() {
TwoDimTable t = createTwoDimTable();
return t==null ? "" : t.toString();
}
public TwoDimTable createTwoDimTable() {
if (response_rates == null || Double.isNaN(avg_response_rate)) return null;
TwoDimTable table = new TwoDimTable(
"Gains/Lift Table",
"Avg response rate: " + PrettyPrint.formatPct(avg_response_rate) + ", avg score: " + PrettyPrint.formatPct(avg_score),
new String[events.length],
new String[]{"Group", "Cumulative Data Fraction", "Lower Threshold", "Lift", "Cumulative Lift", "Response Rate", "Score", "Cumulative Response Rate", "Cumulative Score", "Capture Rate", "Cumulative Capture Rate", "Gain", "Cumulative Gain", "Kolmogorov Smirnov"},
new String[]{"int", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double"},
new String[]{"%d", "%.8f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f","%5f"},
null);
long sum_e_i = 0;
long sum_n_i = 0;
double sum_s_i = 0;
double P = avg_response_rate; // E/N
long N = ArrayUtils.sum(observations);
long E = Math.round(N * P);
for (int i = 0; i < events.length; ++i) {
long e_i = events[i];
long n_i = observations[i];
double p_i = response_rates[i];
double s_i = avg_scores[i];
sum_e_i += e_i;
sum_n_i += n_i;
sum_s_i += n_i * s_i;
double lift=p_i/P; //can be NaN if P==0
double sum_lift=(double)sum_e_i/sum_n_i/P; //can be NaN if P==0
final double cum_event = sum_e_i / (double)E;
final double total_non_event = (double) (N - E);
// If response rate is 1, there are non non-events and the cumulative count will always be zero
final double cum_non_event = total_non_event == 0 ? 0 : (sum_n_i - sum_e_i) / total_non_event;
table.set(i,0,i+1); //group
table.set(i,1,(double)sum_n_i/N); //cumulative_data_fraction
table.set(i,2,_quantiles[i]); //lower_threshold
table.set(i,3,lift); //lift
table.set(i,4,sum_lift); //cumulative_lift
table.set(i,5,p_i); //response_rate
table.set(i,6,s_i); //score
table.set(i,7,(double)sum_e_i/sum_n_i); //cumulative_response_rate
table.set(i,8,(double)sum_s_i/sum_n_i); //cumulative_score
table.set(i,9,(double)e_i/E); //capture_rate
table.set(i,10,(double)sum_e_i/E); //cumulative_capture_rate
table.set(i,11,100*(lift-1)); //gain
table.set(i,12,100*(sum_lift-1)); //cumulative gain
table.set(i,13,cum_event - cum_non_event); //Kolmogorov-Smirnov metric
if (i== events.length-1) {
assert(sum_n_i == N) : "Cumulative data fraction must be 1.0, but is " + (double)sum_n_i/N;
assert(sum_e_i == E) : "Cumulative capture rate must be 1.0, but is " + (double)sum_e_i/E;
if (!Double.isNaN(sum_lift)) assert(Math.abs(sum_lift - 1.0) < 1e-8) : "Cumulative lift must be 1.0, but is " + sum_lift;
assert(Math.abs((double)sum_e_i/sum_n_i - avg_response_rate) < 1e-8) : "Cumulative response rate must be " + avg_response_rate + ", but is " + (double)sum_e_i/sum_n_i;
}
}
return this.table = table;
}
// Compute Gains table via MRTask
public static class GainsLiftBuilder extends MRTask<GainsLiftBuilder> {
/* @OUT response_rates */
public final double[] response_rates() { return _response_rates; }
public final double avg_response_rate() { return _avg_response_rate; }
public final double avg_score() { return _avg_score; }
public final long[] events(){ return _events; }
public final long[] observations(){ return _observations; }
public final double[] avg_scores() { return _avg_scores; }
/* @IN quantiles/thresholds */
final private double[] _thresh;
private long[] _events;
private long[] _observations;
private long _avg_response;
private double _avg_response_rate;
private double _avg_score;
private double[] _response_rates;
private double[] _avg_scores;
public GainsLiftBuilder(double[] thresh) {
_thresh = thresh.clone();
}
@Override public void map( Chunk ca, Chunk cp) { map(ca, cp, (Chunk)null); }
@Override public void map( Chunk ca, Chunk cp, Chunk cw) {
_events = new long[_thresh.length];
_observations = new long[_thresh.length];
_avg_scores = new double[_thresh.length];
_avg_response = 0;
_avg_score = 0;
final int len = Math.min(ca._len, cp._len);
for( int i=0; i < len; i++ ) {
if (ca.isNA(i)) continue;
final int a = (int)ca.at8(i);
if (a != 0 && a != 1) throw new IllegalArgumentException("Invalid values in actualLabels: must be binary (0 or 1).");
if (cp.isNA(i)) continue;
final double pr = cp.atd(i);
final double w = cw!=null?cw.atd(i):1;
perRow(pr, a, w);
}
}
public void perRow(double pr, int a, double w) {
if (w==0) return;
assert (!Double.isNaN(pr));
assert (!Double.isNaN(a));
assert (!Double.isNaN(w));
//for-loop is faster than binary search for small number of thresholds
for( int t=0; t < _thresh.length; t++ ) {
if (pr >= _thresh[t] && (t==0 || pr <_thresh[t-1])) {
_observations[t]+=w;
_avg_scores[t]+=w*pr;
if (a == 1) _events[t]+=w;
break;
}
}
if (a == 1) _avg_response+=w;
_avg_score += w*pr;
}
@Override public void reduce(GainsLiftBuilder other) {
ArrayUtils.add(_events, other._events);
ArrayUtils.add(_observations, other._observations);
ArrayUtils.add(_avg_scores, other._avg_scores);
_avg_response += other._avg_response;
_avg_score += other._avg_score;
}
@Override public void postGlobal(){
_response_rates = new double[_thresh.length];
for (int i=0; i<_response_rates.length; ++i) {
_response_rates[i] = _observations[i] == 0 ? 0 : (double) _events[i] / _observations[i];
_avg_scores[i] = _observations[i] == 0 ? 0 : _avg_scores[i] / (double)_observations[i];
}
_avg_response_rate = (double)_avg_response / ArrayUtils.sum(_observations);
_avg_score /= ArrayUtils.sum(_observations);
}
}
}