In [1]:
import numpy as np
from utils.generator import generate_positive_definite_matrix, make_quadratic_function
from itertools import product
from scipy.optimize import minimize
from tqdm import tqdm

In [2]:
n = 20

In [3]:
Q = generate_positive_definite_matrix(n)
c = np.random.normal(size=n)
mu0 = 1.5 / np.min(np.linalg.eigvalsh(Q))

In [4]:
def brute_force(Q, c):
    n = len(c)
    obj_fn = make_quadratic_function(Q, c)
    all_possible_x = np.array(list(product([0, 1], repeat=n)))
    x_star = all_possible_x[np.argmin(np.array(list(map(obj_fn, all_possible_x))))]
    return x_star

In [5]:
def binary_admm(obj_fn, x0, y0, mu0, n_iter=100, gamma=0.9):
    n = len(x0)
    x = z = x0
    y = y0
    mu = mu0
    bnds = [(0,1)] * n
    for _ in range(n_iter):
        lx = lambda x: obj_fn(x) + np.dot(y, x - z) + 1/(2*mu) * (np.linalg.norm(x - z) ** 2)
        res = minimize(lx, x, bounds=bnds)
        x = res.x
        z = np.where(x ** 2 / (2 * mu) < (x - 1) ** 2 / (2 * mu) - y, 0, 1)
        y += (x - z) / mu
        mu *= gamma
    return x, y

In [6]:
correct = 0
n_iter = 100
for i in tqdm(range(n_iter)):
    Q = generate_positive_definite_matrix(n)
    c = np.random.normal(size=n)
    k = np.min(np.linalg.eigvalsh(Q))
    obj_fn = make_quadratic_function(Q, c)
    mu = 1/k * 1.5
    xhat, yhat = binary_admm(obj_fn, np.ones(n) * 0.5, np.zeros(n), mu, 100, gamma=0.9)
    xhat = np.where(xhat > 0.5, 1, 0)
    xstar = brute_force(Q, c)
    if np.sum(xhat ^ xstar):
        print(mu, np.sum(xhat ^ xstar))
    correct += 1 - np.sum(xhat ^ xstar) / n

print(correct / n_iter)

  4%|▍         | 4/100 [00:30<12:20,  7.71s/it]

36.3727582066606 3


 15%|█▌        | 15/100 [01:56<11:00,  7.77s/it]

114.7361393506507 3


 16%|█▌        | 16/100 [02:03<10:49,  7.74s/it]

102.19056383291729 1


 17%|█▋        | 17/100 [02:11<10:40,  7.72s/it]

130.91650518134736 5


 19%|█▉        | 19/100 [02:27<10:26,  7.74s/it]

122.26756726528244 2


 21%|██        | 21/100 [02:42<10:03,  7.64s/it]

147.41912103010662 1


 25%|██▌       | 25/100 [03:11<09:16,  7.42s/it]

129.33915698101828 2


 30%|███       | 30/100 [03:49<08:49,  7.56s/it]

131.84739224240553 2


 31%|███       | 31/100 [03:57<08:42,  7.58s/it]

114.58657251611442 2


 34%|███▍      | 34/100 [04:20<08:23,  7.63s/it]

128.88532972258562 2


 37%|███▋      | 37/100 [04:44<08:20,  7.95s/it]

60.58815494617083 5


 38%|███▊      | 38/100 [04:52<08:13,  7.95s/it]

145.89288374477377 3


 41%|████      | 41/100 [05:15<07:36,  7.74s/it]

84.94075395725915 2


 47%|████▋     | 47/100 [06:00<06:41,  7.58s/it]

144.86407607888117 1


 48%|████▊     | 48/100 [06:08<06:28,  7.48s/it]

88.16201188177706 2


 50%|█████     | 50/100 [06:23<06:23,  7.67s/it]

148.1787814728162 1


 52%|█████▏    | 52/100 [06:39<06:15,  7.83s/it]

82.22653474517502 4


 53%|█████▎    | 53/100 [06:47<06:05,  7.77s/it]

90.79617134187528 3


 61%|██████    | 61/100 [07:46<04:44,  7.29s/it]

122.1899414924082 5


 68%|██████▊   | 68/100 [08:37<03:57,  7.41s/it]

142.35493265567894 2


 69%|██████▉   | 69/100 [08:45<03:50,  7.44s/it]

100.81140690774396 3


 72%|███████▏  | 72/100 [09:07<03:29,  7.49s/it]

79.37095508090604 2


 76%|███████▌  | 76/100 [09:39<03:09,  7.89s/it]

124.67323988150025 2


 84%|████████▍ | 84/100 [10:41<02:00,  7.54s/it]

148.38375282607674 1


 89%|████████▉ | 89/100 [11:18<01:21,  7.40s/it]

93.76004600982262 3


 93%|█████████▎| 93/100 [11:47<00:50,  7.27s/it]

149.96815461768497 2


 95%|█████████▌| 95/100 [12:02<00:37,  7.49s/it]

121.09150316283093 1


 97%|█████████▋| 97/100 [12:17<00:22,  7.52s/it]

146.62111534499485 1


 99%|█████████▉| 99/100 [12:33<00:07,  7.51s/it]

120.1985553261722 4


100%|██████████| 100/100 [12:40<00:00,  7.60s/it]

0.9650000000000002





In [11]:
from timeit import timeit

n = 50
time = 0

for i in tqdm(range(n_iter)):
    Q = generate_positive_definite_matrix(n)
    c = np.random.normal(size=n)
    k = np.min(np.linalg.eigvalsh(Q))
    obj_fn = make_quadratic_function(Q, c)
    mu = 1/k * 1.5
    t = timeit(lambda: binary_admm(obj_fn, np.ones(n) * 0.5, np.zeros(n), mu, 100, gamma=0.9), number=1)
    time += t
time / n_iter

100%|██████████| 100/100 [01:58<00:00,  1.19s/it]


1.1852028419999987