# Sudoku 9x9 - SAT

| **Puzzle**   | <img src="./assets/3.png" width="300"/> |
|--------------|-----------------------------------------|
| **Solution** | <img src="./assets/4.png" width="300"/> |

In [1]:
from itertools import combinations, product
from pycosat import solve as sat_solve

class Sudoku:
    '''
        SAT solver for sudoku 9x9
        x_ijk = 1 iff cell(i,j) contains digit k
    '''
    def __init__(self):
        pass

    ## Helpers
    def varnum(self, row, col, digit):
        '''Convert (row,col,digit) each in range [0,9] to varnum(x_num) in range [1, 999]'''
        assert row in range(1, 10) and col in range(1, 10)
        assert digit in range(1, 10)
        return 100 * row + 10 * col + digit
    def exactly_one_of(self, literals:list):
        clauses = [[l for l in literals]]
        for pair in combinations(literals, 2):
            clauses.append([-l for l in pair])
        return clauses

    ## Constraints
    def one_digit_in_every_cell(self):
        '''Exactly 1 digit in 1 cell'''
        clauses = []
        for row, col in product(range(1, 10), repeat=2):
            clauses += self.exactly_one_of([self.varnum(row, col, digit) for digit in range(1, 10)])
        return clauses
    def one_digit_in_every_row(self):
        '''Exactly 1 type of digit in 1 row'''
        clauses = []
        for row, digit in product(range(1, 10), repeat=2):
            clauses += self.exactly_one_of([self.varnum(row, col, digit) for col in range(1, 10)])
        return clauses
    def one_digit_in_every_col(self):
        '''Exactly 1 type of digit in 1 col'''
        clauses = []
        for column, digit in product(range(1, 10), repeat=2):
            clauses += self.exactly_one_of([self.varnum(row, column, digit) for row in range(1, 10)])
        return clauses
    def one_digit_in_every_block(self):
        '''Exactly 1 type of digit in 1 3x3 block'''
        clauses = []
        for row, col in product([1, 4, 7], repeat=2):
            for digit in range(1, 10):
                clauses += self.exactly_one_of([self.varnum(row+a, col+b, digit) for (a, b) in product(range(3), repeat=2)])
        return clauses

    ## Solver
    def solve_puzzle(self, puzzle):
        assert len(puzzle) == 9
        assert all(len(row) == 9 for row in puzzle)

        # Add constraints
        clauses = []
        clauses += self.one_digit_in_every_cell()
        clauses += self.one_digit_in_every_row()
        clauses += self.one_digit_in_every_col()
        clauses += self.one_digit_in_every_block()

        # Add puzzle constraints (preset ditgits)
        for row, col in product(range(1, 10), repeat=2):
            if puzzle[row - 1][col - 1] != "*":
                digit = int(puzzle[row - 1][col - 1])
                assert digit in range(1, 10)
                clauses += [[self.varnum(row, col, digit)]]

        # Solve
        solution = sat_solve(clauses)
        if isinstance(solution, str):
            print("No solution")
            return

        # Print ans
        solution = set(solution)
        for row in range(1, 10):
            if (row)%3 == 1: print('+---+---+---+')
            for col in range(1, 10):
                if (col)%3 == 1: print('|', end="")
                for digit in range(1, 10):
                    if self.varnum(row, col, digit) in solution:
                        print(digit, end="")
            print('|')
        print('+---+---+---+')

In [2]:
sudoku_solver = Sudoku()

In [3]:
sample_puzzle = [
    "53**7****",
    "6**195***",
    "*98****6*",
    "8***6***3",
    "4**8*3**1",
    "7***2***6",
    "*6****28*",
    "***419**5",
    "****8**79"
]
sudoku_solver.solve_puzzle(sample_puzzle)

+---+---+---+
|534|678|912|
|672|195|348|
|198|342|567|
+---+---+---+
|859|761|423|
|426|853|791|
|713|924|856|
+---+---+---+
|961|537|284|
|287|419|635|
|345|286|179|
+---+---+---+


In [4]:
difficult_puzzle = [
    "8********",
    "**36*****",
    "*7**9*2**",
    "*5***7***",
    "****457**",
    "***1***3*",
    "**1****68",
    "**85***1*",
    "*9****4**"
]
sudoku_solver.solve_puzzle(difficult_puzzle)

+---+---+---+
|812|753|649|
|943|682|175|
|675|491|283|
+---+---+---+
|154|237|896|
|369|845|721|
|287|169|534|
+---+---+---+
|521|974|368|
|438|526|917|
|796|318|452|
+---+---+---+
