In [1]:
import numpy as np
import time
import matplotlib.pyplot as plt
%matplotlib inline

In [47]:
def SAGA(A, b, fx, gradf, prox, parameter):
    # Parameter setting
    n            = A.shape[0]                       # Number of observations
    d            = A.shape[1]                       # Decision variable dimension
    epoch_max    = parameter['epoch_max']           # Max epoch number
    gamma        = parameter['gamma']               # Learning rate
    tol          = parameter['tol']                 # convergence tolerance
    try:    
        x        = parameter['x0']                  # Initial condition
    except: 
        x        = np.zeros((d,1))
    try:
        use_perm = parameter['perm']                # if true use permutation; otherwise use randomization
    except:
        use_perm = True
    # Output initialization    
    info              = dict()
    info['iter_time'] = list()
    info['fx']        = list()
    # Initialization
    g_phi    = np.array([gradf(x,i) for i in range(n)])
    g_phi_av = gradf(x) / n
    flag     = True
    for epoch in range(epoch_max):        
        if not flag: break
        perm = np.random.permutation(n)
        for j in range(n):
            t = time.time()
            if epoch == 0:
                i = j
            else:
                if use_perm:
                    if epoch % 2 == 0:
                        i = j 
                    else:
                        i = perm[j]
                else:
                    i = np.random.randint(0, n)
            # Update the next iteration
            gx         = gradf(x,i)
            w_next     = x - gamma * (gx - g_phi[i,:] + g_phi_av)
            g_phi_av   = g_phi_av - g_phi[i,:] / n + gx / n
            g_phi[i,:] = gx             
            x_next     = prox(w_next, gamma)
            # Save information
            info['iter_time'].append(time.time() - t)
            info['fx'].append(fx(x))   
            print('Suboptimality is ' + str(abs(fx(x)-fx(x_nature)) ))
            # Check optimality condition
            if np.abs(fx(x_next)-fx(x)) <= tol:
                flag = False
                break
            # Prepare the next iteration
            x = x_next 
            
        
    info['epoch'] = epoch
    info['iter_time'] = iter_time
    info['fx'] = fx_all
    return (x, info)

In [48]:
# Example: Least squares
n        = 1000
d        = 500
A        = np.random.normal(0,1,(n,d))
x_nature = np.random.normal(0,1,(d,1))
x_nature = x_nature / np.linalg.norm(x_nature,2)
b        = np.dot(A,x_nature)
# b        = b + np.random.normal(0,np.sqrt((np.mean(b)**2 + np.var(b))/1000),(n,1))

In [49]:
fx    = lambda x: 0.5 * np.sum((np.dot(A,x) - b)**2)
gradf = lambda x, i=None : np.dot(A.T, np.dot(A,x) - b) if i == None else A[i,:] * (np.dot(A[i,:],x)-b[i])
prox  = lambda x, gamma=None: x

In [50]:
parameter              = dict()
parameter['epoch_max'] = 100
# mu = []
# L  = []
# for i in range(n):
#     w, v= np.linalg.eig(np.dot(A[i,:].T,A[i,:]))
#     mu.append(np.min(w))
#     L.append(np.max(w))
w, v                   = np.linalg.eig(np.dot(A.T,A))
mu                     = np.min(w)
L                      = np.max(w)
parameter['gamma']     = 10 / (2*(mu*n+L))
parameter['tol']       = 1e-6
parameter['x0']        = np.zeros((d,1))
parameter['perm']      = True

In [51]:
x, info = SAGA(A, b, fx, gradf, prox, parameter)

Suboptimality is 485.086560973
Suboptimality is 242504.419271
Suboptimality is 242465.561887
Suboptimality is 242426.708019
Suboptimality is 242387.858231
Suboptimality is 242349.011093
Suboptimality is 242310.168682
Suboptimality is 242271.330051
Suboptimality is 242232.495128
Suboptimality is 242193.666527
Suboptimality is 242154.840872
Suboptimality is 242116.015403
Suboptimality is 242077.196189
Suboptimality is 242038.381303
Suboptimality is 241999.569595
Suboptimality is 241960.758324
Suboptimality is 241921.961542
Suboptimality is 241883.161612
Suboptimality is 241844.366287
Suboptimality is 241805.571996
Suboptimality is 241766.781926
Suboptimality is 241727.986909
Suboptimality is 241689.20512
Suboptimality is 241650.436074


KeyboardInterrupt: 

In [34]:
x = np.random.normal(0,1,10)

In [35]:
x

array([ 0.15948377, -1.18945244,  0.60642859, -0.2723193 , -0.68921355,
        0.25135357,  1.13242419, -0.53999769, -0.70894418, -1.42978669])

In [40]:
np.linalg.norm(x / np.linalg.norm(x,1),2)

0.3664467077115009

In [37]:
np.linalg.norm(x,2)

2.5575796021610731