In [1]:
import numpy as np
from scipy.optimize import linprog
import scipy.sparse as spr
import itertools
from scipy.linalg import sqrtm
from jax.scipy.optimize import minimize
import jax.numpy as jnp
import jax
from itertools import product
import torch

In [2]:
import numpy as np

np.random.seed(1)
J = 500
K = 5
lambda_k = np.ones(K)

phi_j = np.random.normal(10, 300, size = [J,K])  

def v(B):
    if B.dtype == bool:
        return phi_j[B,:].sum(0)@lambda_k -  B.sum()**2
    elif B.dtype == int:
        return phi_j[B,:].sum(0)@lambda_k  - len(B)**2
    else:
        print("Error: B must be a boolean or integer array")



In [3]:
# array of size J with 1 in the 10th position
np.random.seed(1)
B = np.random.randint(2, size=J, dtype=bool)
print(v(B))
print(v(np.where(B)[0]))

-53015.568602015424
-53015.568602015424


In [4]:
def minimize_v():
    min_value = np.inf
    best_array = None

    # Generate all binary arrays of size J
    for binary_array in product([0, 1], repeat=J):
        B = np.array(binary_array, dtype=bool)
        current_value = v(B)
        
        # Check if the current value is less than the min_value
        if current_value < min_value:
            min_value = current_value
            best_array = B

    return min_value, best_array

# Find the minimizer and the maximum value
# min_val, minimizer = minimize_v()

# print("Minimum Value:", min_val)
# print("minimizer:", minimizer)

The Lovatz extension of $v: [n] \rightarrow \mathbb{R}$ is
$$ \hat v (z) = \sum_{i = 0}^n (z_{i} - z_{i+1}) f([i]) $$
if $z_0 = 1 \geq z_1 \geq \dots \geq z_n \geq 0 = z_{n+1}.$

In other words,
$$ \hat v (z) = \sum_{i = 0}^n (z_{\sigma_i} - z_{\sigma_{i+1}}) v([\sigma_i ]) $$
where $\sigma$ is any ordering permutation of $z$ concatenated to 0 and 1. 


In [5]:
def v_hat(z_j):
    sorted_z_id = np.argsort(z_j)[::-1]
    
    sorted_z_j = np.concatenate(([1],z_j[sorted_z_id], [0])) 
    # print(sorted_z_j)
    val = 0
    for j in range(J+1):
        # print(v(sorted_z_id[np.arange(j)]))
        val += v(sorted_z_id[np.arange(j)]) * (sorted_z_j[j] - sorted_z_j[j+1])

    return val


print(v(B))
print(v_hat(B * 1))
# for i in range(3):
#     z_check = np.random.randint(2, size=J, dtype=bool)
#     print(z_check)
#     print(v(z_check))
#     print( v_hat(z_check * 1) )


-53015.568602015424
-53015.568602015424


The subgradient of $\hat v$ coincides with the convex hull of the set of $g \in \mathbb{R}^n$ such that 
$$ g_{\sigma_i} = v([\sigma_i]) - v([\sigma_{i-1}]) $$
for some $\sigma$ ordering permutation of $z$ concatenated to 0 with $\sigma_0 = 0$.


In [6]:
def grad_v_hat(z_j):
    sorted_z_id = np.argsort(z_j)[::-1]

    grad = np.zeros(J)
    for j in range(0,J):
        grad[sorted_z_id[j]] = v(sorted_z_id[np.arange(j+1)]) - v(sorted_z_id[np.arange(j)])

    return grad

In [7]:
def mirror_descent(num_iterations, alpha):
    # Initialize z within the hypercube [0,1]^n
    np.random.seed(1)
    z = np.ones(J)/2 


    # Start Mirror Descent
    z_list = []
    iter = 0 
    for _ in range(num_iterations):
        grad = grad_v_hat(z) 

        # z_new = z * np.exp( - alpha * grad)
        z_new = z - alpha * grad / np.linalg.norm(grad)
        z_new = np.clip(z_new, 0, 1)

        z = z_new
        z_list.append(z)
        iter += 1

    z_star = np.array(z_list)[-int(np.floor(num_iterations)/30):,:].mean(0)
    bundle_star = np.array(z_star.round(0), dtype= bool)
    return z_star, bundle_star

In [16]:
R = np.sqrt(J)
# M = 500

# eps = 5
# precision = (M/eps)**2
precision = .3
num_iterations = int( 9* precision * R**2)
print(num_iterations)
alpha = R / np.sqrt(num_iterations)
print(alpha)


num_iterations  = 1000
alpha = 1

1350
0.6085806194501846


In [17]:
z_star, bundle_star = mirror_descent(num_iterations, alpha)

### create a function which checks that the entries of z_j are either close to 1 or 0
def check_z_j(z_j):
    return np.all(np.isclose(z_j, 0,atol=1e-01) | np.isclose(z_j, 1,atol=1e-01))
print(check_z_j(z_star))

False


In [18]:
z_star[np.logical_and(z_star > 0.1, z_star < 0.9)]

array([0.8401985 , 0.64882383])

In [19]:
z_star[(z_star > 0.05) & (z_star < 0.95)]

array([0.8401985 , 0.64882383])

In [20]:
v_hat(z_star), v(bundle_star)

(-220140.60865027763, -220365.55600183434)

In [13]:
print(bundle_star.sum())
print(J)

449
500


In [14]:
v(np.ones(J, dtype=bool))

-207202.9691422304

In [15]:
# print(v(minimizer))