In [1]:
import numpy as np
from utils.generator import generate_positive_definite_matrix, make_quadratic_function
from itertools import product
from scipy.optimize import minimize
from tqdm import tqdm

In [2]:
n = 20

In [3]:
Q = generate_positive_definite_matrix(n)
c = np.random.normal(size=n)
mu0 = 1.5 / np.min(np.linalg.eigvalsh(Q))

In [4]:
def brute_force(Q, c):
    n = len(c)
    obj_fn = make_quadratic_function(Q, c)
    all_possible_x = np.array(list(product([0, 1], repeat=n)))
    x_star = all_possible_x[np.argmin(np.array(list(map(obj_fn, all_possible_x))))]
    return x_star

In [5]:
def binary_admm(obj_fn, x0, y0, mu0, n_iter=100, gamma=0.9):
    n = len(x0)
    x = z = x0
    y = y0
    mu = mu0
    bnds = [(0,1)] * n
    for _ in range(n_iter):
        lx = lambda x: obj_fn(x) + np.dot(y, x - z) + 1/(2*mu) * (np.linalg.norm(x - z) ** 2)
        res = minimize(lx, x, bounds=bnds)
        x = res.x
        z = np.where(x ** 2 / (2 * mu) < (x - 1) ** 2 / (2 * mu) - y, 0, 1)
        y += (x - z) / mu
        mu *= gamma
    return x, y

In [6]:
correct = 0
n_iter = 100
for i in tqdm(range(n_iter)):
    Q = generate_positive_definite_matrix(n)
    c = np.random.normal(size=n)
    k = np.min(np.linalg.eigvalsh(Q))
    obj_fn = make_quadratic_function(Q, c)
    mu = 1/k * 1.5
    xhat, yhat = binary_admm(obj_fn, np.ones(n) * 0.5, np.zeros(n), mu, 100, gamma=0.9)
    xhat = np.where(xhat > 0.5, 1, 0)
    xstar = brute_force(Q, c)
    if np.sum(xhat ^ xstar):
        print(mu, np.sum(xhat ^ xstar))
    correct += 1 - np.sum(xhat ^ xstar) / n

print(correct / n_iter)

  6%|▌         | 6/100 [00:44<11:24,  7.28s/it]

100.67908587252519 1


  7%|▋         | 7/100 [00:51<11:19,  7.31s/it]

121.75210034671841 2


  9%|▉         | 9/100 [01:06<11:09,  7.36s/it]

57.8301635519124 2


 13%|█▎        | 13/100 [01:36<10:45,  7.42s/it]

145.9617731586846 1


 15%|█▌        | 15/100 [01:51<10:31,  7.43s/it]

143.87138576750877 1


 17%|█▋        | 17/100 [02:05<10:10,  7.36s/it]

127.91460778071435 1


 18%|█▊        | 18/100 [02:12<10:01,  7.34s/it]

149.85108128937628 1


 30%|███       | 30/100 [03:40<08:32,  7.32s/it]

95.38386653597304 5


 32%|███▏      | 32/100 [03:55<08:22,  7.38s/it]

116.00176475132767 1


 37%|███▋      | 37/100 [04:33<07:50,  7.47s/it]

146.28357283841726 1


 39%|███▉      | 39/100 [04:48<07:40,  7.55s/it]

87.56559752689441 2


 44%|████▍     | 44/100 [05:25<06:55,  7.41s/it]

146.4475903605521 1


 51%|█████     | 51/100 [06:16<06:00,  7.36s/it]

118.5573507769179 2


 52%|█████▏    | 52/100 [06:24<05:53,  7.37s/it]

148.25641177500458 1


 59%|█████▉    | 59/100 [07:16<05:07,  7.51s/it]

140.59304194236645 3


 60%|██████    | 60/100 [07:24<05:00,  7.50s/it]

147.6982147994334 1


 61%|██████    | 61/100 [07:31<04:51,  7.48s/it]

139.7174663349495 1


 62%|██████▏   | 62/100 [07:38<04:43,  7.46s/it]

121.33564347791807 1


 74%|███████▍  | 74/100 [09:10<03:23,  7.83s/it]

113.49599687497297 3


 75%|███████▌  | 75/100 [09:17<03:14,  7.78s/it]

146.54555091726098 1


 79%|███████▉  | 79/100 [09:47<02:39,  7.57s/it]

112.46620297584126 2


 82%|████████▏ | 82/100 [10:10<02:16,  7.60s/it]

149.99905724508787 2


 87%|████████▋ | 87/100 [10:47<01:35,  7.34s/it]

52.66141494075602 1


 88%|████████▊ | 88/100 [10:54<01:27,  7.33s/it]

149.7353991579245 1


 89%|████████▉ | 89/100 [11:02<01:20,  7.33s/it]

103.90564697995683 2


 93%|█████████▎| 93/100 [11:32<00:51,  7.36s/it]

123.0963304030446 1


 95%|█████████▌| 95/100 [11:47<00:37,  7.43s/it]

145.62334392683636 2


100%|██████████| 100/100 [12:23<00:00,  7.44s/it]

0.9785000000000005



