# Further Sudoku Adventures:  Using classes

In this notebook we use a record-keeping class
to implement a function that checks a possibly completed puzzle for consistency.

The code in the first cell is Peter Norvig's.  It is a small excrpt from his Soduku module.

In [119]:
def grid_values(grid):
    "Convert grid into a dict of {square: char} with '0' or '.' for empties."
    chars = [c for c in grid if c in digits or c in '0.']
    assert len(chars) == 81
    return dict(list(zip(squares, chars)))

def cross(A, B):
    """
    Cross product of elements in sequence A and elements in sequence B. 
    When a and b are strings, we just concatenate them.
    """
    return [a+b for a in A for b in B]

digits   = '123456789'
rows     = 'ABCDEFGHI'
cols     = digits
row_groups = ('ABC','DEF','GHI')
col_groups = ('123','456','789')
squares  = cross(rows, cols)
unitlist = ([cross(rows, c) for c in cols] +
           [cross(r, cols) for r in rows] +
            [cross(rs, cs) for rs in row_groups for cs in col_groups])

################ Display as 2-D grid ################

def display(values,indent = 3):
    """Display these values as a 2-D grid. 
    
    Repeat a string `S` `n` times by using `S * 3`.
    """
    # This loop: convert 0's to something less visually prominent (`.`)
    values = values.copy()
    for s in squares:
        if values[s] == '0':
            values[s] = '.'
    # Choos the width of a cell in the grid by finding the widest square name
    width = 1+max(len(values[s]) for s in squares)
    # Cook up a filler line ('-'s interrupted by '+'s) to be used to mark boxes
    line = (' '*(indent-1)) + '+' + '+'.join(['-'*(width*3)]*3) 
    line = line + '+'
    #  Cook up a boundary line for tosp, bottoms
    #capper = ' ' + ('-'*len(line))
    # Banner material, col labels
    print(line)
    print((' '*2)+'|',end='')
    for cg in col_groups:
        print(' '.join(cg), end = ' |')
    print()
    print(line)
    for r in rows:
        print(f'{r}', end = ' |')
        print(''.join(values[r+c].center(width)+('|' if c in '369' else '')
                      for c in cols))
        # Every third row, print filler line following row contents.
        if r in 'CF': print(line)
    print(line)
    print()

####################   Grids ##################################################################

grid1  = '003020600900305001001806400008102900700000008006708200002609500800203009005010300'
grid1_soln = '483921657967345821251876493548132976729564138136798245372689514814253769695417382'
grid2  = '003020600900305001001806400008102900700000008006708200002689500800203009005010300'
# Some illegal grids (they break the rules of Sudoku)
grid3  = '003020600900305001001806400008102900700000008006708200002609500800203089005010300'
grid4  = '003020609900305001001806400008102900700000008006708200002609500800203009005010300'
grid5  = '003020600900305061001806400008102900700000008006708200002609500800203009005010300'

The code in the next cell defines a class for checking Sudoku puzzle.  It is a class because
it defines data structures for keeping track of puzzle entries.

The key idea is that each square simulatneously satisfies 3 unit requirements.  So each square update
updates 3 unit records.  The update will fail if the value has been assigned to that unit before.
This is implemented by starting out each unit with a set containing all possible a values ancd
removing each value with the `set.remove` method as it is found.  The `set.remove` method raises
an exception if the value is not found in the set.

In [134]:
def make_val_set ():
    return set(digits)

def noncomprehension_make_row_col2box_dict (start_index=0):
    row_col2box = dict()
    for i,rg in enumerate(row_groups):
        for j,cg in enumerate(col_groups):
            for r in rg:
                for  c in cg:
                    row_col2box[r+c] = ((3*i) + j) + start_index
    return row_col2box

def make_row_col2box_dict (start_index=0):
    """
    Will this work for sudoku puzzles with 6 rows, 9 cols?
    """
    group_sz = len(row_groups[0])
    return {r+c:((group_sz*i) + j) + start_index for i,rg in enumerate(row_groups)
                                               for j,cg in enumerate(col_groups)
                                                           for r in rg
                                                               for  c in cg}

def make_index_dict (labels,start_index=0):
    return {l:(i+start_index) for (i,l) in enumerate(labels)}

class UnitTracker:
    
    """
    rows, cols, row_groups externally defined globals.
    
    unit_size must be the same for all three types.  unit_size must be a multiple
    of three.
    """
    
    unit_types = ['row','col', 'box']
    unit_size = len(rows)
    assert unit_size%3 == 0, 'The unit size must be a multiple of 3'
    # Record for mapping row,col, square to 27 distinct unit indices
    name2index = dict()
    # We have 27 units first 9 indices for rows
    name2index.update(make_index_dict(rows))
    # Next 9 indices for cols
    name2index.update(make_index_dict(cols,start_index=unit_size))
    # Last 9 indices for boxes
    name2index.update(make_row_col2box_dict (start_index= 2 * unit_size))
    
    def __init__ (self):
        
        pass
    
    def init_unit_records (self):
        # Record for keeping tracvk of vals supplied for all 27 units
        self.units =  [make_val_set() for i in range(27)]

    def update_val_set (self, name, value):
        idx = self.name2index[name]
        try:
           self.units[idx].remove(value)
           return True
        except KeyError:
            unit_type = self.unit_types[idx//self.unit_size]
            if unit_type == 'row':
                unit_idx = name
            else:
                unit_idx = (idx%self.unit_size) + 1
            print(f'There were multiple occurrences of {value} in {unit_type} {unit_idx}')
            return False
              
    def update_puzzle_rec (self, square_name, value):
        (row,col) = tuple(square_name)
        #row_ind, col_ind, box_ind = \
        #       self.name2index[row], self.name2index[col], self.name2index[square_name]
        if self.update_val_set (row, value):
            if self.update_val_set (col, value):
                return self.update_val_set (square_name, value)
            else:
                return False
        else:
            return False
        
    def check_puzzle (self, puzzle):
        self.init_unit_records()
        for (s,val) in puzzle.items():
            if val != '0':
                if not self.update_puzzle_rec(s,val):
                    return False
        return True
          

## Tests

In [136]:
ut = UnitTracker()

for grid in [grid3,grid4,grid5,grid1,grid2]:
    grid_dict = grid_values(grid)
    display(grid_dict)
    print(ut.check_puzzle(grid_dict))
    print()
print("Z")

  +------+------+------+
  |1 2 3 |4 5 6 |7 8 9 |
  +------+------+------+
A |. . 3 |. 2 . |6 . . |
B |9 . . |3 . 5 |. . 1 |
C |. . 1 |8 . 6 |4 . . |
  +------+------+------+
D |. . 8 |1 . 2 |9 . . |
E |7 . . |. . . |. . 8 |
F |. . 6 |7 . 8 |2 . . |
  +------+------+------+
G |. . 2 |6 . 9 |5 . . |
H |8 . . |2 . 3 |. 8 9 |
I |. . 5 |. 1 . |3 . . |
  +------+------+------+

There were multiple occurrences of 8 in row H
False

  +------+------+------+
  |1 2 3 |4 5 6 |7 8 9 |
  +------+------+------+
A |. . 3 |. 2 . |6 . 9 |
B |9 . . |3 . 5 |. . 1 |
C |. . 1 |8 . 6 |4 . . |
  +------+------+------+
D |. . 8 |1 . 2 |9 . . |
E |7 . . |. . . |. . 8 |
F |. . 6 |7 . 8 |2 . . |
  +------+------+------+
G |. . 2 |6 . 9 |5 . . |
H |8 . . |2 . 3 |. . 9 |
I |. . 5 |. 1 . |3 . . |
  +------+------+------+

There were multiple occurrences of 9 in col 9
False

  +------+------+------+
  |1 2 3 |4 5 6 |7 8 9 |
  +------+------+------+
A |. . 3 |. 2 . |6 . . |
B |9 . . |3 . 5 |. 6 1 |
C |. . 1 |8 . 6 |