Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/SparseSC/utils/match_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _MTLasso_MatchSpace(
transformer = SelMatchSpace(m_sel)
return transformer, V[m_sel], v_pen, (V, varselectorfit)

def D_LassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5):
def D_LassoCV_MatchSpace_factory(v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, fit_args={}):
"""
Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X and Lasso of D_full ~ X_full
and then combines the coefficients into weights using y_V_share
Expand All @@ -224,14 +224,15 @@ def _D_LassoCV_MatchSpace_wrapper(X, Y, **kwargs):
n_v_cv=n_v_cv,
sample_frac=sample_frac,
y_V_share=y_V_share,
fit_args=fit_args,
**kwargs
)

return _D_LassoCV_MatchSpace_wrapper


def _D_LassoCV_MatchSpace(
X, Y, X_full, D_full, v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, **kwargs
X, Y, X_full, D_full, v_pens=None, n_v_cv=5, sample_frac=1, y_V_share=0.5, fit_args={}, **kwargs
): # pylint: disable=missing-param-doc, unused-argument
if sample_frac < 1:
N_y = X.shape[0]
Expand All @@ -242,7 +243,7 @@ def _D_LassoCV_MatchSpace(
sample_d = np.random.choice(N_d, int(sample_frac * N_d), replace=False)
X_full = X_full[sample_d, :]
D_full = D_full[sample_d]
y_varselectorfit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens).fit(
y_varselectorfit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens, **fit_args).fit(
X, Y
)
y_V = np.sqrt(
Expand Down Expand Up @@ -465,7 +466,7 @@ def transform(self, X):
return M


def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5):
def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5, fit_args={}):
"""
Return a MatchSpace function that will fit a MultiTaskLassoCV for Y ~ X with the penalization optimized to reduce errors on goal units
Expand All @@ -476,17 +477,17 @@ def MTLassoMixed_MatchSpace_factory(v_pens=None, n_v_cv=5):

def _MTLassoMixed_MatchSpace_wrapper(X, Y, fit_model_wrapper, **kwargs):
return _MTLassoMixed_MatchSpace(
X, Y, fit_model_wrapper, v_pens=v_pens, n_v_cv=n_v_cv, **kwargs
X, Y, fit_model_wrapper, v_pens=v_pens, n_v_cv=n_v_cv, fit_args=fit_args, **kwargs
)

return _MTLassoMixed_MatchSpace_wrapper


def _MTLassoMixed_MatchSpace(
X, Y, fit_model_wrapper, v_pens=None, n_v_cv=5, **kwargs
X, Y, fit_model_wrapper, v_pens=None, n_v_cv=5, fit_args={}, **kwargs
): # pylint: disable=missing-param-doc, unused-argument
# Note that MultiTaskLasso(CV).path with the same alpha doesn't produce same results as MultiTaskLasso(CV)
mtlasso_cv_fit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens).fit(
mtlasso_cv_fit = MultiTaskLassoCV(normalize=True, cv=n_v_cv, alphas=v_pens, **fit_args).fit(
X, Y
)
# V_cv = np.sqrt(np.sum(np.square(mtlasso_cv_fit.coef_), axis=0)) #n_tasks x n_features -> n_feature
Expand Down