In [None]:
import numpy as np
import matplotlib.pyplot as plt
from prox import prox_dp
from sklearn.datasets import make_spd_matrix,make_sparse_spd_matrix

In [None]:
class LossFunc:
    
    # initialization
    def __init__(self,Q,b,lam,tol=1e-10):

        self.Q = Q
        self.b = b
        self.lam = lam
        self.tol = tol
    
    def g(self,x):
        return 0.5*np.dot(x.T.dot(self.Q),x)-np.dot(self.b,x)

    def h(self,x):
        return self.lam*np.linalg.norm(x,1)

    def gGrad(self,x):
        return self.Q.dot(x) - self.b

    def obj(self,x):
        return self.g(x)+self.h(x)

    def Gt(self,x,t,idx):
        x_0 = x.copy()
        x[idx] = self.prox(x[idx]-t*self.gGrad(x)[idx],t)
        return (x_0 - x)/t
    
    def soft_threshold(self,a,b):
        if (np.abs(a)<=b):
            return 0
        if(a<-b):
            return a+b
        if(a>b):
            return a-b
    
    def prox(self,x,t):
        return self.soft_threshold(x,self.lam*t)


In [None]:
# Setup and Inital
n = 100
x = np.random.rand(n)
# Q = make_spd_matrix(n)
Q = make_sparse_spd_matrix(n,0.99)
b = np.random.randn(n)
lam = 0.1
bta = 0.8
epi = 1e-20

In [None]:
# Proximal GD Backtracking

L = LossFunc(Q,b,lam)
obj_0 = L.obj(x)
x_0 = x.copy()
cot1 = 0

while(True):
    for i in range(n):
        t = 1
        G = L.Gt(x,t,i)
        while(np.isnan(L.g(x-t*G))):
            t = bta*t
            G = L.Gt(x,t,i)
        while (L.g(x-t*G) > L.g(x)-t*np.dot(L.gGrad(x),G)+0.5*t*(np.linalg.norm(G)**2)):
            t = bta*t
            cot1 = cot1+1
            G = L.Gt(x,t,i)
        cot1 = cot1+1
        x[i] = L.prox(x[i]-t*L.gGrad(x)[i],t)
    obj_1 = L.obj(x)
    print(obj_1)
    if (obj_0 - obj_1) < epi:
        break
    obj_0 = obj_1

print(cot1)
print(x)

In [None]:
# Coordinate Proximal GD

L = LossFunc(Q,b,lam)
obj_0 = L.obj(x)
x = x_0.copy()
cot2 = 0

while(True):
    for i in range(n):
        t = 1/Q[i,i]
        x[i] = L.prox(x[i]-t*L.gGrad(x)[i],t)
        cot2 = cot2+1
    obj_1 = L.obj(x)
    
    if (obj_0 - obj_1) < epi:
        break
    
    obj_0 = obj_1

print(cot2)
print(x)