# About

Comparison of different algorithms for projection on $\ell_1$ ball.


In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils

mpl_stylesheet.banskt_presentation(splinecolor = 'black', dpi = 120, colors = 'kelly')

import sys
sys.path.append("../utils/")
import histogram as mpy_histogram
import simulate as mpy_simulate
import plot_functions as mpy_plotfn

In [2]:
v_orig = np.random.normal(0, 1, 10)

In [4]:
# def proj_l1ball_sort(y, a):
#     if np.sum(y) == a and np.alltrue(y >= 0):
#         return y
#     yabs = np.abs(y)
#     u = np.sort(yabs)[::-1]
#     ukvals = (np.cumsum(u) - a) / np.arange(1, y.shape[0] + 1)
#     K = np.max(np.where(ukvals < u))
#     tau = ukvals[K]
#     x = np.sign(y) * np.clip(yabs - tau, a_min=0, a_max=None)
#     return x

def proj_simplex_sort(y, a = 1.0):
    if np.sum(y) == a and np.alltrue(y >= 0):
        return y
    u = np.sort(y)[::-1]
    ukvals = (np.cumsum(u) - a) / np.arange(1, y.shape[0] + 1)
    K = np.nonzero(ukvals < u)[0][-1]
    tau = ukvals[K]
    x = np.clip(y - tau, a_min=0, a_max=None)
    return x

def proj_l1ball_sort(y, a = 1.0):
    return np.sign(y) * proj_simplex_sort(np.abs(y), a = a)

def l1_norm(x):
    return np.sum(np.abs(x))

In [5]:
l1_norm(proj_l1ball_sort(v_orig, 1.0))

1.0

In [6]:
l1_norm(v_orig)

7.817970379144475

In [7]:
proj_l1ball_sort(v_orig, 1.0)

array([-0.00168506, -0.        ,  0.        ,  0.        ,  0.        ,
       -0.        ,  0.        ,  0.        ,  0.        ,  0.99831494])

In [9]:
def proj_simplex_michelot(y, a = 1.0):
    auxv = y.copy()
    N = y.shape[0]
    rho = (np.sum(y) - a) / N
    istep = 0
    vnorm_last = l1_norm(auxv)
    while True:
        istep += 1
        allowed = auxv > rho
        auxv = auxv[allowed]
        nv = np.sum(allowed)
        vnorm = l1_norm(auxv)
        if vnorm == vnorm_last:
            break
        rho = (np.sum(auxv) - a) / nv
        vnorm_last = vnorm
    x = np.clip(y - rho, a_min = 0, a_max = None)
    return x
    
def proj_l1_michelot(y, a = 1.0):
    return np.sign(y) * proj_simplex_michelot(np.abs(y), a)

proj_l1_michelot(v_orig)

array([-0.00168506, -0.        ,  0.        ,  0.        ,  0.        ,
       -0.        ,  0.        ,  0.        ,  0.        ,  0.99831494])

In [10]:
def proj_simplex_condat(y, a = 1.0):
    auxv = np.array([y[0]])
    vtilde = np.array([])
    rho = y[0] - a
    N = y.shape[0]
    # Step 2
    for i in range(1, N):
        if y[i] > rho:
            rho += (y[i] - rho) / (auxv.shape[0] + 1)
            if rho > (y[i] - a):
                auxv = np.append(auxv, y[i])
            else:
                vtilde = np.append(vtilde, auxv)
                auxv = np.array([y[i]])
                rho = y[i] - a
    # Step 3
    if vtilde.shape[0] > 0:
        for v in vtilde:
            if v > rho:
                auxv = np.append(auxv, v)
                rho += (v - rho) / (auxv.shape[0])                
    # Step 4
    nv_last = auxv.shape[0]
    istep = 0
    while True:
        istep += 1
        to_remove = list()
        nv_ = auxv.shape[0]
        for i, v in enumerate(auxv):
            if v <= rho:
                to_remove.append(i)
                nv_ = nv_ - 1
                rho += (rho - v) / nv_
        auxv = np.delete(auxv, to_remove)
        nv = auxv.shape[0]
        assert nv == nv_
        if nv == nv_last:
            break
        nv_last = nv
    # Step 5
    x = np.clip(y - rho, a_min=0, a_max=None)
    return x

def proj_l1_condat(y, a = 1.0):
    return np.sign(y) * proj_simplex_condat(np.abs(y), a)

proj_l1_condat(v_orig)

array([-0.00168506, -0.        ,  0.        ,  0.        ,  0.        ,
       -0.        ,  0.        ,  0.        ,  0.        ,  0.99831494])

In [11]:
%%timeit

proj_l1_condat(v_orig)

84.8 µs ± 2.24 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
%%timeit
proj_l1ball_sort(v_orig)

56.4 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
v_orig2 = np.random.normal(0, 1, 100000)

In [14]:
%%timeit -n 1 -r 1
proj_l1_condat(v_orig2)

29.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [15]:
%%timeit -n 1 -r 1
proj_l1ball_sort(v_orig2)

12.6 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
