In [1]:
import jax
import jax.numpy as jnp

In [2]:
def soft_threshold(beta, gamma):
    return jnp.sign(beta) * (jnp.abs(beta) - gamma) * (jnp.abs(beta) > gamma)

In [3]:
def quadratic_minimizer_1d(j, beta, Q, l):
    return jnp.where(Q[j,j] != 0,
                     -(2 * (Q[:, j] @ beta - Q[j,j] * beta[j]) + l[j]) / (2 * Q[j,j]),
                     0)

In [4]:
def quadratic_threshold_1d(j, Q, w, lam):
    return jnp.where(Q[j,j] != 0,
                    (lam * w[j]) / (2 * Q[j,j]),
                    0)

In [5]:
@jax.jit
def step_quadratic_lasso_cd(beta, Q, l, w, lam, alternate_conditions = None):
    beta_new = beta.copy()
    cond = beta if alternate_conditions is None else alternate_conditions
    for j in range(len(beta_new)):
        update = jnp.where(cond[j] != 0,
                            jnp.where(w[j] != 0,
                                      soft_threshold(quadratic_minimizer_1d(j, beta, Q, l),
                                                     quadratic_threshold_1d(j, Q, w, lam)),
                                      quadratic_minimizer_1d(j, beta, Q, l)),
                            beta[j])
        beta_new = beta_new.at[j].set(update)
    return beta_new

In [6]:
# vectorized - all coordinate updates performed at once... faster but not guaranteed to converge!
# @jax.jit
# def step_quadratic_lasso_cd(beta, Q, l, w, lam):
#     unpenalized = (-2 * (Q.T @ beta - jnp.diag(Q) * beta) - l) / (2 * jnp.diag(Q))
#     treshold = lam * w / (2 * jnp.diag(Q))
#     return soft_threshold(unpenalized, treshold)

In [7]:
def has_converged(beta, beta_new, tolerance):
    return jnp.abs(beta_new - beta).sum() < tolerance

In [8]:
def fails_quadratic_kkt0(beta, Q, l, w, lam, tolerance):
    return jnp.where((jnp.abs(2 * Q @ beta + l + lam * w * jnp.sign(beta))) < tolerance, 0, 1)

In [9]:
def fails_quadratic_kkt1(beta, Q, l, w, lam, tolerance):
    return jnp.where((jnp.abs(2 * Q @ beta + l) - lam * w) < tolerance, 0, 1)

In [10]:
# TODO: Currently oesn't work if there are any zeros on the diagonal
def fit_quadratic_lasso(beta_0, Q, l, w, lam, max_iter_outer, max_iter_inner, tolerance_cd, tolerance_kkt):
    beta = beta_0.copy()

    for _ in range(max_iter_outer):
        # run coordinate descent on the active set until convergence
        for _ in range(max_iter_inner):
            beta_prev = beta.copy()
            beta = step_quadratic_lasso_cd(beta_prev, Q, l, w, lam)
            if has_converged(beta, beta_prev, tolerance_cd):
                break
        
        # cycle through the (nontrivial) null set once to perturb the stationary solution
        cond = jnp.where((beta == 0) & (w != 0), 1, 0)
        beta = step_quadratic_lasso_cd(beta, Q, l, w, lam, cond)

    # now check KKT conditions
        kkt0 = fails_quadratic_kkt0(beta, Q, l, w, lam, tolerance_kkt)
        kkt1 = fails_quadratic_kkt1(beta, Q, l, w, lam, tolerance_kkt)
        kkt = jnp.where(beta != 0, kkt0, kkt1)
        if kkt.sum() == 0:
            break

    return beta

In [11]:
def mls(y, yhat):
    return jnp.mean(jnp.square(y - yhat))

In [12]:
def bce(y, yhat):
    return (y * jnp.log(yhat)).sum() + ((1-y) * (1 - jnp.log(yhat)).sum( )) / -len(y)

In [13]:
def get_loss_fn(loss):
    if loss in ['linear', 'mls']:
        return mls
    if loss in ['logistic', 'bce']:
        return bce
    return loss

In [14]:
def fit_penalized(beta_0, 
                  X, 
                  y, 
                  loss, 
                  lam, 
                  alpha, 
                  max_iter_outer, 
                  max_iter_inner, 
                  tolerance_cd, 
                  tolerance_kkt, 
                  tolerance_loss,
                  model = None):
    if not model:
        model = get_model(loss)
    beta = beta_0.copy()
    loss_fn = get_loss_fn(loss)

    for _ in max_iter_outer:
        # apply local quadratic approximation
        g = get_loss_gradient(beta, X, y, loss)
        H = get_loss_hessian(beta, X, y, loss)
        w = jpn.zeros(len(beta))

        Q = H/2 + (1 - alpha) * lam * jnp.diag(w) # for adding ridge regression
        l = g - H @ beta + alpha * w @ get_penalty_gradient(beta, lam, gamma, tau) - lam * w * jnp.sign(beta)

        # solve the local quadratic approximation
        beta_new = fit_quadratic_lasso(beta, Q, l, w, lam * alpha, max_iter_outer = max_iterouter, max_iter_inner = max_iter_inner, tolerance_cd = tolerance_cd, tolerance_kkt = tolerance_kkt)

        # if the solution didn't lower the loss, find one that does with a golden section search
        if loss_fn(model(beta_new, X), y) > loss_fn(model(beta, X), y) + tolerance_loss:
            
            

SyntaxError: incomplete input (899020045.py, line 33)

In [15]:
beta = jnp.array([1.,1.])
Q = jnp.array([[1,2],[2,5]])
l = jnp.array([1,1])
w = jnp.array([0,1.])
lam = 0.1
j = 0

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [16]:
# approximate solution
opt = fit_quadratic_lasso(beta, Q, l, w, lam, max_iter_outer = 100, max_iter_inner = 100, tolerance_cd = 1e-6, tolerance_kkt = 1e-6)

In [17]:
# correct solution is [-1.4, 0.45] (according to the notorious pen & paper algorithm)
opt

Array([-1.399998  ,  0.44999918], dtype=float32)