Skip to content

Commit

Permalink
sklearn api compat
Browse files Browse the repository at this point in the history
  • Loading branch information
mathurinm committed Apr 13, 2018
1 parent 54bc074 commit d48550b
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 29 deletions.
3 changes: 2 additions & 1 deletion celer/dense.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def celer_dense(double[::1, :] X,
print("Log gap %.2e" % gap)

if gap < tol:
print("Early exit, gap: %.2e < %.2e" % (gap, tol))
if verbose:
print("Early exit, gap: %.2e < %.2e" % (gap, tol))
break

set_feature_prios_dense(n_samples, n_features, theta, X,
Expand Down
58 changes: 41 additions & 17 deletions celer/homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from .wrapper import celer


def celer_path(X, y, alphas=None, max_iter=20, gap_freq=10,
max_epochs_inner=50000, p0=10, verbose=1, verbose_inner=1,
tol=1e-6, prune=0):
def celer_path(X, y, eps=1e-3, n_alphas=100, alphas=None, max_iter=20,
gap_freq=10, max_epochs_inner=50000, p0=10, verbose=1,
verbose_inner=1, tol=1e-6, prune=0, return_thetas=False):
"""Compute Lasso path with Celer as inner solver.
Parameters
Expand All @@ -17,7 +17,14 @@ def celer_path(X, y, alphas=None, max_iter=20, gap_freq=10,
y : ndarray, shape (n_samples,)
Target values
alphas : ndarray, shape (n_alphas,)
eps : float, optional
Length of the path. ```eps=1e-3`` means that
``alpha_min = 1e-3 * alpha_max``
n_alphas : int, optional
Number of alphas along the regularization path
alphas : ndarray, optional
List of alphas where to compute the models.
If ``None`` alphas are set automatically
Expand Down Expand Up @@ -48,38 +55,51 @@ def celer_path(X, y, alphas=None, max_iter=20, gap_freq=10,
prune : 0 | 1
Whether or not to use pruning when growing working sets.
return_thetas : bool
If True, dual variables along the path are returned.
Returns
-------
betas : array, shape (n_alphas, n_features)
alpha : array, shape (n,_alphas,)
The alphas along the path where models are computed.
coefs : array, shape (n_features, n_alphas)
Coefficients along the path.
dual_gaps : array, shape (n_alphas,)
Duality gaps returned by the solver along the path.
thetas : array, shape (n_alphas, n_samples)
The dual variables along the path.
final_gaps : array, shape (n_alphas,)
Duality gaps returned by the solver along the path.
"""
if alphas is None:
alpha_max = np.max(np.abs(X.T.dot(y)))
alphas = alpha_max * np.logspace(0, np.log10(eps), n_alphas)
else:
alphas = np.sort(alphas)[::-1]

n_alphas = len(alphas)
n_samples, n_features = X.shape
assert alphas[0] > alphas[-1] # alphas must be given in decreasing order

betas = np.zeros((n_alphas, n_features))
coefs = np.zeros((n_features, n_alphas), order='F') # sklearn API
thetas = np.zeros((n_alphas, n_samples))
final_gaps = np.zeros(n_alphas)
dual_gaps = np.zeros(n_alphas)
all_times = np.zeros(n_alphas)

# skip alpha_max and use decreasing alphas
thetas[0] = y / alphas[0] # don't forget to set this one
thetas[0] = y / alphas[0]
for t in range(1, n_alphas):
if verbose:
print("#" * 60)
print(" ##### Computing %dth alpha" % (t + 1))
print("#" * 60)
if t > 1:
beta_init = betas[t - 1].copy()
beta_init = coefs[:, t - 1].copy()
p_t = max(len(np.where(beta_init != 0)[0]), 1)
else:
beta_init = betas[t]
beta_init = coefs[:, t].copy()
p_t = 10

alpha = alphas[t]
Expand All @@ -91,8 +111,12 @@ def celer_path(X, y, alphas=None, max_iter=20, gap_freq=10,
tol=tol, prune=prune)

all_times[t] = time.time() - t0
betas[t], thetas[t], final_gaps[t] = sol[0], sol[1], sol[2][-1]
if final_gaps[t] > tol:
coefs[:, t], thetas[t], dual_gaps[t] = sol[0], sol[1], sol[2][-1]
if dual_gaps[t] > tol:
print("-----WARNING: solver did not converge, t=%d" % t)
print("gap=%.1e, tol=%.1e" % (final_gaps[t], tol))
return betas, thetas, final_gaps
print("gap=%.1e, tol=%.1e" % (dual_gaps[t], tol))

if return_thetas:
return alphas, coefs, dual_gaps, thetas
else:
return alphas, coefs, dual_gaps
3 changes: 2 additions & 1 deletion celer/sparse.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def celer_sparse(double[:] X_data,
print("Log gap %.2e" % gap)

if gap < tol:
print("Early exit, gap: %.2e < %.2e" % (gap, tol))
if verbose:
print("Early exit, gap: %.2e < %.2e" % (gap, tol))
break

set_feature_prios_sparse(n_features, &theta[0],
Expand Down
10 changes: 5 additions & 5 deletions examples/plot_finance_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@
results = np.zeros([1, len(tols)])
for tol_ix, tol in enumerate(tols):
t0 = time.time()
res = celer_path(X, y, alphas, max_iter=100, gap_freq=gap_freq,
max_epochs_inner=50000, p0=100, verbose=verbose,
verbose_inner=verbose_inner,
tol=tol, prune=prune)
res = celer_path(X, y, alphas=alphas, max_iter=100, gap_freq=gap_freq,
p0=100, verbose=verbose, verbose_inner=verbose_inner,
tol=tol, prune=prune, return_thetas=True)
results[0, tol_ix] = time.time() - t0
betas, thetas, gaps = res
_, coefs, gaps, thetas = res
betas = coefs.T

labels = [r"\sc{CELER}"]
figsize = (7, 4)
Expand Down
10 changes: 5 additions & 5 deletions examples/plot_leukemia_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
results = np.zeros([2, len(tols)])
for tol_ix, tol in enumerate(tols):
t0 = time.time()
res = celer_path(X, y, alphas, max_iter=100, gap_freq=gap_freq,
max_epochs_inner=50000, p0=100, verbose=verbose,
verbose_inner=verbose_inner,
tol=tol, prune=prune)
res = celer_path(X, y, alphas=alphas, max_iter=100, gap_freq=gap_freq,
p0=100, verbose=verbose, verbose_inner=verbose_inner,
tol=tol, prune=prune, return_thetas=True)
results[0, tol_ix] = time.time() - t0
print('Celer time: %.2f s' % results[0, tol_ix])
betas, thetas, gaps = res
_, coefs, gaps, thetas = res
betas = coefs.T

t0 = time.time()
_, coefs, dual_gaps = lasso_path(X, y, tol=tol, alphas=alphas / n_samples)
Expand Down

0 comments on commit d48550b

Please sign in to comment.