In [1]:
import numpy as np
from utils.generator import generate_positive_definite_matrix, make_quadratic_function
from itertools import product
from scipy.optimize import minimize
from tqdm import tqdm

In [2]:
n = 15

In [3]:
Q = generate_positive_definite_matrix(n)
c = np.random.normal(size=n)
mu0 = 1.5 / np.min(np.linalg.eigvalsh(Q))

In [4]:
def brute_force(Q, c):
    n = len(c)
    obj_fn = make_quadratic_function(Q, c)
    all_possible_x = np.array(list(product([0, 1], repeat=n)))
    x_star = all_possible_x[np.argmin(np.array(list(map(obj_fn, all_possible_x))))]
    return x_star

In [5]:
def binary_admm(obj_fn, x0, y0, mu0, n_iter=100, gamma=0.9):
    n = len(x0)
    x = z = x0
    y = y0
    mu = mu0
    bnds = [(0,1)] * n
    for _ in range(n_iter):
        lx = lambda x: obj_fn(x) + np.dot(y, x - z) + 1/(2*mu) * (np.linalg.norm(x - z) ** 2)
        res = minimize(lx, x, bounds=bnds)
        x = res.x
        z = np.where(x ** 2 / (2 * mu) < (x - 1) ** 2 / (2 * mu) - y, 0, 1)
        y += (x - z) / mu
        mu *= gamma
    return x, y

In [6]:
correct = 0
n_iter = 100
for i in tqdm(range(n_iter)):
    Q = generate_positive_definite_matrix(n)
    c = np.random.normal(size=n)
    k = np.min(np.linalg.eigvalsh(Q))
    obj_fn = make_quadratic_function(Q, c)
    mu = 1/k * 1.5
    xhat, yhat = binary_admm(obj_fn, np.ones(n) * 0.5, np.zeros(n), mu, 100, gamma=0.9)
    xhat = np.where(xhat > 0.5, 1, 0)
    xstar = brute_force(Q, c)
    if np.sum(xhat ^ xstar):
        print(mu, np.sum(xhat ^ xstar), np.linalg.cond(Q))
    correct += 1 - np.sum(xhat ^ xstar) / n

print(correct / n_iter)

  0%|          | 0/100 [00:00<?, ?it/s]

  9%|▉         | 9/100 [00:04<00:42,  2.16it/s]

116.79536899722052 2 356.0707401283331


 10%|█         | 10/100 [00:04<00:41,  2.15it/s]

148.1648545296715 2 587.1003252967416


 13%|█▎        | 13/100 [00:05<00:38,  2.23it/s]

111.77336675274421 5 379.5550223813773


 14%|█▍        | 14/100 [00:06<00:38,  2.24it/s]

148.40243534445173 1 465.510895373


 15%|█▌        | 15/100 [00:06<00:38,  2.18it/s]

149.99877026049552 1 580.6638841168723


 43%|████▎     | 43/100 [00:19<00:24,  2.31it/s]

88.13934543207085 1 359.53179254150814


 45%|████▌     | 45/100 [00:20<00:24,  2.20it/s]

131.79531005602723 1 518.7342785220218


 65%|██████▌   | 65/100 [00:29<00:16,  2.11it/s]

140.09883264070137 1 419.82436722481776


 69%|██████▉   | 69/100 [00:31<00:13,  2.22it/s]

146.445485469685 4 488.29406912921576


 74%|███████▍  | 74/100 [00:33<00:11,  2.35it/s]

85.51454404897676 1 339.53266200846235


 80%|████████  | 80/100 [00:36<00:09,  2.20it/s]

122.49690184817285 1 413.3386409113681


 89%|████████▉ | 89/100 [00:40<00:04,  2.28it/s]

115.63799175976277 2 509.1246854001138


 92%|█████████▏| 92/100 [00:41<00:03,  2.22it/s]

137.0764474173562 4 477.63460227908433


 94%|█████████▍| 94/100 [00:42<00:02,  2.12it/s]

64.19619894006634 2 197.80052718475008


100%|██████████| 100/100 [00:45<00:00,  2.22it/s]

0.9813333333333333



