Skip to content

Commit

Permalink
ENH use weights in penalties (#131)
Browse files Browse the repository at this point in the history
* Pass memoryviews and access n_samples/n_features as foo.shape

* failing test (probably due to wrong screening)

* Fix screening

* fix wrong parameter order
  • Loading branch information
mathurinm committed Nov 6, 2020
1 parent 1783b7c commit 76090e0
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 90 deletions.
31 changes: 16 additions & 15 deletions celer/PN_logreg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from libc.math cimport fabs, sqrt, exp

from .cython_utils cimport fdot, faxpy, fcopy, fposv, fscal, fnrm2
from .cython_utils cimport (primal, dual, create_dual_pt, create_accel_pt,
sigmoid, ST, LOGREG, compute_dual_scaling,
sigmoid, ST, LOGREG, dnorm_l1,
compute_Xw, compute_norms_X_col, set_prios)

cdef:
Expand Down Expand Up @@ -48,6 +48,7 @@ def newton_celer(

cdef int n_samples = y.shape[0]
cdef int n_features = w.shape[0]
cdef floating[:] weights_pen = np.ones(n_features, dtype=dtype)
cdef int[:] all_features = np.arange(n_features, dtype=np.int32)
cdef floating[:] prios = np.empty(n_features, dtype=dtype)
cdef int[:] WS
Expand Down Expand Up @@ -97,14 +98,13 @@ def newton_celer(
cdef bint positive = 0

for t in range(max_iter):
p_obj = primal(LOGREG, alpha, n_samples, &Xw[0], &y[0], n_features,
&w[0])
p_obj = primal(LOGREG, alpha, Xw, y, w, weights_pen)

# theta = y * sigmoid(-y * Xw) / alpha
create_dual_pt(LOGREG, n_samples, alpha, &theta[0], &Xw[0], &y[0])
norm_Xtheta = compute_dual_scaling(
norm_Xtheta = dnorm_l1(
is_sparse, theta, X, X_data, X_indices, X_indptr,
screened, X_mean, center, positive)
screened, X_mean, weights_pen, center, positive)

if norm_Xtheta > 1.:
tmp = 1. / norm_Xtheta
Expand Down Expand Up @@ -164,9 +164,9 @@ def newton_celer(
for i in range(n_samples):
exp_Xw[i] = exp(Xw[i])

norm_Xtheta_acc = compute_dual_scaling(
norm_Xtheta_acc = dnorm_l1(
is_sparse, theta_acc, X, X_data, X_indices, X_indptr,
screened, X_mean, center, positive)
screened, X_mean, weights_pen, center, positive)

if norm_Xtheta_acc > 1.:
tmp = 1. / norm_Xtheta_acc
Expand All @@ -188,7 +188,8 @@ def newton_celer(


set_prios(is_sparse, theta, X, X_data, X_indices, X_indptr,
norms_X_col, prios, screened, radius, &n_screened, 0)
norms_X_col, weights_pen, prios, screened, radius,
&n_screened, 0)

if prune:
if t == 0:
Expand Down Expand Up @@ -218,8 +219,8 @@ def newton_celer(

PN_logreg(is_sparse, w, WS, X, X_data, X_indices, X_indptr, y,
alpha, tol_inner, Xw, exp_Xw, low_exp_Xw,
aux, is_positive_label, X_mean, center, blitz_sc,
verbose_in, max_pn_iter)
aux, is_positive_label, X_mean, weights_pen, center,
blitz_sc, verbose_in, max_pn_iter)

return np.asarray(w), np.asarray(theta), np.asarray(gaps[:t + 1])

Expand All @@ -234,7 +235,8 @@ cpdef int PN_logreg(
floating tol_inner, floating[:] Xw,
floating[:] exp_Xw, floating[:] low_exp_Xw, floating[:] aux,
int[:] is_positive_label, floating[:] X_mean,
bint center, bint blitz_sc, int verbose_in, int max_pn_iter):
floating[:] weights_pen, bint center, bint blitz_sc, int verbose_in,
int max_pn_iter):

cdef int n_samples = Xw.shape[0]
cdef int ws_size = WS.shape[0]
Expand Down Expand Up @@ -360,16 +362,15 @@ cpdef int PN_logreg(

else:
# rescale aux to create dual point
norm_Xaux = compute_dual_scaling(
norm_Xaux = dnorm_l1(
is_sparse, aux, X, X_data, X_indices, X_indptr,
notin_WS, X_mean, center, 0)
notin_WS, X_mean, weights_pen, center, 0)

for i in range(n_samples):
aux[i] /= max(1, norm_Xaux)

d_obj = dual(LOGREG, n_samples, alpha, 0, &aux[0], &y[0])
p_obj = primal(LOGREG, alpha, n_samples, &Xw[0], &y[0],
n_features, &w[0])
p_obj = primal(LOGREG, alpha, Xw, y, w, weights_pen)

gap = p_obj - d_obj
if verbose_in:
Expand Down
13 changes: 6 additions & 7 deletions celer/cython_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ cdef int LOGREG
cdef floating ST(floating, floating) nogil

cdef floating dual(int, int, floating, floating, floating *, floating *) nogil
cdef floating primal(int, floating, int, floating *, floating *,
int, floating *) nogil
cdef floating primal(int, floating, floating[:], floating [:],
floating [:], floating[:]) nogil
cdef void create_dual_pt(int, int, floating, floating *, floating *, floating *) nogil

cdef floating Nh(floating) nogil
cdef floating sigmoid(floating) nogil
# cdef floating log_1pexp(floating) nogil

cdef floating fdot(int *, floating *, int *, floating *, int *) nogil
cdef floating fasum(int *, floating *, int *) nogil
Expand All @@ -25,7 +24,7 @@ cdef void fcopy(int *, floating *, int *, floating *, int *) nogil
cdef void fscal(int *, floating *, floating *, int *) nogil

cdef void fposv(char *, int *, int *, floating *,
int *, floating *, int *, int *) nogil
int *, floating *, int *, int *) nogil

cdef int create_accel_pt(
int, int, int, int, floating, floating *, floating *,
Expand All @@ -43,11 +42,11 @@ cpdef void compute_norms_X_col(
floating[:], int[:], int[:], floating[:])


cdef floating compute_dual_scaling(
cpdef floating dnorm_l1(
bint, floating[:], floating[::1, :], floating[:],
int[:], int[:], int[:], floating[:], bint, bint) nogil
int[:], int[:], int[:], floating[:], floating[:], bint, bint) nogil


cdef void set_prios(
bint, floating[:], floating[::1, :], floating[:], int[:],
int[:], floating[:], floating[:], int[:], floating, int *, bint) nogil
int[:], floating[:], floating[:], floating[:], int[:], floating, int *, bint) nogil
59 changes: 36 additions & 23 deletions celer/cython_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,46 @@ cdef inline floating sigmoid(floating x) nogil:
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef floating primal_logreg(floating alpha, int n_samples, floating * Xw,
floating * y, int n_features, floating * w) nogil:
cdef floating primal_logreg(
floating alpha, floating[:] Xw, floating[:] y, floating[:] w,
floating[:] weights) nogil:
cdef int inc = 1
cdef floating p_obj = alpha * fasum(&n_features, &w[0], &inc)
cdef int i = 0
cdef int n_samples = Xw.shape[0]
cdef int n_features = w.shape[0]
cdef floating p_obj = 0.
cdef int i, j
for i in range(n_samples):
p_obj += log_1pexp(- y[i] * Xw[i])
for j in range(n_features):
p_obj += alpha * weights[j] * fabs(w[j])
return p_obj


# todo check normalization by 1 / n_samples everywhere
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef floating primal_lasso(floating alpha, int n_samples, floating * R,
int n_features, floating * w) nogil:
cdef floating primal_lasso(
floating alpha, floating[:] R, floating[:] w,
floating[:] weights) nogil:
cdef int n_samples = R.shape[0]
cdef int n_features = w.shape[0]
cdef int inc = 1
cdef floating p_obj = alpha * fasum(&n_features, w, &inc)
p_obj += fdot(&n_samples, R, &inc, R, &inc) / (2. * n_samples)
cdef int j
cdef floating p_obj = 0.
p_obj = fdot(&n_samples, &R[0], &inc, &R[0], &inc) / (2. * n_samples)
for j in range(n_features):
p_obj += alpha * weights[j] * fabs(w[j])
return p_obj


cdef floating primal(
int pb, floating alpha, int n_samples, floating * R, floating * y,
int n_features, floating * w) nogil:
int pb, floating alpha, floating[:] R, floating[:] y,
floating[:] w, floating[:] weights) nogil:
if pb == LASSO:
return primal_lasso(alpha, n_samples, &R[0], n_features, &w[0])
return primal_lasso(alpha, R, w, weights)
else:
return primal_logreg(alpha, n_samples, &R[0], &y[0], n_features,
&w[0])
return primal_logreg(alpha, R, y, w, weights)


@cython.boundscheck(False)
Expand Down Expand Up @@ -348,10 +358,11 @@ cpdef void compute_Xw(
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef floating compute_dual_scaling(
cpdef floating dnorm_l1(
bint is_sparse, floating[:] theta, floating[::1, :] X,
floating[:] X_data, int[:] X_indices, int[:] X_indptr, int[:] skip,
floating[:] X_mean, bint center, bint positive) nogil:
floating[:] X_mean, floating[:] weights, bint center,
bint positive) nogil:
"""compute norm(X[:, ~skip].T.dot(theta), ord=inf)"""
cdef int n_samples = theta.shape[0]
cdef int n_features = skip.shape[0]
Expand All @@ -368,7 +379,7 @@ cdef floating compute_dual_scaling(

# max over feature for which skip[j] == False
for j in range(n_features):
if skip[j]:
if skip[j] or weights[j] == 0:
continue
if is_sparse:
startptr = X_indptr[j]
Expand All @@ -383,7 +394,7 @@ cdef floating compute_dual_scaling(

if not positive:
Xj_theta = fabs(Xj_theta)
scal = max(scal, Xj_theta)
scal = max(scal, Xj_theta / weights[j])
return scal


Expand All @@ -393,16 +404,17 @@ cdef floating compute_dual_scaling(
cdef void set_prios(
bint is_sparse, floating[:] theta,
floating[::1, :] X, floating[:] X_data, int[:] X_indices, int[:] X_indptr,
floating[:] norms_X_col, floating[:] prios, int[:] screened, floating radius,
int * n_screened, bint positive) nogil:
floating[:] norms_X_col, floating[:] weights, floating[:] prios,
int[:] screened, floating radius, int * n_screened, bint positive) nogil:
cdef int i, j, startptr, endptr
cdef floating Xj_theta
cdef int n_samples = theta.shape[0]
cdef int n_features = prios.shape[0]

# TODO we do not substract theta_sum, which seems to indicate that theta is always centered...
# TODO we do not substract theta_sum, which seems to indicate that theta
# is always centered...
for j in range(n_features):
if screened[j] or norms_X_col[j] == 0.:
if screened[j] or norms_X_col[j] == 0. or weights[j] == 0.:
prios[j] = 10000
continue
if is_sparse:
Expand All @@ -414,10 +426,11 @@ cdef void set_prios(
else:
Xj_theta = fdot(&n_samples, &theta[0], &inc, &X[0, j], &inc)


if positive:
prios[j] = fabs(Xj_theta - 1.) / norms_X_col[j]
prios[j] = fabs(Xj_theta - weights[j]) / norms_X_col[j]
else:
prios[j] = (1. - fabs(Xj_theta)) / norms_X_col[j]
prios[j] = (weights[j] - fabs(Xj_theta)) / norms_X_col[j]

if prios[j] > radius:
screened[j] = True
Expand Down
14 changes: 10 additions & 4 deletions celer/dropin_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Lasso(Lasso_sklearn):
The optimization objective for Lasso is::
(1 / (2 * n_samples)) * ||y - X w||^2_2 + alpha * ||w||_1
(1 / (2 * n_samples)) * ||y - X w||^2_2 + alpha * \sum_j weights_j |w_j|
Parameters
----------
Expand Down Expand Up @@ -58,6 +58,10 @@ class Lasso(Lasso_sklearn):
fit_intercept : bool, optional (default=True)
Whether or not to fit an intercept.
weights : array, shape (n_features,), optional (default=None)
Weights used in the L1 penalty part of the Lasso objective.
If None, weights equal to 1 are used.
normalize : bool, optional (default=False)
This parameter is ignored when ``fit_intercept`` is set to False.
If True, the regressors X will be normalized before regression by
Expand Down Expand Up @@ -110,7 +114,8 @@ class Lasso(Lasso_sklearn):

def __init__(self, alpha=1., max_iter=100, max_epochs=50000, p0=10,
verbose=0, tol=1e-4, prune=True, fit_intercept=True,
normalize=False, warm_start=False, positive=False):
weights=None, normalize=False, warm_start=False,
positive=False):
super(Lasso, self).__init__(
alpha=alpha, tol=tol, max_iter=max_iter,
fit_intercept=fit_intercept, normalize=normalize,
Expand All @@ -120,15 +125,16 @@ def __init__(self, alpha=1., max_iter=100, max_epochs=50000, p0=10,
self.p0 = p0
self.prune = prune
self.positive = positive
self.weights = weights

def path(self, X, y, alphas, coef_init=None, return_n_iter=True, **kwargs):
"""Compute Lasso path with Celer."""
results = celer_path(
X, y, "lasso", alphas=alphas, coef_init=coef_init,
max_iter=self.max_iter, return_n_iter=return_n_iter,
max_epochs=self.max_epochs, p0=self.p0, verbose=self.verbose,
tol=self.tol, prune=self.prune, positive=self.positive,
X_scale=kwargs.get('X_scale', None),
tol=self.tol, prune=self.prune, weights=self.weights,
positive=self.positive, X_scale=kwargs.get('X_scale', None),
X_offset=kwargs.get('X_offset', None))

return results
Expand Down
Loading

0 comments on commit 76090e0

Please sign in to comment.