Skip to content

Commit

Permalink
Fix sanity filter
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Sep 2, 2019
1 parent 3c8842a commit 57da1d3
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,14 @@ def fit(self, X, B, T, W=None):

if W is None:
W = [1] * len(X)
XBTW = [(x, b, t, w) for x, b, t, w in zip(X, B, T, W)
if t > 0 and 0 <= float(b) <= 1 and w >= 0]
if len(XBTW) < len(X):
n_removed = len(X) - len(XBTW)
X, B, T, W = (Z if type(Z) == numpy.ndarray else numpy.array(Z)
for Z in (X, B, T, W))
keep_indexes = (T > 0) & (B >= 0) & (B <= 1) & (W >= 0)
if sum(keep_indexes) < X.shape[0]:
n_removed = X.shape[0] - sum(keep_indexes)
warnings.warn('Warning! Removed %d/%d entries from inputs where '
'T <= 0 or B not 0/1 or W < 0' % (n_removed, len(X)))
X, B, T, W = (numpy.array([z[i] for z in XBTW], dtype=numpy.float32)
for i in range(4))
X, B, T, W = (Z[keep_indexes] for Z in (X, B, T, W))
n_features = X.shape[1]

# scipy.optimize and emcee forces the the parameters to be a vector:
Expand Down

0 comments on commit 57da1d3

Please sign in to comment.