# Pliable Lasso

In [None]:
import numpy as np
import scipy as sp
from scipy import stats, linalg
import matplotlib
import matplotlib.pyplot as plt
from itertools import cycle, product

%matplotlib inline
%load_ext autoreload
%autoreload 1

# add parent directory to load spmimage
import sys
sys.path.append('..')

In [None]:
def _soft_threshold(X: np.ndarray, thresh: float) -> np.ndarray:
    return np.where(np.abs(X) <= thresh, 0, X - thresh * np.sign(X))

def solve_quad_eq(a, b, c):
    D = np.sqrt(b**2 - 4*a*c)
    x_1 = (-b + D) / (2 * a)
    x_2 = (-b - D) / (2 * a)
    return x_1, x_2

In [None]:
def _bccd(X, Z, y, lam, alpha, t_init=0.1):
    n_samples, n_features = X.shape[:2]
    K = Z.shape[1]
    W = [X[:, j].reshape(n_samples, 1) * Z for j in range(n_features)]
    beta = np.zeros(n_features)
    theta = np.zeros((K, n_features))
    ones = np.ones(n_samples)
    eps = 1e-3
    
    # least square regression
    Ztil = np.hstack((ones.reshape(-1, 1), Z))
    beta0_theta0 = linalg.solve(Ztil.T.dot(Ztil), Ztil.T.dot(y))
    beta0, theta0 = beta0_theta0[0], beta0_theta0[1:]
    
    y = y - beta0 * ones - Z.dot(theta0)
    r = y
    for j in cycle(range(n_features)):
        # print(j)
        # print(np.linalg.norm(r))
        if np.linalg.norm(r) < eps:
            break
        rbj = r + beta[j] * X[:, j] + W[j].dot(theta[:, j])
        if (abs(np.inner(X[:, j], rbj) / n_samples) <= (1 - alpha) * lam
            and np.linalg.norm(_soft_threshold(W[j].T.dot(rbj) / n_samples, alpha * lam))
                               <= 2.0 * (1 - alpha) * lam):
            beta[j] = 0
            theta[:, j] = 0
            continue

        beta_j = n_samples / (X[:, j]**2).sum() \
            * _soft_threshold(np.inner(X[:, j], rbj) / n_samples, (1 - alpha) * lam)
        if (np.linalg.norm(_soft_threshold(W[j].T.dot(rbj - beta_j * X[:, j])
                                           / n_samples, alpha * lam))
            <= (1 - alpha) * lam):
            r = r - (beta_j - beta[j]) * X[:, j]
            beta[j] = beta_j
            theta[:, j] = 0
            continue
        
        t = t_init
        grbeta = -1 / n_samples * np.inner(X[:, j], r)
        grtheta = -1 / n_samples * W[j].T.dot(r)
        while(1):
            # print(t)
            c = t * (1 - alpha) * lam
            g1 = abs(beta[j] - t * grbeta)
            g2 = np.linalg.norm(_soft_threshold(theta[:, j] - t * grtheta,
                                                t * alpha * lam))
            r1, r2 = solve_quad_eq(1, 2 * c, 2 * c * g2 - g1**2 - g2**2)
            a_list = [g1 * r1 / (c + r1), g1 * r2 / (c + r2),
                      g1 * r1 / (c + r2), g1 * r2 / (c + r1)]
            b_list = [r1 * (c - g2) / (c + r1), r2 * (c - g2) / (c + r2),
                      r1 * (c - g2) / (c + r2), r2 * (c - g2) / (c + r1)]
            min_val = 10e9
            a_hat = 0
            b_hat = 0
            for a, b in product(a_list, b_list):
                gamma = np.sqrt(a**2 + b**2)
                val = (abs((1 + c / gamma) * a - g1)
                       + abs((1 + c * (1 / b + 1 / gamma)) * b - g2))
                if val < min_val:
                    min_val = val
                    a_hat = a
                    b_hat = b
            if min_val > eps:
                print(min_val)
                print('Error: No solution was found in the equation.', file=sys.stderr)
                sys.exit(1)
            if a_hat < 0 or b_hat < 0:
                print('Error: Negative solution was found in the equation.', file=sys.stderr)
                sys.exit(1)
            gamma = np.sqrt(a_hat**2 + b_hat**2)
            beta_j = (beta[j] - t * grbeta) / (1 + c / gamma)
            theta_j = (_soft_threshold(theta[:, j] - t * grtheta, t * alpha * lam)
                       / (1 + c * (1 / b_hat + 1 / gamma)))
            r_new = r - (beta_j - beta[j]) * X[:, j] - W[j].dot(theta_j - theta[:, j])
            if (((r_new**2).sum() / n_samples - (r**2).sum() / n_samples
                 - 2.0 * ((beta_j - beta[j]) * grbeta
                          + np.inner(theta_j - theta[:, j], grtheta))
                 - ((beta_j - beta[j])**2 + ((theta_j - theta[:, j])**2).sum()) / t) < 0.0):
                beta[j] = beta_j
                theta[:, j] = theta_j
                r = r_new
                break
            t = t * 0.9
    print(beta)
    print(theta)

In [None]:
# Example
n = 20
p = 3
nz = 3
x = np.random.normal(0, 1, (n, p))
mx = np.mean(x, axis=0)
sx = np.std(x, axis=0)
x =stats.zscore(x)
z = np.random.normal(0, 1, (n, nz))
mz = np.mean(z, axis=0)
sz = np.std(z, axis=0)
z = stats.zscore(z)
y = 4 * x[:, 0] + 5 * x[:, 0] * z[:, 2] + 3 * np.random.normal(n)
_bccd(x, z, y, 1.0, 0.5)