Skip to content

Commit

Permalink
Change in behavior in glm beta constraints - when ignoring constant/b…
Browse files Browse the repository at this point in the history
…ad columns, remove them from beta_constraints as well.
  • Loading branch information
Tomas Nykodym committed Nov 25, 2015
1 parent a06b9af commit 315201e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
25 changes: 15 additions & 10 deletions h2o-algos/src/main/java/hex/glm/GLM.java
Expand Up @@ -247,9 +247,14 @@ public synchronized TwoDimTable to2dTable() {
m.put(names[i], i);
int[] newMap = MemoryManager.malloc4(dom.length);
for (int i = 0; i < map.length; ++i) {
if(_removedCols.contains(dom[map[i]])){
newMap[i] = -1;
continue;
}
Integer I = m.get(dom[map[i]]);
if (I == null)
if (I == null) {
throw new IllegalArgumentException("Unrecognized coefficient name in beta-constraint file, unknown name '" + dom[map[i]] + "'");
}
newMap[i] = I;
}
map = newMap;
Expand All @@ -262,35 +267,35 @@ public synchronized TwoDimTable to2dTable() {
error("beta_constraints","Unknown column name '" + s + "'");
if ((v = beta_constraints.vec("beta_start")) != null) {
betaStart = MemoryManager.malloc8d(_dinfo.fullN() + (_dinfo._intercept ? 1 : 0));
for (int i = 0; i < (int) v.length(); ++i)
betaStart[map == null ? i : map[i]] = v.at(i);
for (int i = 0; i < (int) v.length(); ++i) if(map[i] != -1)
betaStart[map[i]] = v.at(i);
}
if ((v = beta_constraints.vec("beta_given")) != null) {
betaGiven = MemoryManager.malloc8d(_dinfo.fullN() + (_dinfo._intercept ? 1 : 0));
for (int i = 0; i < (int) v.length(); ++i)
for (int i = 0; i < (int) v.length(); ++i) if(map[i] != -1)
betaGiven[map == null ? i : map[i]] = v.at(i);
}
if ((v = beta_constraints.vec("upper_bounds")) != null) {
betaUB = MemoryManager.malloc8d(_dinfo.fullN() + (_dinfo._intercept ? 1 : 0));
Arrays.fill(betaUB, Double.POSITIVE_INFINITY);
for (int i = 0; i < (int) v.length(); ++i)
for (int i = 0; i < (int) v.length(); ++i) if(map[i] != -1)
betaUB[map == null ? i : map[i]] = v.at(i);
}
if ((v = beta_constraints.vec("lower_bounds")) != null) {
betaLB = MemoryManager.malloc8d(_dinfo.fullN() + (_dinfo._intercept ? 1 : 0));
Arrays.fill(betaLB, Double.NEGATIVE_INFINITY);
for (int i = 0; i < (int) v.length(); ++i)
for (int i = 0; i < (int) v.length(); ++i) if(map[i] != -1)
betaLB[map == null ? i : map[i]] = v.at(i);
}
if ((v = beta_constraints.vec("rho")) != null) {
rho = MemoryManager.malloc8d(_dinfo.fullN() + (_dinfo._intercept ? 1 : 0));
for (int i = 0; i < (int) v.length(); ++i)
for (int i = 0; i < (int) v.length(); ++i) if(map[i] != -1)
rho[map == null ? i : map[i]] = v.at(i);
}
// mean override (for data standardization)
if ((v = beta_constraints.vec("mean")) != null) {
for(int i = 0; i < v.length(); ++i) {
if(!v.isNA(i)) {
if(!v.isNA(i) && map[i] != -1) {
int idx = map == null ? i : map[i];
if (idx > _dinfo.numStart() && idx < _dinfo.fullN()) {
_dinfo._normSub[idx - _dinfo.numStart()] = v.at(i);
Expand All @@ -303,7 +308,7 @@ public synchronized TwoDimTable to2dTable() {
// standard deviation override (for data standardization)
if ((v = beta_constraints.vec("std_dev")) != null) {
for (int i = 0; i < v.length(); ++i) {
if (!v.isNA(i)) {
if (!v.isNA(i) && map[i] != -1) {
int idx = map == null ? i : map[i];
if (idx > _dinfo.numStart() && idx < _dinfo.fullN()) {
_dinfo._normMul[idx - _dinfo.numStart()] = 1.0/v.at(i);
Expand Down Expand Up @@ -999,7 +1004,7 @@ public void callback(H2OCountedCompleter h2OCountedCompleter) {
double objVal(double likelihood, double[] beta, double lambda) {
double alpha = _parms._alpha[0];
double proximalPen = 0;
if (_bc._betaGiven != null) {
if (_bc._betaGiven != null && _bc._rho != null) {
for (int i = 0; i < _bc._betaGiven.length; ++i) {
double diff = beta[i] - _bc._betaGiven[i];
proximalPen += diff * diff * _bc._rho[i];
Expand Down
3 changes: 3 additions & 0 deletions h2o-core/src/main/java/hex/ModelBuilder.java
Expand Up @@ -20,6 +20,7 @@
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

/**
Expand Down Expand Up @@ -1001,6 +1002,7 @@ public void checkDistributions() {
}
}

protected transient HashSet<String> _removedCols = new HashSet<String>();
abstract class FilterCols {
final int _specialVecs; // special vecs to skip at the end
public FilterCols(int n) {_specialVecs = n;}
Expand All @@ -1014,6 +1016,7 @@ void doIt( Frame f, String msg, boolean expensive ) {
if( any ) msg += ", "; // Log dropped cols
any = true;
msg += f._names[i];
_removedCols.add(f._names[i]);
f.remove(i);
i--; // Re-run at same iteration after dropping a col
}
Expand Down

0 comments on commit 315201e

Please sign in to comment.