## Solving KenKen with z3+grilops

In [1]:
import grilops
from grilops.geometry import Point
import z3     

In [2]:
N = 9
groups = [
  ('8-', [(0,0), (1,0)]),
  ('2/', [(0,1),(0,2)]),
  ('3-', [(0,3),(0,4)]),
  ('90*', [(0,5),(0,6),(1,6)]),
  ('84*', [(0,7),(0,8),(1,7)]),
  ('17+', [(1,1),(1,2)]),
  ('56*', [(1,3),(1,4),(1,5)]),
  ('1-', [(2,0),(2,1)]),
  ('6*', [(2,2),(3,2),(3,3)]),
  ('2/', [(2,5),(2,6)]),
  ('12+', [(2,7),(3,7),(3,8)]),
  ('3-', [(3,5),(3,6)]),
  ('2/', [(4,0),(5,0)]),
  ('2/', [(4,2),(4,3)]), 
  ('60*', [(4,4),(4,5),(5,4)]),
  ('16+', [(4,6),(4,7),(4,8)]),
  ('30*', [(5,1),(6,1),(6,2)]),
  ('2-', [(5,2),(5,3)]),
  ('48*', [(5,7),(5,8)]),
  ('40*', [(6,0),(7,0),(8,0)]),
  ('3/', [(6,3),(6,4)]),
  ('13+', [(6,5),(7,4),(7,5)]),  
  ('3+', [(6,7),(7,7)]),  
  ('7-', [(6,8),(7,8)]),
  ('13+', [(7,3),(8,3)]),  
  ('2/', [(7,6),(8,6)]),
  ('40*', [(8,1),(8,2)]),  
  ('10+', [(8,4),(8,5)]),
  ('9*', [(8,7),(8,8)]),
]


# Initialize grid.
sym = grilops.make_number_range_symbol_set(1, N)
lattice = grilops.get_square_lattice(N)
sg = grilops.SymbolGrid(lattice, sym)

# Rows and columns have unique numbers.
rows = [[sg.grid[Point(y, x)] for x in range(N)] for y in range(N)]
for row in rows:
  sg.solver.add(z3.Distinct(*row))
columns = [[sg.grid[Point(y, x)] for y in range(N)] for x in range(N)]
for column in columns:
  sg.solver.add(z3.Distinct(*column))


control_grid = [["?" for _ in range(N)] for _ in range(N)]                       
for g in groups:
  for y,x in g[1]:
    assert 0<=x<N and 0<=y<N, str(g[1])
    control_grid[y][x]=g[0]  
  val = int(g[0][:-1])
  sign = g[0][-1:]
  pts = [Point(y,x) for y,x in g[1]]
  if sign == '=':
    assert len(pts) == 1
    sg.solver.add(sg.cell_is(pts[0], val)) 
  elif sign == '+':
    sg.solver.add(sum(sg.grid[p] for p in pts) == val)
  elif sign == '-':
    assert len(pts) == 2
    diff = sg.grid[pts[0]]-sg.grid[pts[1]]
    sg.solver.add(z3.Or(diff==val, diff==-val))    
  elif sign == '*':
    prod = 1
    for p in pts:
      prod *= sg.grid[p]
    sg.solver.add(prod == val)
  else:
    assert sign == '/'
    assert len(pts) == 2
    x, y = sg.grid[pts[0]], sg.grid[pts[1]]
    sg.solver.add(z3.Or(val*x==y, val*y==x))    
    
    
for row in control_grid:
  print('\t'.join(row))

8-	2/	2/	3-	3-	90*	90*	84*	84*
8-	17+	17+	56*	56*	56*	90*	84*	?
1-	1-	6*	?	?	2/	2/	12+	?
?	?	6*	6*	?	3-	3-	12+	12+
2/	?	2/	2/	60*	60*	16+	16+	16+
2/	30*	2-	2-	60*	?	?	48*	48*
40*	30*	30*	3/	3/	13+	?	3+	7-
40*	?	?	13+	13+	13+	2/	3+	7-
40*	40*	40*	13+	10+	10+	2/	9*	9*


In [3]:
sg.solve()

True

In [4]:
sg.print()

921856374
189427536
873961245
742189653
694235187
317542968
465398712
536714829
258673491


In [5]:
solved_grid = sg.solved_grid()
assert ''.join([str(solved_grid[Point(3,i)]) for i in range(N)]) == '742189653'