From ba5a168c386fd818845a83858451732082df01c7 Mon Sep 17 00:00:00 2001 From: "Dougal J. Sutherland" Date: Tue, 21 Mar 2017 14:16:18 +0000 Subject: [PATCH] play nicer with discontiguous inputs, skl_groups Features objects --- mmd/mmd.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/mmd/mmd.py b/mmd/mmd.py index 64bf96c..88ce73c 100644 --- a/mmd/mmd.py +++ b/mmd/mmd.py @@ -104,18 +104,26 @@ def _atonce_rbf_mmk(X, Y=None, gammas=1, get_X_diag=False, get_Y_diag=False): return K[0] if scalar_gamma else K +def _get_stacked(X): + if hasattr(X, 'n_samps'): # is an skl_groups Features object + X.make_stacked() + return X.stacked_features, X.n_samps + + X_stacked = np.ascontiguousarray(np.vstack(X)) + X_n_samps = np.array([len(x) for x in X], dtype=np.int32) + return X_stacked, X_n_samps + + def _el_by_el_rbf_mmk(X, Y=None, gammas=1, get_X_diag=False, get_Y_diag=False, n_jobs=1): - X_stacked = np.vstack(X) - X_n_samps = np.array([len(x) for x in X], dtype=np.int32) + X_stacked, X_n_samps = _get_stacked(X) if Y is None or Y is X: Y_stacked = X_stacked Y_n_samps = X_n_samps X_is_Y = True else: - Y_stacked = np.vstack(Y) - Y_n_samps = np.array([len(y) for y in Y], dtype=np.int32) + Y_stacked, Y_n_samps = _get_stacked(Y) X_is_Y = False if all(isinstance(h, logging.NullHandler)