diff --git a/src/SparseSC/utils/match_space.py b/src/SparseSC/utils/match_space.py index f1ab31d..eab6d43 100644 --- a/src/SparseSC/utils/match_space.py +++ b/src/SparseSC/utils/match_space.py @@ -204,7 +204,7 @@ def _MTLasso_MatchSpace( transformer = SelMatchSpace(m_sel) return transformer, V[m_sel], v_pen, (V, varselectorfit) -def D_LassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5): +def D_LassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, fit_args={}): """ Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X and Lasso of D_full ~ X_full and then combines the coefficients into weights using y_V_share @@ -224,6 +224,7 @@ def _D_LassoCV_MatchSpace_wrapper(X, Y, **kwargs): n_v_cv=n_v_cv, sample_frac=sample_frac, y_V_share=y_V_share, + fit_args=fit_args, **kwargs ) @@ -231,7 +232,7 @@ def _D_LassoCV_MatchSpace_wrapper(X, Y, **kwargs): def _D_LassoCV_MatchSpace( - X, Y, X_full, D_full, v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, **kwargs + X, Y, X_full, D_full, v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, fit_args={}, **kwargs ): # pylint: disable=missing-param-doc, unused-argument if sample_frac < 1: N_y = X.shape[0] @@ -242,7 +243,7 @@ def _D_LassoCV_MatchSpace( sample_d = np.random.choice(N_d, int(sample_frac * N_d), replace=False) X_full = X_full[sample_d, :] D_full = D_full[sample_d] - y_varselectorfit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens).fit( + y_varselectorfit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens, **fit_args).fit( X, Y ) y_V = np.sqrt( @@ -465,7 +466,7 @@ def transform(self, X): return M -def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5): +def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5, fit_args={}): """ Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X with the penalization optimized to reduce errors on goal units @@ -476,17 +477,17 @@ def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5): def _MTLassoMixed_MatchSpace_wrapper(X, Y, fit_model_wrapper, **kwargs): return _MTLassoMixed_MatchSpace( - X, Y, fit_model_wrapper, v_pens=v_pens, n_v_cv=n_v_cv, **kwargs + X, Y, fit_model_wrapper, v_pens=v_pens, n_v_cv=n_v_cv, fit_args=fit_args, **kwargs ) return _MTLassoMixed_MatchSpace_wrapper def _MTLassoMixed_MatchSpace( - X, Y, fit_model_wrapper, v_pens=None, n_v_cv=5, **kwargs + X, Y, fit_model_wrapper, v_pens=None, n_v_cv=5, fit_args={}, **kwargs ): # pylint: disable=missing-param-doc, unused-argument # Note that MultiTaskLasso(CV).path with the same alpha doesn't produce same results as MultiTaskLasso(CV) - mtlasso_cv_fit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens).fit( + mtlasso_cv_fit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens, **fit_args).fit( X, Y ) # V_cv = np.sqrt(np.sum(np.square(mtlasso_cv_fit.coef_), axis=0)) #n_tasks x n_features -> n_feature