In [1]:
import numpy as np
from scipy.optimize import minimize
from utils.generator import generate_positive_definite_matrix, make_quadratic_function
from itertools import product
from tqdm import tqdm

In [2]:
n = 5

In [3]:
Q = generate_positive_definite_matrix(n)
c = np.random.normal(size=n)

In [4]:
obj_fn = make_quadratic_function(Q, c)

In [5]:
def sqrt_alm(obj_fn, lam, mu):
    return lambda x: obj_fn(x) + np.dot(lam, np.sqrt(x * (1-x))) + 1 / (2*mu) * (np.sum(x * (1-x)))

In [6]:
mu0 = 1 / np.min(np.linalg.eigvalsh(Q))
lam0 = np.zeros(n)

In [7]:
def alm_sqrt(obj_fn, x0, lam0, mu0, n_iter=50, gamma=0.9):
    x = x0
    lam = lam0
    mu = mu0
    for _ in range(n_iter):
        alm = sqrt_alm(obj_fn, lam, mu)
        x = minimize(alm,x, bounds=[(0,1)]*n).x
        lam += np.sqrt(x * (1-x)) / mu
        mu *= gamma
    return x

In [8]:
def brute_force(Q, c):
    n = len(c)
    obj_fn = make_quadratic_function(Q, c)
    all_possible_x = np.array(list(product([0, 1], repeat=n)))
    x_star = all_possible_x[np.argmin(np.array(list(map(obj_fn, all_possible_x))))]
    return x_star

In [9]:
alm_sqrt(obj_fn, np.ones(n) * 0.5, lam0, mu0)

array([1., 1., 0., 0., 0.])

In [10]:
brute_force(Q, c)

array([1, 1, 0, 0, 0])

In [11]:
correct = 0
n_iter = 100
for i in tqdm(range(n_iter)):
    Q = generate_positive_definite_matrix(n)
    c = np.random.normal(size=n)
    k = np.min(np.linalg.eigvalsh(Q))
    obj_fn = make_quadratic_function(Q, c)
    mu = 1/k * 1.5
    xhat = alm_sqrt(obj_fn, np.ones(n) * 0.5, np.zeros(n), mu, 100, gamma=0.9)
    xhat = np.where(xhat > 0.5, 1, 0)
    xstar = brute_force(Q, c)
    if np.sum(xhat ^ xstar):
        print(mu, np.sum(xhat ^ xstar))
        continue
    correct += 1 - np.sum(xhat ^ xstar) / n

print(correct / n_iter)

 27%|██▋       | 27/100 [00:01<00:04, 16.88it/s]

76.91346058003346 2


 69%|██████▉   | 69/100 [00:03<00:01, 20.31it/s]

19.71495049809526 1


 96%|█████████▌| 96/100 [00:04<00:00, 21.40it/s]

135.5461399134302 2


100%|██████████| 100/100 [00:05<00:00, 19.79it/s]

0.97



