# N-Queens MCMC P.2
## Estimating the number of solutions.

In [None]:
# Imports
import time, sys, timeit
from IPython.display import clear_output
import numpy as np
import itertools as it
import math

In [None]:
# For reproducibility
np.random.seed(2022)

In [None]:
# Initialisation
N = 1000
C = math.comb(N,2)
z0 = np.arange(1,N+1)
beta = 1

idx_pairs = np.array(list(it.combinations(z0,2))) - [1,1]
col_diff = np.array([j-i for (i,j) in idx_pairs])
np.random.shuffle(z0)

In [None]:
def swap(z, i, j):
    """
    Swaps the elements of z at indices i and j, then returns z. Inplace.
    """
    z[[i, j]] = z[j], z[i]
    return z

In [None]:
def threats(z, i):
    """
    Returns number of queens threatening queen i.
    """
    Q = np.delete(np.arange(N), i) # Other queens
    return np.sum(abs(Q-i)==abs(z[Q]-z[i]))

In [None]:
def loss_diff(z, i, j):
    """
    Given a state z and swap operation (i,j), calculates the change in loss.
    """
    old = threats(z,i) + threats(z,j)
    y = swap(z, i, j)
    new = threats(y,i) + threats(y,j)
    z = swap(y, i, j)

    return new - old

In [None]:
### Loss function runs in n(n+1)/2 steps.
def loss(z):
    """
    Interprets z as chessboard with N queens threatening each other diagonally.
    Counts the number of unique pairs of threatening queens.
    """
    # Compute pairwise differences in z.
    row_diff = np.array([abs(z[j]-z[i]) for (i,j) in idx_pairs])
    loss = np.sum(col_diff==row_diff)
    return loss

In [None]:
### Run search (NOT FEASIBLE)
MAX_ITERS=1000000
REG_ITERS = 1000
z = z0.copy()
I = np.append(idx_pairs, [[0, 0]],axis=0)
l = loss(z)
for t in range(1, MAX_ITERS):

    # Calculate loss
    if t % REG_ITERS == 0:
        
        # Print current loss
        clear_output(wait = True)
        print("t =", t, "| conflicts =", l)

    # Choose a random swap
    b = np.log(t**2/N)
    i, j = idx_pairs[np.random.choice(C, size=1)][0]
    diff = loss_diff(z,i,j)
    acc = min(1, np.exp(-b*diff))
    if np.random.rand() < acc:
        z = swap(z,i,j)
        l += diff
    
    # If a solution is found, exit.
    if l <= 0:
        break
    
if (loss(z) == 0):    
    print("Here's a valid solution: ", z, "\nFound after ", t, " steps. (beta = ", np.log(t), ")")