## ADMM 

In [2]:
using LinearAlgebra 

# objective 
function objective(A, b, lambda, x, z) 

    p = ( 1/2 * sum( ( A*x - b ).^2 ) + lambda*norm(z,1) ) 

    return p 
end 

# test function 
A = rand(3,3)  
b = rand(3,1)  
lambda = 0.1  
x = rand(3,1)  
z = rand(3,1) 
p = objective(A, b, lambda, x, z) 
println("objective p = ", p)

# shrinkage 
function shrinkage(x, kappa) 

    z = 0*x ; 
    for i = 1:length(x) 
        z[i] = max( 0, x[i] - kappa ) - max( 0, -x[i] - kappa ) 
    end 

    return z 
end 

# test shrinkage 
kappa = 0.1 ; 
z = shrinkage(x, kappa) ; 
println("shrinkage z = ", z)

# cache factorization 
function factor(A, rho)

    m,n =  size(A) ; 
    if m >= n 
        C = cholesky( A'*A + rho*I ) 
    else
        C = cholesky( I + 1/rho*(A*A') )  
    end 
    L = C.L  
    U = C.U 

    return L, U 
end 

# test 
rho = 0.1 
L, U = factor(A, rho) 

# end 




objective p = 0.6745532015682119


shrinkage z = 

[0.26021863023068526; 0.5729073986522815; 0.7170654763947101;;]


([0.9375839696474381 0.0 0.0; 0.340629284695301 0.39189272386963125 0.0; 0.6597315747935305 0.2140180791722751 0.9137190976284412], [0.9375839696474381 0.340629284695301 0.6597315747935305; 0.0 0.39189272386963125 0.2140180791722751; 0.0 0.0 0.9137190976284412])

### Actual ADMM 

In [3]:
# define output hist struct 
struct Hist 
    objval 
    r_norm 
    s_norm 
    eps_pri 
    eps_dual 
end 
hist = Hist([], [], [], [], [])

function lasso_admm(A, b, lamda, rho, alpha) 
# ------------------------------------------------------------------------
# lasso  Solve lasso problem via ADMM
#
# [z, history] = lasso(A, b, lambda, rho, alpha);
#
# Solves the following problem via ADMM:
#
#   minimize 1/2*|| Ax - b ||_2^2 + \lambda || x ||_1
#
# The solution is returned in the vector x.
#
# history is a structure that contains:
#   objval   = objective function values 
#   r_norm   = primal residual norms 
#   s_norm   = dual residual norms 
#   eps_pri  = tolerances for the primal norms at each iteration
#   eps_dual = tolerance for dual residual norms at each iteration
#
# rho is the augmented Lagrangian parameter.
#
# alpha is the over-relaxation parameter (typical values for alpha are
# between 1.0 and 1.8).
# 
# Reference: 
# http://www.stanford.edu/~boyd/papers/distr_opt_stat_learning_admm.html
# ------------------------------------------------------------------------

    # define constants 
    max_iter = 1000  
    abstol   = 1e-4 
    reltol   = 1e-2 

    # data pre-processing 
    m, n = size(A) 
    Atb = A'*b                          # save matrix-vector multiply 

    # ADMM solver 
    x = 0*b  
    z = 0*b 
    u = 0*b 

    # cache factorization 
    L, U = factor(A, rho) 

    # begin iterations 
    for k = 1:max_iter 

        # x-update 
        q = Atb + rho*(z - u)           # temp value 
        if m >= n                       # if skinny 
            x = U \ ( L \ q ) 
        else                            # if fat 
            x = q / rho - ( A' * ( U \ ( L \ (A*q) ) ) ) / rho^2 
        end 

        # z-update 
        z_old = z 
        x_hat = alpha*x + (1 .- alpha*z_old) 
        z = shrinkage(x_hat + u, lambda/rho) 

        # u-update 
        u = u + (x_hat - z) 

        # diagnostics + termination checks 
        p = objective(A, b, lambda, x, z) 
        push!( hist.objval, p )
        push!( hist.r_norm, norm(x - z) )
        push!( hist.s_norm, norm( -rho*(z - z_old) ) )
        push!( hist.eps_pri, sqrt(n)*abstol + reltol*max(norm(x), norm(-z)) ) 
        push!( hist.eps_dual, sqrt(n)*abstol + reltol*norm(rho*u) ) 

        if hist.r_norm[k] < hist.eps_pri[k] && hist.s_norm[k] < hist.eps_dual[k] 
            break 
        end 

    end 

    return z, hist
end 

# test 
x, hist = lasso_admm(A, b, lambda, 1.0, 1.0) 

MethodError: MethodError: no method matching -(::Int64, ::Matrix{Float64})
For element-wise subtraction, use broadcasting with dot syntax: scalar .- array
Closest candidates are:
  -(::Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8}) at int.jl:85
  -(::T, !Matched::T) where T<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8} at int.jl:86
  -(!Matched::SparseArrays.AbstractSparseMatrixCSC, ::Array) at C:\Users\junet\AppData\Local\Programs\Julia-1.8.3\share\julia\stdlib\v1.8\SparseArrays\src\sparsematrix.jl:1834
  ...