In [None]:
# Run only once
%cd ..

### 2-SAT Solver and Posix planted solutions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from src.qubo_solver.utils import *
from src.qubo_solver.sat_solver import *
from src.qubo_solver.plot import *

In [None]:
if __name__ == "__main__":
    adjacency_list = {
        0: set([1]),
        1: set([2]),
        2: set([0]),
        3: set([4,7]),
        4: set([5]),
        5: set([0,6]),
        6: set([0,2,4]),
        7: set([3,5])
    }

    low_link = SCC(adjacency_list).get_scc()
    draw_2sat(adjacency_list, low_link)

In [None]:
expression = [2, 5, 0, 5, 0, 4, 1, 4]
adjacency_list = graph_from_2sat(expression)
low_link = SCC(adjacency_list).get_scc()
solution = get_solution(adjacency_list, low_link)
print(*solution)
draw_2sat(adjacency_list, low_link)

In [None]:
while True:
    expression = generate_expression(510, 405)
    adjacency_list = graph_from_2sat(expression)
    low_link = SCC(adjacency_list).get_scc()

    try:
        solution = get_solution(adjacency_list, low_link)
        draw_2sat(
            adjacency_list, 
            low_link,
            seed=121343
            )
        print(*solution)
        break
    except ValueError:
        pass

In [None]:
expression = generate_expression(int(1e5), int(1e5))
adjacency_list = graph_from_2sat(expression)
low_link = SCC(adjacency_list).get_scc()
solution = get_solution(adjacency_list, low_link)

In [None]:
n = int(1e2)
iters = int(4e2)
probs = [1.0]
ms = range(1, 2*n+2, n//10)

for m in ms[1:]:
    t = 0
    for _ in range(iters):
        expression = generate_expression(n, m)
        adjacency_list = graph_from_2sat(expression)
        low_link = SCC(adjacency_list).get_scc()
        try:
            solution = get_solution(adjacency_list, low_link)
            t += 1
        except ValueError:
            pass 
    probs.append(t/iters)

In [None]:

# probs += [0.0]
plt.plot(ms, probs, linewidth=3, color='red', alpha=0.75)
plt.xlim(0, 2*n)
plt.xlabel('Prob')
plt.ylabel("m")
plt.title("n=100")
# plt.ylim(0.0, 1.0)

### Gradient-based planting

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
def find_Q(x, reg, iters=int(1e3)):
    n = x.shape[0]
    Q = torch.rand(size=(n, n), requires_grad=True)
    W = torch.rand(size=(n, n))
    W = reg * W / torch.norm(W)
    optimiser = torch.optim.Adam([Q])
    logs = []

    for _ in range(iters):
        loss = torch.dot(x, Q @ x) + torch.norm(Q@W)**2
        loss.backward()
        optimiser.step()
        optimiser.zero_grad()
        logs.append(loss.detach().numpy())
            
    return Q.detach(), logs


def gen_all_binary_vectors(n: int) -> torch.Tensor:
    return ((torch.arange(2**n).unsqueeze(1) >> torch.arange(n-1, -1, -1)) & 1).float()


def is_solution(x: torch.Tensor, Q: torch.Tensor) -> bool:
    n = len(x)
    strings = gen_all_binary_vectors(n).T
    answers = torch.sum(strings * (Q @ strings), axis=0)
    
    return torch.dot(x, Q @ x) <= torch.min(answers)
    

In [None]:
n = 10
reg = 1.0

Q, logs = find_Q(x, reg=4, iters=int(1e4))
print(torch.norm(Q))
plt.plot(logs)

### Chook testing

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from chook.planters.wishart_planting import generate_problem

In [None]:
alpha = 0.3
n, m = 100, int(alpha*n)
Q = generate_problem(n, M=m)
plt.imshow(Q)