In [1]:
import os
import pickle
import skimage
import numpy as np
from scipy import sparse
import matplotlib.pyplot as plt

from scfw.frank_wolfe import frank_wolfe
from scfw.scopt import scopt
import scfw.phase_retrival as pr

In [2]:
def norm(x):
    return np.imag(x)**2 + np.real(x)**2

In [3]:
pic_data = np.load('../data/pic28.npy')

In [4]:
p = len(pic_data)
n = 5 * p

mu, sigma = 0, 1
#a = sparse.random(n, p, density=0.5, data_rvs=np.random.randn)
a = np.random.randn(n, p).astype('float16') 
A = np.array([np.dot(a_i.reshape(-1, 1), np.conj(a_i).reshape(1, -1)) for a_i in a])

In [5]:
x_sol = pic_data.flatten()
X_sol = np.dot(x_sol.reshape(-1, 1), np.conj(x_sol).reshape(1, -1))

lambda_val = [norm(np.vdot(a_i, x_sol)) for a_i in a]

y = np.array([np.random.poisson(lambda_i) for lambda_i in lambda_val])
c = np.mean(y)

x_0 = np.random.normal(0, 1, p)
X_0 = np.dot(x_0.reshape(-1, 1), np.conj(x_0).reshape(1, -1))

In [6]:
#first set of parameters
Mf = 2
nu = 3
#running parameters

terminate_tol = 1e-10

FW_params={
    'iter_FW': 100,
    'line_search_tol': 1e-10,
}


sc_params={
    #parameters for SCOPT
    'iter_SC': 1000,
    'Lest': 'estimate', #estimate L
    'use_two_phase': True,
    #FISTA parameters
    'fista_type': 'mfista',
    'fista_tol': 1e-5,
    'fista_iter': 1000,
    #Conjugate Gradient Parameters
    'conj_grad_tol': 1e-2,
    'conj_grad_iter': 100,
}

In [7]:
func_x = lambda X: pr.phase_val(A, X, y)
func_beta = lambda X, S, beta, extra_param, extra_param_s: pr.phase_val(A, (1 - beta) * X + beta * S, y)
grad_x = lambda X, trace_sum: pr.phase_gradient(A, X, y, trace_sum)
grad_beta = lambda X, S, beta, trace_sum, trace_sum_s: pr.phase_gradient(A, (1 - beta) * X + beta * S, y)
hess_mult_x = lambda S, trace_sum: pr.hess_mult(A, y, S, trace_sum)
hess_mult_vec_x = lambda S, trace_sum: pr.hess_mult_vec(A, y, S, trace_sum)
extra_func = lambda X: np.array([np.trace(A_i.dot(X)).real for A_i in A])
linear_oracle = lambda grad: pr.linear_oracle(grad, c)
prox_func = lambda X, L: pr.proj_map(X, c) #used for SCOPT

In [8]:
results = {}

In [9]:
#run_alpha_policies = ["standard", "line_search", "icml", "backtracking"]
run_alpha_policies = ["backtracking"]

for policy in run_alpha_policies:
    x, alpha_hist, Gap_hist, Q_hist, time_hist  = frank_wolfe(func_x,
                       func_beta,                                       
                       grad_x,
                       grad_beta,
                       hess_mult_x,
                       extra_func,                                                    
                       Mf,
                       nu,
                       linear_oracle,                                                    
                       X_0,
                       FW_params,
                       #hess=hess_x, 
                       #lloo_oracle=llo_oracle,                                                 
                       alpha_policy=policy,                                                    
                       eps=terminate_tol, 
                       print_every=1, 
                       debug_info=False)
    
    results[policy] = {
        'x': x,
        'alpha_hist': alpha_hist,
        'Gap_hist': Gap_hist,
        'Q_hist': Q_hist,
        'time_hist': time_hist,
    }

********* Algorithm starts *********
-0.10651486756010109


SystemExit: Error!

In [27]:
with open('./pic_res.pkl', 'wb') as f:
    pickle.dump(results, f)

In [28]:
results.keys()

dict_keys(['standard', 'icml'])

In [25]:
results['icml']

{'x': array([[ 1.77620885e+00, -1.81581028e+00, -1.36961784e+00, ...,
         -3.07557857e+00, -8.85648319e-02,  6.43768571e-01],
        [-1.81581028e+00,  1.85629465e+00,  1.40015413e+00, ...,
          3.14415008e+00,  9.05394276e-02, -6.58121702e-01],
        [-1.36961784e+00,  1.40015413e+00,  1.05609936e+00, ...,
          2.37154954e+00,  6.82915042e-02, -4.96403856e-01],
        ...,
        [-3.07557857e+00,  3.14415008e+00,  2.37154954e+00, ...,
          5.32549060e+00,  1.53353644e-01, -1.11471172e+00],
        [-8.85648319e-02,  9.05394276e-02,  6.82915042e-02, ...,
          1.53353644e-01,  4.41599504e-03, -3.20994096e-02],
        [ 6.43768571e-01, -6.58121702e-01, -4.96403856e-01, ...,
         -1.11471172e+00, -3.20994096e-02,  2.33327277e-01]]),
 'alpha_hist': [9.342377943120246e-20],
 'Gap_hist': [2.4193019663839593e+22],
 'Q_hist': [-53681919593.08092, -53681919593.08092],
 'time_hist': [0, 216.33529686927795]}