## ADMM 

In [3]:
# define output hist struct 
struct Hist 
    objval 
    r_norm 
    s_norm 
    eps_pri 
    eps_dual 
end 

using LinearAlgebra 

# objective 
function objective(A, b, lambda, x, z) 

    p = ( 1/2 * sum( ( A*x - b ).^2 ) + lambda*norm(z,1) ) 

    return p 
end 

# test function 
A = rand(3,3)  
b = rand(3,1)  
lambda = 0.1  
x = rand(3,1)  
z = rand(3,1) 
p = objective(A, b, lambda, x, z) 
println("objective p = ", p)

# shrinkage 
function shrinkage(x, kappa) 

    z = 0*x ; 
    for i = 1:length(x) 
        z[i] = max( 0, x[i] - kappa ) - max( 0, -x[i] - kappa ) 
    end 

    return z 
end 

# test shrinkage 
kappa = 0.1 ; 
z = shrinkage(x, kappa) ; 
println("shrinkage z = ", z)

# cache factorization 
function factor(A, rho)

    m,n =  size(A) ; 
    if m >= n 
        C = cholesky( A'*A + rho*I ) 
    else
        C = cholesky( I + 1/rho*(A*A') )  
    end 
    L = C.L  
    U = C.U 

    return L, U 
end 

# test 
rho = 0.1 
L, U = factor(A, rho) 

# end 




objective p = 0.5838331746259419


shrinkage z = [0.6155934226635063; 0.12490046185571654; 0.876926091055854;;]


([1.0904441473716642 0.0 0.0; 1.2033143040059537 0.7048310367527997 0.0; 0.9941078076657761 0.037386891338611984 0.5418033629135732], [1.0904441473716642 1.2033143040059537 0.9941078076657761; 0.0 0.7048310367527997 0.037386891338611984; 0.0 0.0 0.5418033629135732])

### Actual ADMM 

In [4]:
function lasso_admm(A, b, lamda, rho, alpha) 
# ------------------------------------------------------------------------
# lasso  Solve lasso problem via ADMM
#
# [z, history] = lasso(A, b, lambda, rho, alpha);
#
# Solves the following problem via ADMM:
#
#   minimize 1/2*|| Ax - b ||_2^2 + \lambda || x ||_1
#
# The solution is returned in the vector x.
#
# history is a structure that contains:
#   objval   = objective function values 
#   r_norm   = primal residual norms 
#   s_norm   = dual residual norms 
#   eps_pri  = tolerances for the primal norms at each iteration
#   eps_dual = tolerance for dual residual norms at each iteration
#
# rho is the augmented Lagrangian parameter.
#
# alpha is the over-relaxation parameter (typical values for alpha are
# between 1.0 and 1.8).
# 
# Reference: 
# http://www.stanford.edu/~boyd/papers/distr_opt_stat_learning_admm.html
# ------------------------------------------------------------------------

    hist = Hist([], [], [], [], [])

    # define constants 
    max_iter = 1000  
    abstol   = 1e-4 
    reltol   = 1e-2 

    # data pre-processing 
    m, n = size(A) 
    Atb = A'*b                          # save matrix-vector multiply 

    # ADMM solver 
    x = 0*b  
    z = 0*b 
    u = 0*b 

    # cache factorization 
    L, U = factor(A, rho) 

    # begin iterations 
    for k = 1:max_iter 

        # x-update 
        q = Atb + rho*(z .- u)           # temp value 
        if m >= n                       # if skinny 
            x = U \ ( L \ q ) 
        else                            # if fat 
            x = q / rho - ( A' * ( U \ ( L \ (A*q) ) ) ) / rho^2 
        end 

        # z-update 
        z_old = z 
        x_hat = alpha*x + (1 .- alpha*z_old) 
        z = shrinkage(x_hat + u, lambda/rho) 

        # u-update 
        u = u + (x_hat .- z) 

        # diagnostics + termination checks 
        p = objective(A, b, lambda, x, z) 
        push!( hist.objval, p )
        push!( hist.r_norm, norm(x - z) )
        push!( hist.s_norm, norm( -rho*(z - z_old) ) )
        push!( hist.eps_pri, sqrt(n)*abstol + reltol*max(norm(x), norm(-z)) ) 
        push!( hist.eps_dual, sqrt(n)*abstol + reltol*norm(rho*u) ) 

        if hist.r_norm[k] < hist.eps_pri[k] && hist.s_norm[k] < hist.eps_dual[k] 
            break 
        end 

    end 

    return z, hist 
end 

# test 
x, hist = lasso_admm(A, b, lambda, 1.0, 1.0) 

([0.6967286567718355; 0.7135363547254007; 0.7305349360538247;;], Hist(Any[0.42022352709181676, 0.25580689098147424, 0.33908366388610217, 0.24233129839989787, 0.300062230997607, 0.24225902996699142, 0.2805599823974597, 0.245431707073546, 0.2703064759890963, 0.24869735908113821  …  0.2564778066548292, 0.25647780665482917, 0.2564778066548292, 0.2564778066548291, 0.2564778066548292, 0.25647780665482917, 0.2564778066548292, 0.2564778066548292, 0.2564778066548292, 0.2564778066548292], Any[1.5588457268119895, 0.34749100522715515, 1.1585240234311287, 0.07046031239351885, 0.912512658889767, 0.1730877257557556, 0.7580240522688378, 0.29107314947936186, 0.6609472609151177, 0.3669419745196143  …  0.4966350253203484, 0.49663502532034814, 0.49663502532034853, 0.496635025320348, 0.49663502532034864, 0.496635025320348, 0.49663502532034853, 0.49663502532034814, 0.4966350253203484, 0.4966350253203483], Any[2.071708609419906, 1.5013255105782728, 1.1905443513323937, 0.9442757348463305, 0.7489583509326901, 