<a href="https://colab.research.google.com/github/nhhung1810/8-queens-SAT/blob/main/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A SAT Approach for 8 Queens Problem

## Preparation

In [None]:
!pip install python-sat
from pysat.formula import CNF
from pysat.solvers import Lingeling
from typing import List, Tuple
from heapq import *

Collecting python-sat
  Downloading python_sat-0.1.7.dev15-cp37-cp37m-manylinux2010_x86_64.whl (1.8 MB)
[?25l[K     |▏                               | 10 kB 25.8 MB/s eta 0:00:01[K     |▍                               | 20 kB 31.0 MB/s eta 0:00:01[K     |▋                               | 30 kB 19.7 MB/s eta 0:00:01[K     |▊                               | 40 kB 16.8 MB/s eta 0:00:01[K     |█                               | 51 kB 16.2 MB/s eta 0:00:01[K     |█▏                              | 61 kB 14.9 MB/s eta 0:00:01[K     |█▎                              | 71 kB 13.6 MB/s eta 0:00:01[K     |█▌                              | 81 kB 15.2 MB/s eta 0:00:01[K     |█▊                              | 92 kB 14.7 MB/s eta 0:00:01[K     |█▉                              | 102 kB 13.3 MB/s eta 0:00:01[K     |██                              | 112 kB 13.3 MB/s eta 0:00:01[K     |██▎                             | 122 kB 13.3 MB/s eta 0:00:01[K     |██▍                      

In [None]:
def num_attacking_pairs (status: List[int]) -> int:
  res = 0
  for i in range(8):
    if status[i] == -1:
      continue
    for j in range(i + 1, 8):
      if status[j] == -1:
        continue
      if status[i] == status[j] or i + status[i] == j + status[j] or i - status[i] == j - status[j]:
        res += 1

  return res

a = [1, 1, 1, 1, 1, 1, 1, 2]

print(num_attacking_pairs(a))


22


In [1]:
# A* solver
from heapq import *
from functools import total_ordering
from typing import Callable, List
from copy import deepcopy

PATH_COST = 1

class Heap():
    def __init__(self):
        self.arr = []
        
    def size(self):
        return len(self.arr)
    
    def clear(self):
        self.arr.clear()
        
    def isEmpty(self):
        return len(self.arr) == 0
    
    def push(self, cost, element):
        heappush(self.arr, (cost, element))
        pass
    
    def pop(self):
        res = heappop(self.arr)
        return res[1]
    
    def printHeap(self):
        print(self.arr)
        
@total_ordering
class Node():
    def __init__(self, state : list):
        self.state = state
        self.visited = False
        self.isInFrontier = False
        self.bestCost = 10**8
        self.isExpanded = False
        
    def __lt__(self, other):
        # lexicographical order
        return str(self.getStateLabel())< str(other.getStateLabel())
        
    def getState(self) -> list:
        return self.state
    
    def getStateLabel(self):
        res = ""
        for pos in self.state:
            if pos >= 0:
                res+=str(pos)
            else:
                res += "x"
        return res

class AStarSolver():
    def __init__(self, state : list, action : Callable, heuristic : Callable):
        self.init : Node = Node(state)
        self.action : Callable = action
        self.frontier = Heap()
        self.heuristic : Callable = heuristic
        self.stateDict = {}
        self.path = []
      
    def solve(self):
        self.path = []
        self.init.bestCost = 0
        self.frontier.push(0, self.init)
        # this 
        if self.heuristic(self.init.getState()):
            print("Initial state is already invalid, how can we solve it?")
            return
        while not self.frontier.isEmpty():
            node : Node = self.frontier.pop()
            node.isInFrontier = True
            self.path.append(node.getState())
            
            # Test if popped node is goal
            if self.goalTest(node):
                print("Final result: ", node.getState())
                return
            
            # Add node to expanded
            node.isExpanded = True
            for adjNode in self.generateNextNodes(node):
                if adjNode.isInFrontier:
                    # Update path cost
                    newCost = self.calculateCost(node, adjNode)
                    if(adjNode.bestCost > newCost):
                        adjNode.bestCost = newCost
                    continue
                elif (not adjNode.isInFrontier) and (not adjNode.isExpanded):
                    # Calculate new cost
                    newCost = self.calculateCost(node, adjNode)
                    adjNode.bestCost = newCost
                    # New node to the frontier
                    self.frontier.push(newCost, adjNode)
                    
        # If it go out of the loop, probably it fail
        print("Fail to find the result!!")
    
    def goalTest(self, node : Node) -> bool:
        state = node.getState()
        if self.heuristic(state) != 0:
            return False
        # Improvement: Check backward
        for i in range(len(state)-1, -1, -1):
            if(state[i] == -1): return False
        return True
    
    
    def calculateCost(self, parent, node):
        return parent.bestCost + PATH_COST + self.heuristic(node.getState())
    
    def generateNextNodes(self, node : Node) -> List[Node]:
        states = self.action(node.getState())
        result = []
        for newState in states:
            newNode = Node(newState)
            label = newNode.getStateLabel()
            if not label in self.stateDict:
                self.stateDict[label] = newNode
            result.append(self.stateDict[label])
        return result
    
    def getResult(self, isPrint : bool):
        if(len(self.path) == 0):
            print("Run the solve() function first!!")
            return []
        if isPrint:
            print("This is the path:")
            print(self.path)
        return self.path

# 0 -> 7, -1 for not have been replace
intiState = [6, -1, -1, -1, -1, -1, -1, -1]
# intiState = [0, -1, 7, 5, 2, -1, 1, 3]

def action(state : list) -> list:
    if(len(state) != 8):
        return []
    # Just make 8 state of that row
    for pos in range(len(state)):
        if (state[pos] == -1):
            newStates = []
            for i in range(8):
                tmp = deepcopy(state)
                tmp[pos] = i
                newStates.append(tmp)
            return newStates
    return []

def num_attacking_pairs (status: List[int]) -> int:
  res = 0
  for i in range(8):
    if status[i] == -1:
      continue
    for j in range(i + 1, 8):
      if status[j] == -1:
        continue
      if status[i] == status[j] or i + status[i] == j + status[j] or i - status[i] == j - status[j]:
        res += 1

  return res

def heuristic(state) -> int:
    return num_attacking_pairs(state)
            
solver  = AStarSolver(intiState, action, heuristic)
solver.solve()
res = solver.getResult(isPrint=False)
print("Count:", len(res))

Final result:  [0, 4, 7, 5, 2, 6, 1, 3]
This is the path:
[[0, -1, -1, -1, -1, -1, -1, -1], [0, 0, -1, -1, -1, -1, -1, -1], [0, 1, -1, -1, -1, -1, -1, -1], [0, 2, -1, -1, -1, -1, -1, -1], [0, 3, -1, -1, -1, -1, -1, -1], [0, 4, -1, -1, -1, -1, -1, -1], [0, 5, -1, -1, -1, -1, -1, -1], [0, 6, -1, -1, -1, -1, -1, -1], [0, 7, -1, -1, -1, -1, -1, -1], [0, 2, 0, -1, -1, -1, -1, -1], [0, 2, 1, -1, -1, -1, -1, -1], [0, 2, 2, -1, -1, -1, -1, -1], [0, 2, 3, -1, -1, -1, -1, -1], [0, 2, 4, -1, -1, -1, -1, -1], [0, 2, 5, -1, -1, -1, -1, -1], [0, 2, 6, -1, -1, -1, -1, -1], [0, 2, 7, -1, -1, -1, -1, -1], [0, 3, 0, -1, -1, -1, -1, -1], [0, 3, 1, -1, -1, -1, -1, -1], [0, 3, 2, -1, -1, -1, -1, -1], [0, 3, 3, -1, -1, -1, -1, -1], [0, 3, 4, -1, -1, -1, -1, -1], [0, 3, 5, -1, -1, -1, -1, -1], [0, 3, 6, -1, -1, -1, -1, -1], [0, 3, 7, -1, -1, -1, -1, -1], [0, 4, 0, -1, -1, -1, -1, -1], [0, 4, 1, -1, -1, -1, -1, -1], [0, 4, 2, -1, -1, -1, -1, -1], [0, 4, 3, -1, -1, -1, -1, -1], [0, 4, 4, -1, -1, -1, -1, -1], [

## Requirement a: Formulate the problem by specifying the following points.

(Khi nào cài xong thì ghi)

## Requirement b: Write CNF clauses to describe restrictions required when Florence places a queen in the cell[3][3]

First of all, denote `b[x][y]` as "There is a queen placed at the cell (x, y)" and `-b[x][y]` as "It's invalid to place a queen at the cell (x, y)".
If Florence places a queen in the cell (3, 3) all other cells in the horizontal, vertical, main & sub diagonal line are invalid cells to place the next queen. Therefore:

$b[3][3] \implies -b[0][3] \wedge -b[1][3] \wedge -b[2][3] \wedge -b[4][3] \wedge -b[5][3] \wedge -b[6][3] \wedge -b[7][3] \wedge $

$-b[3][0] \wedge -b[3][1] \wedge -b[3][2] \wedge -b[3][4] \wedge -b[3][5] \wedge -b[3][6] \wedge -b[3][7] \wedge $

$-b[0][0] \wedge -b[1][1] \wedge -b[2][2] \wedge -b[4][4] \wedge -b[5][5] \wedge -b[6][6] \wedge -b[7][7] \wedge $

$-b[0][6] \wedge -b[1][5] \wedge -b[2][4] \wedge -b[4][2] \wedge -b[5][1] \wedge -b[6][0]$

Which equivalents to:

$-b[3][3] ∨ (-b[0][3] \wedge -b[1][3] \wedge -b[2][3] \wedge -b[4][3] \wedge -b[5][3] \wedge -b[6][3] \wedge -b[7][3] \wedge $

$-b[3][0] \wedge -b[3][1] \wedge -b[3][2] \wedge -b[3][4] \wedge -b[3][5] \wedge -b[3][6] \wedge -b[3][7] \wedge $

$-b[0][0] \wedge -b[1][1] \wedge -b[2][2] \wedge -b[4][4] \wedge -b[5][5] \wedge -b[6][6] \wedge -b[7][7] \wedge $

$-b[0][6] \wedge -b[1][5] \wedge -b[2][4] \wedge -b[4][2] \wedge -b[5][1] \wedge -b[6][0])$

Transforming into CNF clause:

$ (-b[3][3] \vee -b[0][3]) \wedge $

$ (-b[3][3] \vee -b[1][3]) \wedge $

$ (-b[3][3] \vee -b[2][3]) \wedge $

$ (-b[3][3] \vee -b[4][3]) \wedge $

$ (-b[3][3] \vee -b[5][3]) \wedge $

$ (-b[3][3] \vee -b[6][3]) \wedge $

$ (-b[3][3] \vee -b[7][3]) \wedge $

$ (-b[3][3] \vee -b[3][0]) \wedge $

$ (-b[3][3] \vee -b[3][1]) \wedge $

$ (-b[3][3] \vee -b[3][2]) \wedge $

$ (-b[3][3] \vee -b[3][4]) \wedge $

$ (-b[3][3] \vee -b[3][5]) \wedge $

$ (-b[3][3] \vee -b[3][6]) \wedge $

$ (-b[3][3] \vee -b[3][7]) \wedge $

$ (-b[3][3] \vee -b[0][0]) \wedge $

$ (-b[3][3] \vee -b[1][1]) \wedge $

$ (-b[3][3] \vee -b[2][2]) \wedge $

$ (-b[3][3] \vee -b[4][4]) \wedge $

$ (-b[3][3] \vee -b[5][5]) \wedge $

$ (-b[3][3] \vee -b[6][6]) \wedge $

$ (-b[3][3] \vee -b[7][7]) \wedge $

$ (-b[3][3] \vee -b[0][6]) \wedge $

$ (-b[3][3] \vee -b[1][5]) \wedge $

$ (-b[3][3] \vee -b[2][4]) \wedge $

$ (-b[3][3] \vee -b[4][2]) \wedge $

$ (-b[3][3] \vee -b[5][1]) \wedge $

$ (-b[3][3] \vee -b[6][0]) $



## Requirement c: Create a function that returns the expected CNF set



In [None]:
def encode_coor (x: int, y: int, size: int) -> int:
  return x * size + y + 1

def decode_coor (v: int, size: int) -> Tuple[int, int]:
  v -= 1
  return v // size, v % size

In [None]:
def restriction_at (x: int, y: int, size: int) -> List[List[int]]:
  if x < 0 or size <= x or y < 0 or size <= y:
    return []

  res = []
  # Horizontal
  for i in range(size):
    if i == x:
      continue
    res.append([-1 * encode_coor(x, y, size), -1 * encode_coor(i, y, size)])

  # Vertical
  for j in range(size):
    if j == y:
      continue
    res.append([-1 * encode_coor(x, y, size), -1 * encode_coor(x, j, size)])

  # Main diagonal
  for i in range(size):
    j = i - (x - y)
    if j < 0 or size <= j or (i == x and j == y):
      continue

    res.append([-1 * encode_coor(x, y, size), -1 * encode_coor(i, j, size)])

  # Sub diagonal
  for i in range(size):
    j = (x + y) - i
    if j < 0 or size <= j or (i == x and j == y):
      continue

    res.append([-1 * encode_coor(x, y, size), -1 * encode_coor(i, j, size)])

  return res

In [None]:
def restrictions (level: int, size: int, selected: List[int] = [], excluded: List[int] = []) -> List[List[int]]:
  res = []
  for i in range(size):
    for j in range(size):
      res += restriction_at(i, j, size)

  if level == 1:
    for i in range(size):
      res.append([encode_coor(i, j, size) for j in range(size)])
  elif level == 2:
    candidates = []
    for v in range(1, 1 + size * size):
      if v in excluded:
        res.append([-1 * v])
      else:
        if v in selected:
          res.append([v])
        else:
          candidates.append(v)
    if len(candidates) > 0:
      res.append(candidates)

  return res

# restrictions(2, 4, [3, 4], [1, 2])

In [None]:
def solve (level: int = 1, size: int = 8) -> List[int]:
  if level == 1:
    res = restrictions(1, size)
    formula = CNF()
    for r in res:
      formula.append(r)

    l = Lingeling(bootstrap_with=formula.clauses, with_proof=True)
    if l.solve() == False:
      print("Failed, proof:", l.get_proof())
      return []
    else:
      print("Success")
      result = []
      for v in l.get_model():
        if v > 0:
          result.append(v)
      
      return result

  elif level == 2:
    result = []
    exclude = [[] for _ in range(size + 1)]
    cnt = 0
    while len(result) < size:
      # print(result, exclude)
      res = restrictions(2, size, result, exclude[len(result)])
      formula = CNF()
      for r in res:
        formula.append(r)

      l = Lingeling(bootstrap_with=formula.clauses, with_proof=True)
      if l.solve() == False:
        if len(result) == 0:
          print("Failed")
          return []
        print("Remove latest queen")
        exclude[len(result)] = []
        if len(result) > 0:
          exclude[len(result) - 1].append(result[-1])
          
        result = result[:-1]
      else:
        found = False
        for v in l.get_model():
          if v > 0 and v not in result:
            found = True
            result.append(v)
            print(f"Found the {len(result)}th queen at {str(decode_coor(v, size))}")

        if not found:
          print("wtf???")
          print(result, exclude)
          print(res)
          print(l.get_model())
          return []

    print("Success")
    return result

  else:
    print("Invalid level")
    return []

size = 4
res = solve(level = 1, size = size)
print([decode_coor(v, size) for v in res])
res = solve(level = 2, size = size) # Takes really long time
print([decode_coor(v, size) for v in res])

Success
[(0, 2), (1, 0), (2, 3), (3, 1)]
Found the 1th queen at (0, 0)
Found the 2th queen at (2, 1)
Found the 3th queen at (1, 3)
Remove latest queen
Remove latest queen
Found the 2th queen at (3, 1)
Found the 3th queen at (1, 2)
Remove latest queen
Found the 3th queen at (2, 3)
Remove latest queen
Remove latest queen
Found the 2th queen at (1, 2)
Found the 3th queen at (3, 1)
Remove latest queen
Remove latest queen
Found the 2th queen at (2, 3)
Found the 3th queen at (3, 1)
Remove latest queen
Remove latest queen
Found the 2th queen at (3, 2)
Found the 3th queen at (1, 3)
Remove latest queen
Remove latest queen
Found the 2th queen at (1, 3)
Found the 3th queen at (2, 1)
Remove latest queen
Found the 3th queen at (3, 2)
Remove latest queen
Remove latest queen
Remove latest queen
Found the 1th queen at (1, 0)
Found the 2th queen at (0, 2)
Found the 3th queen at (3, 3)
Remove latest queen
Found the 3th queen at (3, 1)
Found the 4th queen at (2, 3)
Success
[(1, 0), (0, 2), (3, 1), (2, 3)

In [None]:
#@title Default title text
