## "Lie Groups"

This is a puzzle from Microsft's Puzzlehunt 23 "Puzzle University". [Puzzle statement](https://puzzle.university/puzzle/lie-groups.html). [Official solution](https://puzzle.university/solution/lie-groups.html).

This is 9x9 [KenKen](https://en.wikipedia.org/wiki/KenKen) puzzle with an additional twist, which makes it much harder to solve. But not for Z3 :)

**Spoler warning!** Below is the full solution to the puzzle. Stop reading if you want to solve it yourself.

In [1]:
N = 9
groups = [
  ('9*', [(0,0), (0,1)]),
  ('21+', [(0,2), (1,2), (1,3)]),
  ('15+', [(0,3), (0,4)]),
  ('25+', [(0,5), (0,6), (1,5)]),
  ('9/', [(0,7), (0,8)]),
  ('9*', [(1,0),(2,0)]),
  ('54*', [(1,1),(2,1),(2,2)]),
  ('26+', [(1,4),(2,3),(2,4)]),
  ('12+', [(1,6),(2,6),(3,6)]),
  ('40*', [(1,7),(2,7)]),
  ('21+', [(1,8),(2,8),(3,8)]),
  ('10*', [(2,5),(3,5)]),
  ('6*', [(3,0),(3,1)]),
  ('2-', [(3,2),(3,3)]),
  ('13+', [(3,4),(4,4)]),
  ('17+', [(3,7),(4,7)]),
  ('12+', [(4,0),(5,0)]),
  ('8*', [(4,1),(5,1)]),
  ('3-', [(4,2),(4,3)]),
  ('56*', [(4,5),(4,6)]),
  ('13+', [(4,8),(5,7),(5,8)]),
  ('6+', [(5,2),(6,2)]),
  ('3/', [(5,3),(6,3)]),
  ('15+', [(5,4),(6,4)]),
  ('5-', [(5,5),(6,5)]),
  ('5-', [(5,6),(6,6)]),
  ('5*', [(6,0),(7,0)]),
  ('15+', [(6,1),(7,1)]),
  ('1-', [(6,7),(6,8)]),
  ('4*', [(7,2),(8,2)]),
  ('3-', [(7,3),(8,3)]),
  ('15+', [(7,4),(8,4)]),
  ('6-', [(7,5),(7,6)]),
  ('2/', [(7,7),(8,7)]),
  ('4-', [(7,8),(8,8)]),
  ('54*', [(8,0),(8,1)]),
  ('16+', [(8,5),(8,6)]),  
]   

assert(sum(len(g) for _, g in groups)) == 81

In [2]:
import z3
solver = z3.Solver()

# Digits without lies, 1-9, unique per row and per column.
digs = [[z3.Int(f"dig-{y}-{x}") for x in range(N)] for y in range(N)]
for i in range(9):
  for j in range(9):
    solver.add(digs[i][j] >= 1)
    solver.add(digs[i][j] <= 9)
for i in range(9):
  solver.add(z3.Distinct(*digs[i]))
  solver.add(z3.Distinct(*[digs[j][i] for j in range(N)]))

# Positions of lies in each row.
lie_poss = [z3.Int(f"lie_pos-{y}") for y in range(N)]
for i in range(9):
  solver.add(lie_poss[i] >= 0)
  solver.add(lie_poss[i] < 9)
solver.add(z3.Distinct(*lie_poss))
  
# Condition that "each digit lies exactly once".
# Note: we get the same solution even without this condition.
lying_dig_per_row = []
for y in range(N):
  lying_dig_per_row.append(sum([z3.If(lie_poss[y]==x, digs[y][x], 0) for x in range(N)]))
solver.add(z3.Distinct(*lying_dig_per_row))  

# Lie values per row (positive integers).
lie_vals = [z3.Int(f"lie_val-{y}") for y in range(N)] 
for y in range(N):
  solver.add(lie_vals[y] >= 1)

# Expressions representing actual values in grid, after correcting lies.
# These must satisfy KenKen arithmetic conditions.
kk_exprs = [
  [z3.If(lie_poss[y]==x, lie_vals[y], digs[y][x]) for x in range(N)]
  for y in range(N)
]
  
# Encode KenKen arithmetic conditions.  
control_grid = [["?" for _ in range(N)] for _ in range(N)]                       
for g in groups:
  for y,x in g[1]:
    assert 0<=x<N and 0<=y<N, str(g[1])
    control_grid[y][x]=g[0]  
  val = int(g[0][:-1])
  sign = g[0][-1:]
  pts = [kk_exprs[y][x] for y,x in g[1]]
  if sign == '=':
    assert len(pts) == 1
    solver.add(pts[0] == val) 
  elif sign == '+':
    solver.add(sum(pts) == val)
  elif sign == '-':
    assert len(pts) == 2
    diff = pts[0]-pts[1]
    solver.add(z3.Or(diff==val, diff==-val))    
  elif sign == '*':
    prod = 1
    for p in pts:
      prod *= p
    solver.add(prod == val)
  else:
    assert sign == '/'
    assert len(pts) == 2
    x, y = pts
    solver.add(z3.Or(val*x==y, val*y==x))    
    
# Print grid to visually verify that we entered it correctly.    
for row in control_grid:
  print('\t'.join(row)) 

9*	9*	21+	15+	15+	25+	25+	9/	9/
9*	54*	21+	21+	26+	25+	12+	40*	21+
9*	54*	54*	26+	26+	10*	12+	40*	21+
6*	6*	2-	2-	13+	10*	12+	17+	21+
12+	8*	3-	3-	13+	56*	56*	17+	13+
12+	8*	6+	3/	15+	5-	5-	13+	13+
5*	15+	6+	3/	15+	5-	5-	1-	1-
5*	15+	4*	3-	15+	6-	6-	2/	4-
54*	54*	4*	3-	15+	16+	16+	2/	4-


In [3]:
print(solver.check())

sat


In [4]:
m = solver.model()
sol_digs = [[m.eval(v).as_long() for v in row] for row in digs]
sol_digs

[[3, 5, 6, 8, 7, 4, 2, 1, 9],
 [9, 1, 8, 7, 4, 3, 6, 5, 2],
 [1, 6, 9, 4, 3, 2, 5, 8, 7],
 [2, 3, 7, 9, 8, 5, 1, 4, 6],
 [4, 2, 3, 6, 5, 8, 7, 9, 1],
 [8, 4, 2, 3, 6, 1, 9, 7, 5],
 [7, 8, 5, 1, 9, 6, 4, 2, 3],
 [5, 7, 4, 2, 1, 9, 3, 6, 8],
 [6, 9, 1, 5, 2, 7, 8, 3, 4]]

In [5]:
sol_lie_poss = [m.eval(v).as_long() for v in lie_poss]
print(sol_lie_poss)
sol_lie_vals = [m.eval(v).as_long() for v in lie_vals]
print(sol_lie_vals)

[1, 5, 3, 8, 7, 2, 0, 4, 6]
[3, 19, 19, 12, 13, 1, 1, 13, 9]


In [6]:
ans = []
for i in range(9):
  lying_digit = sol_digs[i][sol_lie_poss[i]]
  letter = chr(ord('A')-1+sol_lie_vals[i])
  ans.append((lying_digit, letter))
ans.sort()
print(''.join(x[1] for x in ans))

MASSCLAIM
