-
Notifications
You must be signed in to change notification settings - Fork 562
/
GLM2.java
1681 lines (1534 loc) · 72.7 KB
/
GLM2.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
package hex.glm;
import dontweave.gson.JsonObject;
import hex.FrameTask.DataInfo;
import hex.FrameTask.DataInfo.TransformType;
import hex.GridSearch.GridSearchProgress;
import hex.glm.GLMModel.GLMXValidationTask;
import hex.glm.GLMModel.Submodel;
import hex.glm.GLMParams.Family;
import hex.glm.GLMParams.Link;
import hex.glm.GLMTask.GLMInterceptTask;
import hex.glm.GLMTask.GLMIterationTask;
import hex.glm.GLMTask.YMUTask;
import hex.glm.LSMSolver.ADMMSolver;
import jsr166y.CountedCompleter;
import water.*;
import water.H2O.H2OCallback;
import water.H2O.H2OCountedCompleter;
import water.H2O.H2OEmptyCompleter;
import water.api.DocGen;
import water.api.ParamImportance;
import water.api.RequestServer.API_VERSION;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.ModelUtils;
import water.util.RString;
import water.util.Utils;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Properties;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class GLM2 extends Job.ModelJobWithoutClassificationField {
public static final double LS_STEP = .9;
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
public static DocGen.FieldDoc[] DOC_FIELDS;
public static final String DOC_GET = "GLM2";
public final String _jobName;
transient public boolean _done = false;
// API input parameters BEGIN ------------------------------------------------------------
@API(help="Column to be used as an offset, if you have one.", required=false, filter=responseFilter.class, json = true)
public Vec offset = null;
class responseFilter extends SpecialVecSelect { responseFilter() { super("source"); } }
@API(help = "Family.", filter = Default.class, json=true, importance = ParamImportance.CRITICAL)
protected Family family = Family.gaussian;
@API(help = "", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
protected Link link = Link.family_default;
@API(help = "Tweedie variance power", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
protected double tweedie_variance_power;
public void setTweediePower(double pwr){
tweedie_variance_power = pwr;
tweedie_link_power = 1 - tweedie_variance_power;
_glm = new GLMParams(family,tweedie_variance_power,link,tweedie_link_power);
}
@API(help="prior probability for y==1. To be used only for logistic regression iff the data has been sampled and the mean of response does not reflect reality.",filter=Default.class, importance = ParamImportance.EXPERT)
protected double prior = -1; // -1 is magic value for default value which is mean(y) computed on the current dataset
@API(help="disable line search in all cases.",filter=Default.class, importance = ParamImportance.EXPERT, hide = true)
protected boolean disable_line_search = false; // -1 is magic value for default value which is mean(y) computed on the current dataset
private double _iceptAdjust = 0; // adjustment due to the prior
@API(help = "validation folds", filter = Default.class, lmin=0, lmax=100, json=true, importance = ParamImportance.CRITICAL)
protected int n_folds;
@API(help = "distribution of regularization between L1 and L2.", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
protected double [] alpha = new double[]{0.5};
public final double DEFAULT_LAMBDA = 1e-5;
@API(help = "regularization strength", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
protected double [] lambda = new double[]{DEFAULT_LAMBDA};
@API(help="use lambda search starting at lambda max, given lambda is then interpreted as lambda min",filter=Default.class, importance = ParamImportance.SECONDARY)
protected boolean lambda_search;
@API(help="number of lambdas to be used in a search",filter=Default.class, importance = ParamImportance.EXPERT)
protected int nlambdas = 100;
@API(help="min lambda used in lambda search, specified as a ratio of lambda_max",filter=Default.class, importance = ParamImportance.EXPERT)
protected double lambda_min_ratio = -1;
@API(help="lambda_search stop condition: stop training when model has more than than this number of predictors (or don't use this option if -1).",filter=Default.class, importance = ParamImportance.EXPERT)
protected int max_predictors = -1;
public void setLambda(double l){ lambda = new double []{l};}
private double _currentLambda = Double.POSITIVE_INFINITY;
public int MAX_ITERATIONS_PER_LAMBDA = 10;
@API(help="use strong rules to filter out inactive columns",filter=Default.class, importance = ParamImportance.SECONDARY)
protected boolean strong_rules = true;
// intentionally not declared as API now
int sparseCoefThreshold = 1000; // if more than this number of predictors, result vector of coefficients will be stored sparse
double [] beta_start = null;
@API(help = "Standardize numeric columns to have zero mean and unit variance.", filter = Default.class, json=true, importance = ParamImportance.CRITICAL)
protected boolean standardize = true;
@API(help = "Include intercept term in the model.", filter = Default.class, json=true, importance = ParamImportance.CRITICAL)
protected boolean intercept = true;
@API(help = "Restrict coefficients to be non-negative.", filter = Default.class, json=true, importance = ParamImportance.CRITICAL)
protected boolean non_negative = false;
@API(help="lower bounds for coefficients",filter=Default.class,hide=true)
protected Frame beta_constraints = null;
@API(help="By default, first factor level is skipped from the possible set of predictors. Set this flag if you want use all of the levels. Needs sufficient regularization to solve!",filter=Default.class, importance = ParamImportance.SECONDARY)
protected boolean use_all_factor_levels = false;
/**
* Whether to compute variable importances for input features, based on the absolute
* value of the coefficients. For safety this should only be done if
* use_all_factor_levels, because an important factor level can be skipped and not
* appear if !use_all_factor_levels.
*/
@API(help = "Compute variable importances for input features. NOTE: If use_all_factor_levels is off the importance of the base level will NOT be shown.", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
public boolean variable_importances = false;
@API(help = "beta_eps", filter = Default.class, json=true, importance = ParamImportance.SECONDARY)
protected double beta_epsilon = DEFAULT_BETA_EPS;
@API(help = "max-iterations", filter = Default.class, lmin=1, lmax=1000000, json=true, importance = ParamImportance.CRITICAL)
public int max_iter = 100;
@API(help="use line search (slower speed, to be used if glm does not converge otherwise)",filter=Default.class, importance = ParamImportance.SECONDARY)
protected boolean higher_accuracy = false;
// API input parameters END ------------------------------------------------------------
// API output parameters BEGIN ------------------------------------------------------------
@API(help = "", json=true, importance = ParamImportance.SECONDARY)
private double [] _wgiven;
@API(help = "", json=true, importance = ParamImportance.SECONDARY)
private double _proximalPenalty;
@API(help = "", json=true, importance = ParamImportance.SECONDARY)
private double [] _beta;
@API(help = "", json=true, importance = ParamImportance.SECONDARY)
private boolean _runAllLambdas = true;
@API(help = "Tweedie link power", json=true, importance = ParamImportance.SECONDARY)
double tweedie_link_power;
@API(help = "lambda_value max", json=true, importance = ParamImportance.SECONDARY)
double lambda_max = Double.NaN;
double lambda_min = Double.NaN;
long _nobs = 0;
private double _nullDeviance;
public static int MAX_PREDICTORS = 7000;
// API output parameters END ------------------------------------------------------------
private static double GLM_GRAD_EPS = 1e-4; // done (converged) if subgrad < this value.
private boolean highAccuracy(){return higher_accuracy;}
public GLM2 setHighAccuracy(){
higher_accuracy = true;
return this;
}
private Key _progressKey;
private DataInfo _srcDinfo;
private int [] _activeCols;
private boolean _allIn;
private DataInfo _activeData;
public GLMParams _glm;
private boolean _grid;
private double ADMM_GRAD_EPS = 1e-4; // default addm gradietn eps
private static final double MIN_ADMM_GRAD_EPS = 1e-5; // min admm gradient eps
int _lambdaIdx = -1;
private double _addedL2;
private boolean _failedLineSearch;
public static final double DEFAULT_BETA_EPS = 5e-5;
private double _ymu;
private int _iter;
@Override protected void registered(API_VERSION ver) {
super.registered(ver);
Argument c = find("ignored_cols");
Argument r = find("offset");
int ci = _arguments.indexOf(c);
int ri = _arguments.indexOf(r);
_arguments.set(ri, c);
_arguments.set(ci, r);
((FrameKeyMultiVec) c).ignoreVec((FrameKeyVec)r);
}
private double objval(GLMIterationTask glmt){
return glmt._val.residual_deviance / glmt._nobs + 0.5 * l2pen() * l2norm(glmt._beta) + l1pen() * l1norm(glmt._beta) + proxPen(glmt._beta);
}
private IterationInfo makeIterationInfo(int i, GLMIterationTask glmt, final int [] activeCols, double [] gradient){
IterationInfo ii = new IterationInfo(_iter, glmt,activeCols,gradient);
if(ii._glmt._grad == null)
ii._glmt._grad = contractVec(gradient,activeCols);
return ii;
}
private static class IterationInfo extends Iced {
final int _iter;
private double [] _fullGrad;
public double [] fullGrad(double alpha, double lambda){
if(_fullGrad == null)return null;
double [] res = _fullGrad.clone();
double l2 = (1-alpha)*lambda; // no 0.5 mul here since we're adding derivative of 0.5*|b|^2
if(_activeCols != null)
for(int i = 0; i < _glmt._beta.length-1; ++i)
res[_activeCols[i]] += _glmt._beta[i]*l2;
else for(int i = 0; i < _glmt._beta.length; ++i) {
res[i] += _glmt._beta[i]*l2;
}
return res;
}
private final GLMIterationTask _glmt;
final int [] _activeCols;
IterationInfo(int i, GLMIterationTask glmt, final int [] activeCols, double [] gradient){
_iter = i;
_glmt = glmt.clone();
assert _glmt._grad != null;
_activeCols = activeCols;
_fullGrad = gradient;
// NOTE: _glmt._beta CAN BE NULL (unlikely but possible, if activecCols were empty)
assert _glmt._val != null:"missing validation";
}
}
private IterationInfo _lastResult;
@Override
public JsonObject toJSON() {
JsonObject jo = super.toJSON();
if (lambda == null) jo.addProperty("lambda_value", "automatic"); //better than not printing anything if lambda_value=null
return jo;
}
@Override public Key defaultDestKey(){
return null;
}
@Override public Key defaultJobKey() {return null;}
public GLM2() {_jobName = "";}
public static class Source {
public final Frame fr;
public final Vec response;
public final Vec offset;
public final boolean standardize;
public final boolean intercept;
public Source(Frame fr,Vec response, boolean standardize){ this(fr,response,standardize,true,null);}
public Source(Frame fr,Vec response, boolean standardize, boolean intercept){ this(fr,response,standardize,intercept,null);}
public Source(Frame fr,Vec response, boolean standardize, boolean intercept, Vec offset){
this.fr = fr;
this.response = response;
this.offset = offset;
this.standardize = standardize;
this.intercept = intercept;
}
}
public GLM2(String desc, Key jobKey, Key dest, Source src, Family family){
this(desc,jobKey,dest,src,family,Link.family_default);
}
public GLM2(String desc, Key jobKey, Key dest, Source src, Family family, Link l){
this(desc, jobKey, dest, src, family, l, 0, false);
}
public GLM2(String desc, Key jobKey, Key dest, Source src, Family family, Link l, int nfolds, boolean highAccuracy) {
job_key = jobKey;
description = desc;
destination_key = dest;
this.offset = src.offset;
this.intercept = src.intercept;
this.family = family;
this.link = l;
n_folds = nfolds;
source = src.fr;
this.response = src.response;
this.standardize = src.standardize;
_jobName = dest.toString() + ((nfolds > 1)?("[" + 0 + "]"):"");
higher_accuracy = highAccuracy;
}
public GLM2 doInit(){
init();
return this;
}
public GLM2 setNonNegative(boolean val){
non_negative = val;
return this;
}
public GLM2 setRegularization(double [] alpha, double [] lambda){
this.alpha = alpha;
this.lambda = lambda;
return this;
}
public GLM2 setBetaConstraints(Frame f){
beta_constraints = f;
return this;
}
static String arrayToString (double[] arr) {
if (arr == null) {
return "(null)";
}
StringBuffer sb = new StringBuffer();
for (int i = 0; i < arr.length; i++) {
if (i > 0) {
sb.append(", ");
}
sb.append(arr[i]);
}
return sb.toString();
}
public transient float [] thresholds = ModelUtils.DEFAULT_THRESHOLDS;
/** Return the query link to this page */
public static String link(Key k, String content) {
RString rs = new RString("<a href='GLM2.query?source=%$key'>%content</a>");
rs.replace("key", k.toString());
rs.replace("content", content);
return rs.toString();
}
public GLMGridSearch gridSearch(){
return new GLMGridSearch(4, this, destination_key).fork();
}
private transient AtomicBoolean _jobdone = new AtomicBoolean(false);
@Override public void cancel(String msg){
if(!_grid) {
source.unlock(self());
}
DKV.remove(_progressKey);
Value v = DKV.get(destination_key);
if(v != null){
GLMModel m = v.get();
Key [] xvals = m.xvalModels();
if(xvals != null)
for(Key k:xvals)
DKV.remove(k);
DKV.remove(destination_key);
}
DKV.remove(destination_key);
super.cancel(msg);
}
private boolean sorted(int [] ary){
for(int i = 0; i < ary.length-1; ++i)
if(ary[i+1] < ary[i])return false;
return true;
}
private double computeIntercept(DataInfo dinfo, double ymu, Vec offset, Vec response){
double mul = 1, sub = 0;
int vecId = dinfo._adaptedFrame.find(offset);
if(dinfo._normMul != null)
mul = dinfo._normMul[vecId-dinfo._cats];
if(dinfo._normSub != null)
sub = dinfo._normSub[vecId-dinfo._cats];
double icpt = ymu - (offset.mean() - sub)*mul;
double icpt2 = new GLMInterceptTask(_glm,sub,mul,icpt).doAll(offset,response)._icpt;
double diff = icpt2 - icpt;
int iter = 0;
while((1e-4 < diff || diff < -1e-4) && ++iter <= 10){
icpt = icpt2;
icpt2 = new GLMInterceptTask(_glm,sub,mul,icpt).doAll(offset,response)._icpt;
diff = icpt2 - icpt;
}
return icpt;
}
private transient Frame source2; // adapted source with reordered (and removed) vecs we do not want to push back into KV
private int _noffsets = 0;
private int _intercept = 1; // 1 or 0
private double [] _lbs;
private double [] _ubs;
private double [] _bgs;
private double [] _rho;
boolean toEnum = false;
private double [] makeAry(int sz, double val){
double [] res = MemoryManager.malloc8d(sz);
Arrays.fill(res,val);
return res;
}
private double [] mapVec(double [] src, double [] tgt, int [] map){
for(int i = 0; i < src.length; ++i)
if(map[i] != -1) tgt[map[i]] = src[i];
return tgt;
}
@Override public void init(){
try {
super.init();
if (family == Family.gamma)
setHighAccuracy();
if (link == Link.family_default)
link = family.defaultLink;
_intercept = intercept ? 1 : 0;
tweedie_link_power = 1 - tweedie_variance_power;// TODO
if (tweedie_link_power == 0) link = Link.log;
_glm = new GLMParams(family, tweedie_variance_power, link, tweedie_link_power);
source2 = new Frame(source);
assert sorted(ignored_cols);
source2.remove(ignored_cols);
if(offset != null)
source2.remove(source2.find(offset)); // remove offset and add it later explicitly (so that it does not interfere with DataInfo.prepareFrame)
if (nlambdas == -1)
nlambdas = 100;
if (lambda_search && lambda.length > 1)
throw new IllegalArgumentException("Can not supply both lambda_search and multiple lambdas. If lambda_search is on, GLM expects only one value of lambda_value, representing the lambda_value min (smallest lambda_value in the lambda_value search).");
// check the response
if (response.isEnum() && family != Family.binomial)
throw new IllegalArgumentException("Invalid response variable, trying to run regression with categorical response!");
switch (family) {
case poisson:
case tweedie:
if (response.min() < 0)
throw new IllegalArgumentException("Illegal response column for family='" + family + "', response must be >= 0.");
break;
case gamma:
if (response.min() <= 0)
throw new IllegalArgumentException("Invalid response for family='Gamma', response must be > 0!");
break;
case binomial:
if (response.min() < 0 || response.max() > 1)
throw new IllegalArgumentException("Illegal response column for family='Binomial', response must in <0,1> range!");
break;
default:
//pass
}
toEnum = family == Family.binomial && (!response.isEnum() && (response.min() < 0 || response.max() > 1));
Frame fr = DataInfo.prepareFrame(source2, response, new int[0], toEnum, true, true);
if(offset != null){ // now put the offset just in front of response
int id = source.find(offset);
String name = source.names()[id];
String responseName = fr.names()[fr.numCols()-1];
Vec responseVec = fr.remove(fr.numCols()-1);
fr.add(name, offset);
fr.add(responseName,responseVec);
_noffsets = 1;
}
TransformType dt = TransformType.NONE;
if (standardize)
dt = intercept ? TransformType.STANDARDIZE : TransformType.DESCALE;
_srcDinfo = new DataInfo(fr, 1, intercept, use_all_factor_levels || lambda_search, dt, DataInfo.TransformType.NONE);
if(offset != null && dt != TransformType.NONE) { // do not standardize offset
if(_srcDinfo._normMul != null)
_srcDinfo._normMul[_srcDinfo._normMul.length-1] = 1;
if(_srcDinfo._normSub != null)
_srcDinfo._normSub[_srcDinfo._normSub.length-1] = 0;
}
if (!intercept && _srcDinfo._cats > 0)
throw new IllegalArgumentException("Models with no intercept are only supported with all-numeric predictors.");
_activeData = _srcDinfo;
if (higher_accuracy) setHighAccuracy();
if (beta_constraints != null) {
Vec v;
v = beta_constraints.vec("names");
// for now only enums allowed here
String [] dom = v.domain();
String [] names = Utils.append(_srcDinfo.coefNames(), "Intercept");
int [] map = Utils.asInts(v);
if(!Arrays.deepEquals(dom,names)) { // need mapping
HashMap<String,Integer> m = new HashMap<String, Integer>();
for(int i = 0; i < names.length; ++i)
m.put(names[i],i);
int [] newMap = MemoryManager.malloc4(dom.length);
for(int i = 0; i < map.length; ++i) {
Integer I = m.get(dom[map[i]]);
newMap[i] = I == null?-1:I;
}
map = newMap;
}
final int numoff = _srcDinfo.numStart();
if((v = beta_constraints.vec("lower_bounds")) != null) {
_lbs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, Double.NEGATIVE_INFINITY), map);
// for(int i = 0; i < _lbs.length; ++i)
// if(_lbs[i] > 0) throw new IllegalArgumentException("lower bounds must be non-positive");
if(_srcDinfo._normMul != null) {
for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
if (Double.isInfinite(_lbs[i])) continue;
_lbs[i] /= _srcDinfo._normMul[i - numoff];
}
}
}
if((v = beta_constraints.vec("upper_bounds")) != null) {
_ubs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, Double.POSITIVE_INFINITY), map);
System.out.println("upper bounds = " + Arrays.toString(_ubs));
// for(int i = 0; i < _ubs.length; ++i)
// if (_ubs[i] < 0) throw new IllegalArgumentException("lower bounds must be non-positive");
if(_srcDinfo._normMul != null)
for(int i = numoff; i < _srcDinfo.fullN(); ++i) {
if(Double.isInfinite(_ubs[i]))continue;
_ubs[i] /= _srcDinfo._normMul[i - numoff];
}
}
if(_lbs != null && _ubs != null) {
for(int i = 0 ; i < _lbs.length; ++i)
if(_lbs[i] > _ubs[i])
throw new IllegalArgumentException("Invalid upper/lower bounds: lower bounds must be <= upper bounds for all variables.");
}
if((v = beta_constraints.vec("beta_given")) != null) {
_bgs = map == null ? Utils.asDoubles(v) : mapVec(Utils.asDoubles(v), makeAry(names.length, 0), map);
if(_srcDinfo._normMul != null) {
double norm = 0;
for (int i = numoff; i < _srcDinfo.fullN(); ++i) {
norm += _bgs[i] * _srcDinfo._normSub[i-numoff];
_bgs[i] /= _srcDinfo._normMul[i-numoff];
}
if(_intercept == 1)
_bgs[_bgs.length-1] -= norm;
}
}
if((v = beta_constraints.vec("rho")) != null)
_rho = map == null?Utils.asDoubles(v):mapVec(Utils.asDoubles(v),makeAry(names.length,0),map);
else if(_bgs != null)
throw new IllegalArgumentException("Missing vector of penalties (rho) in beta_constraints file.");
}
if (non_negative) { // make srue lb is >= 0
if (_lbs == null)
_lbs = new double[_srcDinfo.fullN()];
for (int i = 0; i < _lbs.length; ++i)
if (_lbs[i] < 0)
_lbs[i] = 0;
}
} catch(RuntimeException e) {
e.printStackTrace();
cleanup();
throw e;
}
}
@Override protected void cleanup(){
super.cleanup();
if(toEnum && _srcDinfo != null){
Futures fs = new Futures();
_srcDinfo._adaptedFrame.lastVec().remove(fs);
fs.blockForPending();
}
}
@Override protected boolean filterNaCols(){return true;}
@Override protected Response serve() {
try {
init();
if (alpha.length > 1) { // grid search
if (destination_key == null) destination_key = Key.make("GLMGridResults_" + Key.make());
if (job_key == null) job_key = Key.make((byte) 0, Key.JOB, H2O.SELF);
GLMGridSearch j = gridSearch();
_fjtask = j._fjtask;
assert _fjtask != null;
return GLMGridView.redirect(this, j.dest());
} else {
if (destination_key == null) destination_key = Key.make("GLMModel_" + Key.make());
if (job_key == null) job_key = Key.make("GLM2Job_" + Key.make());
fork();
assert _fjtask != null;
return GLMProgress.redirect(this, job_key, dest());
}
}catch(Throwable ex){
return Response.error(ex.getMessage());
}
}
private static double beta_diff(double[] b1, double[] b2) {
if(b1 == null || b1.length == 0)return Double.MAX_VALUE;
double res = b1[0] >= b2[0]?b1[0] - b2[0]:b2[0] - b1[0];
for( int i = 1; i < b1.length; ++i ) {
double diff = b1[i] - b2[i];
if(diff > res)
res = diff;
else if( -diff > res)
res = -diff;
}
return res;
}
//private static double beta_diff(double[] b1, double[] b2) {
// double res = 0;
// for(int i = 0; i < b1.length; ++i)
// res += (b1[i]-b2[i])*(b1[i]-b2[i]);
// return res;
//}
private static class GLM2_Progress extends Iced{
final long _total;
double _done;
public GLM2_Progress(int total){_total = total;
assert _total > 0:"total = " + _total;
}
public float progess(){
return 0.01f*((int)(100*_done/(double)_total));
}
}
private static class GLM2_ProgressUpdate extends TAtomic<GLM2_Progress> {
final int _i;
public GLM2_ProgressUpdate(){_i = 1;}
public GLM2_ProgressUpdate(int i){_i = i;}
@Override
public GLM2_Progress atomic(GLM2_Progress old) {
if(old == null)return old;
old._done += _i;
return old;
}
}
@Override public float progress(){
if(isDone())return 1.0f;
Value v = DKV.get(_progressKey);
if(v == null)return 0;
float res = v.<GLM2_Progress>get().progess();
if(res > 1f)
res = 1f;
return res;
}
protected double l2norm(double[] beta){
if(_beta == null)return 0;
double l2 = 0;
for (double aBeta : beta) l2 += aBeta * aBeta;
return l2;
}
protected double l1norm(double[] beta){
if(_beta == null)return 0;
double l2 = 0;
for (double aBeta : beta) l2 += Math.abs(aBeta);
return l2;
}
private final double [] expandVec(double [] beta, final int [] activeCols){
assert beta != null;
if (activeCols == null)
return beta;
double[] res = MemoryManager.malloc8d(_srcDinfo.fullN() + _intercept -_noffsets);
int i = 0;
for (int c = 0; c < activeCols.length-_noffsets; ++c)
res[_activeCols[c]] = beta[i++];
if(_intercept == 1)
res[res.length - 1] = beta[beta.length - 1];
for(int j = beta.length-_noffsets; j < beta.length-1; ++j)
beta[j] = 1;
return res;
}
private final double [] contractVec(double [] beta, final int [] activeCols){ return contractVec(beta,activeCols,_intercept);}
private final double [] contractVec(double [] beta, final int [] activeCols, int intercept){
if(beta == null)return null;
if(activeCols == null)return beta.clone();
final int N = activeCols.length - _noffsets;
double [] res = MemoryManager.malloc8d(N+intercept);
for(int i = 0; i < N; ++i)
res[i] = beta[activeCols[i]];
if(intercept == 1)
res[res.length-1] = beta[beta.length-1];
return res;
}
private final double [] resizeVec(double[] beta, final int[] activeCols, final int[] oldActiveCols){
if(beta == null || Arrays.equals(activeCols,oldActiveCols))return beta;
double [] full = expandVec(beta, oldActiveCols);
if(activeCols == null)return full;
return contractVec(full,activeCols,_intercept);
}
// protected boolean needLineSearch(final double [] beta,double objval, double step){
protected boolean needLineSearch(final GLMIterationTask glmt) {
if(disable_line_search)
return false;
if(_glm.family == Family.gaussian)
return false;
if(glmt._beta == null)
return false;
if (Utils.hasNaNsOrInfs(glmt._xy) || (glmt._grad != null && Utils.hasNaNsOrInfs(glmt._grad)) || (glmt._gram != null && glmt._gram.hasNaNsOrInfs())) {
return true;
}
if(glmt._val != null && Double.isNaN(glmt._val.residualDeviance())){
return true;
}
if(glmt._val == null) // no validation info, no way to decide
return false;
final double [] grad = Arrays.equals(_activeCols,_lastResult._activeCols)
?_lastResult._glmt.gradient(alpha[0],_currentLambda)
:contractVec(_lastResult.fullGrad(alpha[0],_currentLambda),_activeCols);
return needLineSearch(1, objval(_lastResult._glmt),objval(glmt),diff(glmt._beta,_lastResult._glmt._beta),grad);
}
private static double [] diff(double [] x, double [] y){
if(y == null)return x.clone();
double [] res = MemoryManager.malloc8d(x.length);
for(int i = 0; i < x.length; ++i)
res[i] = x[i] - y[i];
return res;
}
public static final double c1 = 1e-2;
// protected boolean needLineSearch(final double [] beta,double objval, double step){
// Armijo line-search rule enhanced with generalized gradient to handle l1 pen
protected final boolean needLineSearch(double step, final double objOld, final double objNew, final double [] pk, final double [] gradOld){
// line search
double f_hat = 0;
for(int i = 0; i < pk.length; ++i)
f_hat += gradOld[i] * pk[i];
f_hat = step*f_hat + objOld;
return objNew > (f_hat + 1/(2*step)*l2norm(pk));
}
private class LineSearchIteration extends H2OCallback<GLMTask.GLMLineSearchTask> {
final GLMIterationTask _glmt;
LineSearchIteration(GLMIterationTask glmt, CountedCompleter cmp){super((H2OCountedCompleter)cmp); cmp.addToPendingCount(1); _glmt = glmt;}
@Override public void callback(final GLMTask.GLMLineSearchTask glmt) {
assert getCompleter().getPendingCount() >= 1:"unexpected pending count, expected 1, got " + getCompleter().getPendingCount();
double step = LS_STEP;
for(int i = 0; i < glmt._glmts.length; ++i){
if(!needLineSearch(glmt._glmts[i]) || (i == glmt._glmts.length-1 && objval(glmt._glmts[i]) < objval(_lastResult._glmt))){
LogInfo("line search: found admissible step = " + step + ", objval = " + objval(glmt._glmts[i]));
setHighAccuracy();
new GLMIterationTask(_noffsets,GLM2.this.self(),_activeData,_glm,true,true,true,glmt._glmts[i]._beta,_ymu,1.0/_nobs,thresholds, new Iteration(getCompleter(),false,false)).asyncExec(_activeData._adaptedFrame);
return;
}
step *= LS_STEP;
}
LogInfo("line search: did not find admissible step, smallest step = " + step + ", objval = " + objval(glmt._glmts[glmt._glmts.length-1]) + ", old objval = " + objval(_lastResult._glmt));
// check if objval of smallest step is below the previous step, if so, go on
LogInfo("Line search did not find feasible step, converged.");
_failedLineSearch = true;
GLMIterationTask res = highAccuracy()?_lastResult._glmt:_glmt;
if(_activeCols != _lastResult._activeCols && !Arrays.equals(_activeCols,_lastResult._activeCols)) {
_activeCols = _lastResult._activeCols;
_activeData = _srcDinfo.filterExpandedColumns(_activeCols);
}
checkKKTAndComplete(getCompleter(),res,res._beta,true);
}
}
protected double checkGradient(final double [] newBeta, final double [] grad){
// check the gradient
ADMMSolver.subgrad(alpha[0], _currentLambda, newBeta, grad);
double err = 0;
for(double d:grad)
if(d > err) err = d;
else if(d < -err) err = -d;
LogInfo("converged with max |subgradient| = " + err);
return err;
}
private String LogInfo(String msg){
msg = "GLM2[dest=" + dest() + ", iteration=" + _iter + ", lambda = " + _currentLambda + "]: " + msg;
Log.info(msg);
return msg;
}
private double [] setSubmodel(final double[] newBeta, GLMValidation val, H2OCountedCompleter cmp){
int intercept = (this.intercept ?1:0);
double [] fullBeta = (_activeCols == null || newBeta == null)?newBeta.clone():expandVec(newBeta,_activeCols);
if(val != null) val.null_deviance = _nullDeviance;
if(this.intercept)
fullBeta[fullBeta.length-1] += _iceptAdjust;
if(_noffsets > 0){
fullBeta = Arrays.copyOf(fullBeta,fullBeta.length + _noffsets);
if(this.intercept)
fullBeta[fullBeta.length-1] = fullBeta[fullBeta.length-intercept-_noffsets];
for(int i = fullBeta.length-intercept-_noffsets; i < fullBeta.length-intercept; ++i)
fullBeta[i] = 1;//_srcDinfo.applyTransform(i,1);
}
final double [] newBetaDeNorm;
final int numoff = _srcDinfo.numStart();
if(_srcDinfo._predictor_transform == DataInfo.TransformType.STANDARDIZE) {
assert this.intercept;
newBetaDeNorm = fullBeta.clone();
double norm = 0.0; // Reverse any normalization on the intercept
// denormalize only the numeric coefs (categoricals are not normalized)
for( int i=numoff; i< fullBeta.length-intercept; i++ ) {
double b = newBetaDeNorm[i]* _srcDinfo._normMul[i-numoff];
norm += b* _srcDinfo._normSub[i-numoff]; // Also accumulate the intercept adjustment
newBetaDeNorm[i] = b;
}
if(this.intercept)
newBetaDeNorm[newBetaDeNorm.length-1] -= norm;
} else if (_srcDinfo._predictor_transform == TransformType.DESCALE) {
assert !this.intercept;
newBetaDeNorm = fullBeta.clone();
for( int i=numoff; i< fullBeta.length; i++ )
newBetaDeNorm[i] *= _srcDinfo._normMul[i-numoff];
} else
newBetaDeNorm = null;
GLMModel.setSubmodel(cmp, dest(), _currentLambda, newBetaDeNorm == null ? fullBeta : newBetaDeNorm, newBetaDeNorm == null ? null : fullBeta, _iter, System.currentTimeMillis() - start_time, _srcDinfo.fullN() >= sparseCoefThreshold, val);
return fullBeta;
}
private transient long _callbackStart = 0;
private transient double _rho_mul = 1.0;
private transient double _gradientEps = ADMM_GRAD_EPS;
private double [] lastBeta(int noffsets){
final double [] b;
if(_lastResult == null || _lastResult._glmt._beta == null) {
int bsz = _activeCols == null? _srcDinfo.fullN()+1-noffsets:_activeCols.length+1;
b = MemoryManager.malloc8d(bsz);
b[bsz-1] = _glm.linkInv(_ymu);
} else
b = resizeVec(_lastResult._glmt._beta, _activeCols, _lastResult._activeCols);
return b;
}
protected void checkKKTAndComplete(final CountedCompleter cc, final GLMIterationTask glmt, final double [] newBeta, final boolean failedLineSearch){
H2OCountedCompleter cmp = (H2OCountedCompleter)cc;
final double [] fullBeta = newBeta == null?MemoryManager.malloc8d(_srcDinfo.fullN()+_intercept-_noffsets):expandVec(newBeta,_activeCols);
// now we need full gradient (on all columns) using this beta
new GLMIterationTask(_noffsets,GLM2.this.self(), _srcDinfo,_glm,false,true,true,fullBeta,_ymu,1.0/_nobs,thresholds, new H2OCallback<GLMIterationTask>(cmp) {
@Override public String toString(){
return "checkKKTAndComplete.Callback, completer = " + getCompleter() == null?"null":getCompleter().toString();
}
@Override
public void callback(final GLMIterationTask glmt2) {
// first check KKT conditions!
final double [] grad = glmt2.gradient(alpha[0],_currentLambda);
if(Utils.hasNaNsOrInfs(grad)){
_failedLineSearch = true;
// TODO: add warning and break the lambda search? Or throw Exception?
}
glmt._val = glmt2._val;
_lastResult = makeIterationInfo(_iter,glmt2,null,glmt2.gradient(alpha[0],0));
// check the KKT conditions and filter data for next lambda_value
// check the gradient
double[] subgrad = grad.clone();
ADMMSolver.subgrad(alpha[0], _currentLambda, fullBeta, subgrad);
double grad_eps = GLM_GRAD_EPS;
if (!failedLineSearch &&_activeCols != null) {
for (int c = 0; c < _activeCols.length-_noffsets; ++c)
if (subgrad[_activeCols[c]] > grad_eps) grad_eps = subgrad[_activeCols[c]];
else if (subgrad[c] < -grad_eps) grad_eps = -subgrad[_activeCols[c]];
int[] failedCols = new int[64];
int fcnt = 0;
for (int i = 0; i < grad.length - 1; ++i) {
if (Arrays.binarySearch(_activeCols, i) >= 0) continue;
if (subgrad[i] > grad_eps || -subgrad[i] > grad_eps) {
if (fcnt == failedCols.length)
failedCols = Arrays.copyOf(failedCols, failedCols.length << 1);
failedCols[fcnt++] = i;
}
}
if (fcnt > 0) {
final int n = _activeCols.length;
final int[] oldActiveCols = _activeCols;
_activeCols = Arrays.copyOf(_activeCols, _activeCols.length + fcnt);
for (int i = 0; i < fcnt; ++i)
_activeCols[n + i] = failedCols[i];
Arrays.sort(_activeCols);
LogInfo(fcnt + " variables failed KKT conditions check! Adding them to the model and continuing computation.(grad_eps = " + grad_eps + ", activeCols = " + (_activeCols.length > 100?"lost":Arrays.toString(_activeCols)));
_activeData = _srcDinfo.filterExpandedColumns(_activeCols);
// NOTE: tricky completer game here:
// We expect 0 pending in this method since this is the end-point, ( actually it's racy, can be 1 with pending 1 decrement from the original Iteration callback, end result is 0 though)
// while iteration expects pending count of 1, so we need to increase it here (Iteration itself adds 1 but 1 will be subtracted when we leave this method since we're in the callback which is called by onCompletion!
// [unlike at the start of nextLambda call when we're not inside onCompletion]))
getCompleter().addToPendingCount(1);
new GLMIterationTask(_noffsets,GLM2.this.self(), _activeData, _glm, true, true, true, resizeVec(newBeta, _activeCols, oldActiveCols), _ymu, glmt._reg, thresholds, new Iteration(getCompleter())).asyncExec(_activeData._adaptedFrame);
return;
}
}
int diff = MAX_ITERATIONS_PER_LAMBDA - _iter + _iter1;
if(diff > 0)
new GLM2_ProgressUpdate(diff).fork(_progressKey); // update progress
GLM2.this.setSubmodel(newBeta, glmt2._val,(H2OCountedCompleter)getCompleter().getCompleter());
_done = true;
LogInfo("computation of current lambda done in " + (System.currentTimeMillis() - GLM2.this.start_time) + "ms");
assert _lastResult._fullGrad != null;
}
}).asyncExec(_srcDinfo._adaptedFrame);
}
private class Iteration extends H2OCallback<GLMIterationTask> {
public final long _iterationStartTime;
final boolean _countIteration;
final boolean _checkLineSearch;
public Iteration(CountedCompleter cmp){ this(cmp,true,true);}
public Iteration(CountedCompleter cmp, boolean countIteration,boolean checkLineSearch){
super((H2OCountedCompleter)cmp);
cmp.addToPendingCount(1);
_checkLineSearch = checkLineSearch;
_countIteration = countIteration;
_iterationStartTime = System.currentTimeMillis(); }
@Override public void callback(final GLMIterationTask glmt){
if( !isRunning(self()) ) throw new JobCancelledException();
assert _activeCols == null || glmt._beta == null || glmt._beta.length == (_activeCols.length+_intercept-glmt._noffsets):LogInfo("betalen = " + glmt._beta.length + ", activecols = " + _activeCols.length + " noffsets = " + glmt._noffsets);
assert _activeCols == null || _activeCols.length == _activeData.fullN();
assert getCompleter().getPendingCount() >= 1 : LogInfo("unexpected pending count, expected >= 1, got " + getCompleter().getPendingCount()); // will be decreased by 1 after we leave this callback
if (_countIteration) ++_iter;
_callbackStart = System.currentTimeMillis();
double gerr = Double.NaN;
boolean hasNaNs = glmt._gram.hasNaNsOrInfs() || Utils.hasNaNsOrInfs(glmt._xy);
boolean needLineSearch = hasNaNs || _checkLineSearch && needLineSearch(glmt);
if (glmt._val != null && glmt._computeGradient) { // check gradient
final double[] grad = glmt.gradient(alpha[0], _currentLambda);
ADMMSolver.subgrad(alpha[0], _currentLambda, glmt._beta, grad);
gerr = 0;
for (double d : grad)
gerr += d*d;
if(gerr <= GLM_GRAD_EPS*GLM_GRAD_EPS || (needLineSearch && gerr <= 5*ADMM_GRAD_EPS*ADMM_GRAD_EPS)){
LogInfo("converged by reaching small enough gradient, with max |subgradient| = " + gerr );
checkKKTAndComplete(getCompleter(),glmt, glmt._beta,false);
return;
}
}
if(needLineSearch){
if(!_checkLineSearch){ // has to converge here
LogInfo("Line search did not progress, converged.");
checkKKTAndComplete(getCompleter(),glmt, glmt._beta,true);
return;
}
LogInfo("invoking line search");
new GLMTask.GLMLineSearchTask(_noffsets, GLM2.this.self(), _activeData,_glm, lastBeta(_noffsets), glmt._beta, 1e-4, _ymu, _nobs, new LineSearchIteration(glmt,getCompleter())).asyncExec(_activeData._adaptedFrame);
return;
}
if(glmt._grad != null)
_lastResult = makeIterationInfo(_iter,glmt,_activeCols,null);
if(glmt._newThresholds != null) {
thresholds = Utils.join(glmt._newThresholds[0], glmt._newThresholds[1]);
Arrays.sort(thresholds);
}
final double [] newBeta = MemoryManager.malloc8d(glmt._xy.length);
long t1 = System.currentTimeMillis();
ADMMSolver slvr = new ADMMSolver(lambda_max, _currentLambda,alpha[0], _gradientEps, _addedL2);
if(_lbs != null)
slvr._lb = _activeCols == null?contractVec(_lbs,_activeCols,0):_lbs;
if(_ubs != null)
slvr._ub = _activeCols == null?contractVec(_ubs,_activeCols,0):_ubs;
if(_bgs != null && _rho != null) {
slvr._wgiven = _activeCols == null ? contractVec(_bgs, _activeCols, 0) : _bgs;
slvr._proximalPenalties = _activeCols == null ? contractVec(_rho, _activeCols, 0) : _rho;
}
slvr.solve(glmt._gram,glmt._xy,glmt._yy,newBeta,Math.max(1e-8*lambda_max,_currentLambda*alpha[0]));
// print all info about iteration
LogInfo("Gram computed in " + (_callbackStart - _iterationStartTime) + "ms, " + (Double.isNaN(gerr)?"":"gradient = " + gerr + ",") + ", step = " + 1 + ", ADMM: " + slvr.iterations + " iterations, " + (System.currentTimeMillis() - t1) + "ms (" + slvr.decompTime + "), subgrad_err=" + slvr.gerr);
// int [] iBlocks = new int[]{8,16,32,64,128,256,512,1024};
// int [] rBlocks = new int[]{1,2,4,8,16,32,64,128};
// for(int i:iBlocks)
// for(int r:rBlocks){
// long ttx = System.currentTimeMillis();
// try {
// slvr.gerr = Double.POSITIVE_INFINITY;
// ADMMSolver.ParallelSolver pslvr = slvr.parSolver(glmt._gram, glmt._wy, newBeta, _currentLambda * alpha[0] * _rho_mul, i, r);
// pslvr.invoke();
// System.out.println("iBlock = " + i + ", rBlocsk = " + r + "ms");
// LogInfo("ADMM: " + pslvr._iter + " iterations, " + (System.currentTimeMillis() - ttx) + "ms (" + slvr.decompTime + "), subgrad_err=" + slvr.gerr);
// } catch(Throwable t){
// System.out.println("iBlock = " + i + ", rBlocsk = " + r + " failed! err = " + t);
// }
// }
if (slvr._addedL2 > _addedL2) LogInfo("added " + (slvr._addedL2 - _addedL2) + "L2 penalty");
new GLM2_ProgressUpdate().fork(_progressKey); // update progress
_gradientEps = Math.max(ADMM_GRAD_EPS, Math.min(slvr.gerr, 0.01));
_addedL2 = slvr._addedL2;
if (Utils.hasNaNsOrInfs(newBeta)) {
throw new RuntimeException(LogInfo("got NaNs and/or Infs in beta"));
} else {
final double bdiff = beta_diff(glmt._beta, newBeta);
if(_glm.family == Family.gaussian && _glm.link == Link.identity) {
checkKKTAndComplete(getCompleter(),glmt, newBeta, false);
return;
} else if (bdiff < beta_epsilon || _iter >= max_iter) { // Gaussian is non-iterative and gradient is ADMMSolver's gradient => just validate and move on to the next lambda_value
int diff = (int) Math.log10(bdiff);
int nzs = 0;
for (int i = 0; i < glmt._beta.length; ++i)
if (glmt._beta[i] != 0) ++nzs;
LogInfo("converged (reached a fixed point with ~ 1e" + diff + " precision), got " + nzs + " nzs");
checkKKTAndComplete(getCompleter(),glmt, newBeta, false); // NOTE: do not use newBeta here, it has not been checked and can lead to NaNs in KKT check, redoing line search, coming up with the same beta and so on.