Sudoku Solver (Jhalley)

In [1]:
from collections import Counter, defaultdict
from pprint import pprint as pp
import itertools
import copy

class Solution:
  def __init__(self, grid):
    self.boxes = {
        0: [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)],
        1: [(0,3), (0,4), (0,5), (1,3), (1,4), (1,5), (2,3), (2,4), (2,5)],
        2: [(0,6), (0,7), (0,8), (1,6), (1,7), (1,8), (2,6), (2,7), (2,8)],
        3: [(3,0), (3,1), (3,2), (4,0), (4,1), (4,2), (5,0), (5,1), (5,2)],
        4: [(3,3), (3,4), (3,5), (4,3), (4,4), (4,5), (5,3), (5,4), (5,5)],
        5: [(3,6), (3,7), (3,8), (4,6), (4,7), (4,8), (5,6), (5,7), (5,8)],
        6: [(6,0), (6,1), (6,2), (7,0), (7,1), (7,2), (8,0), (8,1), (8,2)],
        7: [(6,3), (6,4), (6,5), (7,3), (7,4), (7,5), (8,3), (8,4), (8,5)],
        8: [(6,6), (6,7), (6,8), (7,6), (7,7), (7,8), (8,6), (8,7), (8,8)],
    }
    self.grid = []

    for i in range(9):
      self.grid.append([j for j in grid[i]])
    self.validateInputGrid()
    pp([''.join(x) for x in self.grid])

  def genSolve(self):
    def genSolveHelper(grid):
      working_grid = copy.deepcopy(grid)
      # Generate potentials grid, but if there's any single potential square
      # we update working_grid and then redo potentials grid.
      while True:
        has_single_potential = False
        least_potentials_square = (-1,-1,10) # row, col, num_potentials
        potentials_grid = copy.deepcopy(working_grid)
        for row, col in itertools.product(range(9), range(9)):
          if working_grid[row][col] != '0':
            potentials_grid[row][col] = working_grid[row][col]
          else:
            potentials_grid[row][col] = self.genPotentialsSquare(working_grid, row, col)
            if len(potentials_grid[row][col]) == 1:
              working_grid[row][col] = potentials_grid[row][col].pop()
              has_single_potential = True
              break
            elif len(potentials_grid[row][col]) == 0: # invalid solution
              return False
            elif least_potentials_square[2] > len(potentials_grid[row][col]):
              least_potentials_square = (row, col, len(potentials_grid[row][col]))

        if not has_single_potential:
          break

      if self.isSolved(working_grid):
        return working_grid
      else: # Need to do recursion
        r, c, _ = least_potentials_square
        for potential in potentials_grid[r][c]:
          temp = copy.deepcopy(working_grid)
          temp[r][c] = potential
          maybe_solved = genSolveHelper(temp)
          if not maybe_solved:
            continue
          else:
            return maybe_solved
    
    return genSolveHelper(copy.deepcopy(self.grid))

  def genPotentialsSquare(self, grid, row, col):
    nums_in_row = self.getRow(grid, row)
    nums_in_col = self.getCol(grid, col)
    nums_in_box = self.getBox(grid, self.getBoxNum(row, col))
    return set('123456789')-(set(nums_in_row)|set(nums_in_col)|set(nums_in_box))

  def isSolved(self, grid):
    for i in range(9):
      # check rows
      if len(set(self.getRow(grid, i)) & set('123456789')) != 9:
        return False
      # check cols
      if len(set(self.getCol(grid, i)) & set('123456789')) != 9:
        return False
      #check boxes
      if len(set(self.getBox(grid, i)) & set('123456789')) != 9:
        return False

    return True

  def getRow(self, grid, row_num):
    return grid[row_num]

  def getCol(self, grid, col_num):
    return [grid[i][col_num] for i in range(9)]

  def getBox(self, grid, box_num):
    return [grid[point[0]][point[1]] for point in self.boxes[box_num]]

  def getBoxNum(self, row, col):
    for b in range(9):
      if (row, col) in set(self.boxes[b]):
        return b

  def validateInputGrid(self):
    # has to have 9 rows
    assert len(self.grid) == 9

    # has to have 9 cols
    for i in range(9):
      assert len(self.grid[i]) == 9
    
    for i in range(9):
      temp_row = self.getRow(self.grid, i)
      temp_col = self.getCol(self.grid, i)
      temp_box = self.getBox(self.grid, i)
      # no invalid characters
      assert len(set(temp_row) - set('0123456789')) == 0
      assert len(set(temp_col) - set('0123456789')) == 0
      assert len(set(temp_box) - set('0123456789')) == 0
      # no repeated characters (except 0)
      row_counts = Counter(temp_row)
      del row_counts['0']
      most_common_row = row_counts.most_common(1)[0][1] if row_counts.most_common(1) else 0
      col_counts = Counter(temp_col)
      del col_counts['0']
      most_common_col = col_counts.most_common(1)[0][1] if col_counts.most_common(1) else 0
      box_counts = Counter(temp_box)
      del box_counts['0']
      most_common_box = box_counts.most_common(1)[0][1] if box_counts.most_common(1) else 0
      assert most_common_row <= 1
      assert most_common_col <= 1
      assert most_common_box <= 1





In [2]:
sample_puzzle_valid = [
  '539728461',
  '471695382',
  '862134795',
  '356281974',
  '148957623',
  '927346158',
  '685419237',
  '294573816',
  '713862549'
]
sample_puzzle_invalid = [
  '539728461',
  '471695382',
  '862134795',
  '356281974',
  '142957623',
  '927346158',
  '685419237',
  '294573816',
  '713862549'
]
sample_puzzle_invalid2 = [
  '539728461',
  '471695382',
  '862134795',
  '356281974',
  '148957623',
  '327946158',
  '685419237',
  '294573816',
  '713862549'
]
input_test_easy = [
  '400307600',
  '003002800',
  '028510704',
  '100823900',
  '000750128',
  '004009000',
  '602048351',
  '030070400',
  '009000280',
]
input_test_medium = [
  '600130780',
  '000047106',
  '001000030',
  '080960000',
  '007208900',
  '000014050',
  '070000500',
  '402870000',
  '059021007',
]
input_test_hard = [
  '005261070',
  '140000000',
  '003050000',
  '000300580',
  '000904000',
  '017005000',
  '000030900',
  '000000016',
  '050192300',
]
input_test_very_hard = [
  '002000000',
  '000000913',
  '090300050',
  '000180040',
  '000004720',
  '073000000',
  '700000000',
  '010070000',
  '680000409',
]
input_test_hardest_ever = [
  '005300000',
  '800000020',
  '070010500',
  '400005300',
  '010070006',
  '003200080',
  '060500009',
  '004000030',
  '000009700',
]

#test = Solution(sample_puzzle_valid)
#print(test.isSolved())
#test2 = Solution(sample_puzzle_invalid)
#print(test2.isSolved())
#test3 = Solution(sample_puzzle_invalid2)
print('easy')
test4 = Solution(input_test_easy)
pp([''.join(x) for x in test4.genSolve()])
print('---')
print('medium')
test5 = Solution(input_test_medium)
pp([''.join(x) for x in test5.genSolve()])
print('---')
print('hard')
test6 = Solution(input_test_hard)
pp([''.join(x) for x in test6.genSolve()])
print('---')
print('very hard')
test7 = Solution(input_test_very_hard)
pp([''.join(x) for x in test7.genSolve()])
print('---')
print('hardest ever sudoku')
test8 = Solution(input_test_hardest_ever)
pp([''.join(x) for x in test8.genSolve()])

easy
['400307600',
 '003002800',
 '028510704',
 '100823900',
 '000750128',
 '004009000',
 '602048351',
 '030070400',
 '009000280']
['415387692',
 '763492815',
 '928516734',
 '157823946',
 '396754128',
 '284169573',
 '672948351',
 '831275469',
 '549631287']
---
medium
['600130780',
 '000047106',
 '001000030',
 '080960000',
 '007208900',
 '000014050',
 '070000500',
 '402870000',
 '059021007']
['645139782',
 '238547196',
 '791682435',
 '584963271',
 '317258964',
 '926714853',
 '173496528',
 '462875319',
 '859321647']
---
hard
['005261070',
 '140000000',
 '003050000',
 '000300580',
 '000904000',
 '017005000',
 '000030900',
 '000000016',
 '050192300']
['895261473',
 '142873695',
 '673459821',
 '264317589',
 '538924167',
 '917685234',
 '481736952',
 '329548716',
 '756192348']
---
very hard
['002000000',
 '000000913',
 '090300050',
 '000180040',
 '000004720',
 '073000000',
 '700000000',
 '010070000',
 '680000409']
['342591687',
 '568742913',
 '197368254',
 '926187345',
 '851934726',
 '4736258