Skip to content

Commit

Permalink
play nicer with discontiguous inputs, skl_groups Features objects
Browse files Browse the repository at this point in the history
  • Loading branch information
djsutherland committed Mar 21, 2017
1 parent 3f3a1d7 commit ba5a168
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions mmd/mmd.py
Expand Up @@ -104,18 +104,26 @@ def _atonce_rbf_mmk(X, Y=None, gammas=1, get_X_diag=False, get_Y_diag=False):
return K[0] if scalar_gamma else K


def _get_stacked(X):
if hasattr(X, 'n_samps'): # is an skl_groups Features object
X.make_stacked()
return X.stacked_features, X.n_samps

X_stacked = np.ascontiguousarray(np.vstack(X))
X_n_samps = np.array([len(x) for x in X], dtype=np.int32)
return X_stacked, X_n_samps


def _el_by_el_rbf_mmk(X, Y=None, gammas=1, get_X_diag=False, get_Y_diag=False,
n_jobs=1):
X_stacked = np.vstack(X)
X_n_samps = np.array([len(x) for x in X], dtype=np.int32)
X_stacked, X_n_samps = _get_stacked(X)

if Y is None or Y is X:
Y_stacked = X_stacked
Y_n_samps = X_n_samps
X_is_Y = True
else:
Y_stacked = np.vstack(Y)
Y_n_samps = np.array([len(y) for y in Y], dtype=np.int32)
Y_stacked, Y_n_samps = _get_stacked(Y)
X_is_Y = False

if all(isinstance(h, logging.NullHandler)
Expand Down

0 comments on commit ba5a168

Please sign in to comment.