In [1]:
#Condat's Fast projection to L1 ball 
#For a fixed y in R^n and a>0, we solve x* = argmin ||x-y||_2 such that ||x||_1 <= a 

# This problem is equivalent to the following problem: 
# We first solve z* =  argmin ||z-abs(y)||_2 such that z^T 1 = a, z>=0 
# Then we obtain x* = sign(y) @ z* where @ represents elementwise multiplication 

import numpy as np 


VERBOSE = 0

def condat_simplex_project(y,a):
    assert type(y)==np.ndarray, "vector y should be a numpy ndarray, please check!"
    assert type(a)==float or type(a)==int, "radius a should be int or float, please check!"
    assert len(y)>0, "Length of vector y must be >0, please check!"
    
    if np.sum(y)<=a:
#         if VERBOSE >= 0:
#             print('projection point same as original point! returning the original point...')
        return y
    
    active_list = [y[0]] 
    active_list_first_iter = [] 
    
    rho = y[0]-a
    
    for i in range(1,len(y)):
        if y[i]>rho:
            rho += (y[i]-rho)/(len(active_list)+1)
            if rho>y[i]-a:
                active_list.append(y[i])
            else:
                if len(active_list_first_iter) > 0: 
                    active_list_first_iter = [ active_list_first_iter[0:], *active_list]
                else:
                    active_list_first_iter = [ *active_list]
                active_list = [y[i]]
                rho = y[i]-a
    if VERBOSE > 1:
        print('*************************')
        print('active list:')
        print(active_list)
        print('*************************')
        print('active list first iter:')
        print(active_list_first_iter)
        print('*************************')
    
    if len(active_list_first_iter)>0:
        for ele in active_list_first_iter:
            if ele>rho:
                active_list.append(ele)
                rho += (ele-rho)/len(active_list)
    if VERBOSE > 1:
        print('*************************')
        print('active list before traversal:')
        print(active_list)
        print('*************************')
    
    num_traversal = 0            
    changed = True
    while changed:
        changed = False
        i=0
        while i < len(active_list):
            if active_list[i]<=rho:
                rho += (rho - active_list[i])/(len(active_list)-1)
                active_list.pop(i)
                changed=True
            else:
                i+=1
        num_traversal += 1
        if VERBOSE > 1:
            print('*************************')
            print('traversal:',num_traversal, 'len(active list):', len(active_list))
            print('active list:')
            print(active_list)
            print('*************************')
    
    tau = rho
    if VERBOSE > 0:
        print('Lagrange multiplier: ',tau, 'size of active list:',len(active_list))
    
    proj_x = np.zeros(len(y))
    for i in range(len(y)):
        proj_x[i]=np.maximum(0.0,y[i]-tau)
    
#     if VERBOSE >= 0:
#         print('point of projection:')
#         print(proj_x)
#         print('sum of coordinates:',np.sum(proj_x))
        
    return proj_x

def projection_l1_ball(y,a):
    z = np.abs(y)
    proj_z = condat_simplex_project(z,a)
    proj_y = np.sign(y)*proj_z 
    return proj_y



# #y = np.array([3.3,1.2,0.8,1.2,3.3])
# #Example for projection onto simplex
# n = 30
# y = np.random.rand(n)
# print('################')
# print('original point:')
# print(y)
# radius = 3
# proj_y = condat_simplex_project(y,radius)
# print('################')

# #Example for projection onto L1 ball
# n = 50
# y = np.random.randn(n)
# print('################')
# print('original point:')
# print(y)
# radius = 5.5

# z = np.abs(y)
# proj_z = condat_simplex_project(z,radius)

# proj_y = np.sign(y)*proj_z    
# print('projection point onto L1 ball:')
# print(proj_y)
# print('################')


################
original point:
[0.52575944 0.41786517 0.38761423 0.90323707 0.23500298 0.82820306
 0.11848305 0.97440535 0.90505377 0.07843904 0.19910008 0.13960491
 0.66254455 0.40182677 0.72608664 0.81915671 0.95585068 0.81378199
 0.81983898 0.49738275 0.413342   0.61539789 0.66819728 0.82953336
 0.93051302 0.06754882 0.07697351 0.645791   0.85829009 0.96273879]
point of projection:
[0.         0.         0.         0.25928568 0.         0.18425166
 0.         0.33045395 0.26110237 0.         0.         0.
 0.01859315 0.         0.08213525 0.17520531 0.31189929 0.16983059
 0.17588758 0.         0.         0.         0.02424589 0.18558197
 0.28656162 0.         0.         0.0018396  0.2143387  0.31878739]
sum of coordinates: 2.999999999999999
################
################
original point:
[-0.43622077  0.12125096  0.12794784 -0.43827225  0.47890132  1.0311143
 -0.39811251  0.56078251 -1.43481668  0.91013339 -0.53109868  0.38993717
  0.00531226 -0.54010222 -2.03505287 -0.36971351 