Skip to content
Merged

Dev #11

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ Philosophy
License
=======

New BSD. Same one as SciPy.
New BSD

65 changes: 41 additions & 24 deletions copt/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numba import njit
from copt.utils import norm_rows

import concurrent.futures

@njit
def f_squared(p, y):
Expand Down Expand Up @@ -52,6 +53,11 @@ def prox_L1(step_size: float, x: np.ndarray, low: int, high: int):
x[j] = np.fmax(x[j] - step_size, 0) - np.fmax(- x[j] - step_size, 0)


@njit
def f_L1(x):
return np.sum(np.abs(x))


def compute_step_size(loss: str, A, alpha: float, step_size_factor=4) -> float:
"""
Helper function to compute the step size for common loss
Expand All @@ -77,7 +83,6 @@ def compute_step_size(loss: str, A, alpha: float, step_size_factor=4) -> float:
raise NotImplementedError('loss %s is not implemented' % loss)



def fmin_SAGA(
fun: Callable, fun_deriv: Callable, A, b, x0: np.ndarray,
alpha: float=0., beta: float=0., g_prox: Callable=None, step_size: float=-1,
Expand Down Expand Up @@ -156,7 +161,7 @@ def g_func(x, *args):
A = sparse.csr_matrix(A)
if g_blocks is None:
g_blocks = np.zeros(n_features, dtype=np.int64)
epoch_iteration, trace_loss = _epoch_factory_sparse_SAGA_fast(
epoch_iteration, trace_loss = _epoch_factory_sparse_SAGA(
fun, g_func, fun_deriv, g_prox, g_blocks, A, b, alpha, beta)

start_time = datetime.now()
Expand All @@ -171,31 +176,34 @@ def g_func(x, *args):

# .. iterate on epochs ..
for it in range(max_iter):
epoch_iteration(
x, memory_gradient, gradient_average, np.random.permutation(n_samples),
step_size)
# TODO: needs to be adapted in the sparse case
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for _ in range(n_jobs):
futures.append(executor.submit(
epoch_iteration, x, memory_gradient, gradient_average,
np.random.permutation(n_samples), step_size))
concurrent.futures.wait(futures)

grad = gradient_average + alpha * x
z = x - step_size * grad
g_prox(beta * step_size, z, 0, n_features)
certificate = np.linalg.norm(x - z)

if callback is not None:
callback(x)
if trace:
trace_x.append(x.copy())
trace_certificate.append(certificate)
trace_time.append((datetime.now() - start_time).total_seconds())

if verbose:
print(it, certificate)
print('Iteration: %s, certificate: %s' % (it, certificate))
if certificate < tol:
success = True
break
if trace:
if verbose:
print('.. computing trace ..')
print('Computing trace')
# .. compute function values ..
with futures.ThreadPoolExecutor(max_workers=n_jobs) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=n_jobs) as executor:
trace_func = [t for t in executor.map(trace_loss, trace_x)]

return optimize.OptimizeResult(
Expand Down Expand Up @@ -260,7 +268,7 @@ def h_func(*args): return 0
h_blocks = np.zeros(n_features, dtype=np.int64)
if g_blocks is None:
g_blocks = np.zeros(n_features, dtype=np.int64)
epoch_iteration, trace_loss = _epoch_factory_sparse_PSSAGA_fast(
epoch_iteration, trace_loss = _epoch_factory_sparse_PSSAGA(
fun, g_func, h_func, fun_deriv, g_prox, h_prox, g_blocks, h_blocks, A, b,
alpha, beta, gamma)

Expand Down Expand Up @@ -344,7 +352,7 @@ def _support_matrix(
return BS_data, BS_indices[:counter_indptr], BS_indptr


def _epoch_factory_sparse_SAGA_fast(
def _epoch_factory_sparse_SAGA(
f_func, g_func, f_prime, g_prox, g_blocks, A, b, alpha, beta):

A_data = A.data
Expand Down Expand Up @@ -374,42 +382,51 @@ def _epoch_factory_sparse_SAGA_fast(
idx = (d != 0)
d[idx] = n_samples / d[idx]

@njit
@njit(nogil=True)
def epoch_iteration_template(
x, memory_gradient, gradient_average, sample_indices, step_size):

# .. SAGA estimate of the gradient ..
grad_est = np.zeros(n_features)
incr = np.zeros(n_features, dtype=x.dtype)
x_hat = np.empty(n_features, dtype=x.dtype)

# .. inner iteration ..
for i in sample_indices:
# .. iterate on blocks ..
for g_j in range(BS_indptr[i], BS_indptr[i+1]):
g = BS_indices[g_j]

# .. iterate on features inside block ..
for b_j in range(RB_indptr[g], RB_indptr[g+1]):
x_hat[b_j] = x[b_j]
p = 0.
for j in range(A_indptr[i], A_indptr[i+1]):
j_idx = A_indices[j]
p += x[j_idx] * A_data[j]
p += x_hat[j_idx] * A_data[j]

grad_i = f_prime(p, b[i])

# .. update coefficients ..
for j in range(A_indptr[i], A_indptr[i+1]):
j_idx = A_indices[j]
grad_est[j_idx] = (grad_i - memory_gradient[i]) * A_data[j]
incr[j_idx] = (grad_i - memory_gradient[i]) * A_data[j]

# .. iterate on blocks ..
for g_j in range(BS_indptr[i], BS_indptr[i+1]):
g = BS_indices[g_j]

# .. iterate on features inside block ..
for b_j in range(RB_indptr[g], RB_indptr[g+1]):
grad_est[b_j] += d[g] * (
gradient_average[b_j] + alpha * x[b_j])
x[b_j] -= step_size * grad_est[b_j]
incr[b_j] += d[g] * (
gradient_average[b_j] + alpha * x_hat[b_j])
incr[b_j] = x_hat[b_j] - step_size * incr[b_j]

g_prox(step_size * beta * d[g], x, RB_indptr[g], RB_indptr[g+1])
g_prox(step_size * beta * d[g], incr, RB_indptr[g], RB_indptr[g+1])

# .. clean up ..
for b_j in range(RB_indptr[g], RB_indptr[g+1]):
grad_est[b_j] = 0
# update vector of coefficients
x[b_j] -= (x_hat[b_j] - incr[b_j])
incr[b_j] = 0

# .. update memory terms ..
for j in range(A_indptr[i], A_indptr[i+1]):
Expand All @@ -429,7 +446,7 @@ def full_loss(x):
return epoch_iteration_template, full_loss


def _epoch_factory_sparse_PSSAGA_fast(
def _epoch_factory_sparse_PSSAGA(
fun, g_func, h_func, f_prime, g_prox, h_prox, g_blocks, h_blocks, A, b, alpha,
beta, gamma):

Expand Down
58 changes: 58 additions & 0 deletions examples/plot_asynchronous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Asynchronous Stochastic Gradient
================================

"""
import numpy as np
import pylab as plt
from copt import stochastic
from scipy import sparse
colors = ['#7fc97f', '#beaed4', '#fdc086']

# generate a random large sparse matrix as input data
# and associated target labels
n_samples, n_features = 10000, 10000
X = sparse.random(n_samples, n_features, density=0.001, format='csr')
w = sparse.random(1, n_features, density=0.01).toarray().ravel()
y = np.sign(X.dot(w) + np.random.randn(n_samples))


alpha = 1.0 / n_samples
beta = 1.0 / n_samples
step_size = stochastic.compute_step_size('logistic', X, alpha)

max_iter = 25

opt_1cores = stochastic.fmin_SAGA(
stochastic.f_logistic, stochastic.deriv_logistic, X, y, np.zeros(X.shape[1]),
step_size=step_size, alpha=alpha, beta=beta, max_iter=max_iter, tol=-1,
trace=True, verbose=True, g_prox=stochastic.prox_L1, g_func=stochastic.f_L1)


opt_2cores = stochastic.fmin_SAGA(
stochastic.f_logistic, stochastic.deriv_logistic, X, y, np.zeros(X.shape[1]),
step_size=step_size, alpha=alpha, beta=beta, max_iter=max_iter, tol=-1,
trace=True, verbose=True, g_prox=stochastic.prox_L1, g_func=stochastic.f_L1, n_jobs=2)


opt_3cores = stochastic.fmin_SAGA(
stochastic.f_logistic, stochastic.deriv_logistic, X, y, np.zeros(X.shape[1]),
step_size=step_size, alpha=alpha, beta=beta, max_iter=max_iter, tol=-1,
trace=True, verbose=True, g_prox=stochastic.prox_L1, g_func=stochastic.f_L1, n_jobs=3)

fmin = min(np.min(opt_1cores.trace_func), np.min(opt_2cores.trace_func),
np.min(opt_3cores.trace_func))

plt.plot(opt_1cores.trace_time, opt_1cores.trace_func - fmin, lw=4, label='1 core',
color=colors[0])
plt.plot(opt_2cores.trace_time, opt_2cores.trace_func - fmin, lw=4, label='2 cores',
color=colors[1])
plt.plot(opt_3cores.trace_time, opt_3cores.trace_func - fmin, lw=4, label='3 cores',
color=colors[2])

plt.yscale('log')
plt.ylabel('Function suboptimality')
plt.xlabel('Time')
plt.grid()
plt.legend()
plt.show()
3 changes: 3 additions & 0 deletions tests/test_three_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
X = np.random.randn(n_samples, n_features)
y = np.sign(np.random.randn(n_samples))


# helper functions
def logloss(x):
return logistic._logistic_loss(x, X, y, 1.)


def fprime_logloss(x):
return logistic._logistic_loss_and_grad(x, X, y, 1.)[1]


def fused_lasso(x):
return np.abs(np.diff(x)).sum()

Expand Down