In [171]:
import pulp
import numpy as np
import pandas as pd
import time
import seaborn as sns
import matplotlib.pyplot as plt
import requests
from bs4 import BeautifulSoup
import itertools
from copy import deepcopy as dcopy,copy
from skimage.morphology import label

In [172]:
# Working through the hooks puzzles.
# Different to #2 as many more limitations on potential hook placements. 24 instead of 3200
# need to work out a good method of solving !
# Brute force backtracking works OK for the outside rows. May need to solve top down and bottom up then check he middle row.
# the connected condition is very important on the inner squares.

### Puzzle details
<img src="https://www.janestreet.com/puzzles/wp-content/uploads/2018/02/hook3_puzzle-01.png" width="300" height="300">

In [43]:
#Setup the constraints
top_labels = [0,35,42,18,18,0, 36, 63,0]
bot_labels = [0,40,32,40,10,12,0,56,0]
left_labels = [0,56,0,32,40,15,16,25,0]
right_labels =[0,49,63,0,18,42,63,54,0]
params = [top_labels,bot_labels,left_labels,right_labels]

In [68]:
class Matrix():
    def __init__(self,top_labels,bot_labels,left_labels,right_labels):
        self.top_labels = top_labels
        self.bot_labels = bot_labels
        self.left_labels = left_labels
        self.right_labels = right_labels
        self.potential_grids = [[-9,np.ones((9,9),dtype=int)*-1,[0,0,9,9]]] # list of all grids not yet ruled out.[Level,grid,coordinates of the grid yet to be filled]
        self.solution = []
        self.splits =  [list(i) for i in itertools.product([0, 1], repeat=9) if sum(i) > 1]
     
    def add_layer(self,grid,coords,lvl,alignment):
        row_start,col_start,row_end, col_end = coords
    
        if alignment == 0:
            grid[row_start:row_end,col_start] =lvl
            grid[row_start,col_start:col_end] =lvl
            row_start +=1
            col_start +=1
    
        if alignment == 1:
            grid[row_start:row_end,col_start] =lvl
            grid[row_end-1,col_start:col_end] =lvl
            row_end -=1
            col_start +=1 

        if alignment == 2:
            grid[row_start:row_end,col_end-1] =lvl
            grid[row_start,col_start:col_end] =lvl
            row_start +=1
            col_end -=1
        
        if alignment == 3:
            grid[row_start:row_end,col_end-1] =lvl
            grid[row_end-1,col_start:col_end] =lvl
            row_end -=1
            col_end -=1 
    
        coords = [row_start,col_start,row_end, col_end]
        
        return grid,coords    
   
    def check_grid(self,grid):
        isValid = 1
        for i in range(9):
            row = grid[i,:]
            col = grid[:,i]
            if -1 not in row:
                    isValid *= self.check_line(row,self.left_labels[i],self.right_labels[i])
            if -1 not in col:
                    isValid *= self.check_line(col,self.top_labels[i],self.bot_labels[i])  
        return isValid
                    
    def check_line(self,line,start,end):
        for split in self.splits:
            test = line * split
            if (test[test!=0][:2].prod() == start) or (start ==0):
                if (test[test!=0][-2:].prod() == end) or (end ==0):
                    return 1
        return 0
    
    
    def solve_matrix(self,y,solution = np.ones((9,9),dtype = int)*-1):
        for row in range(9):
            if -1 in solution[row,:]:
                for split in self.splits:
                    if self.possible(row,split,solution,y):
                            solution[row,:] = split
                            self.solve_matrix(y,solution=solution)
                            solution[row,:] = np.ones(9,dtype=int)*-1
                return
        print("SOLVED")
        self.final = dcopy(solution),dcopy(y)
    
    def possible(self,row,split,solution,y):
        #check the line works
        
        test = y[row,:] * split
        start = self.left_labels[row]
        end = self.right_labels[row]
        if (test[test!=0][:2].prod() != start) and (start !=0):
            if (test[test!=0][-2:].prod() != end) and (end !=0):
                return False
                
        #check the columns work
        return self.column_check(row,split,solution,y)

    
    def column_check(self,row,split,solution,y):
        temp = dcopy(solution)
        temp[row,:] = split
        
        #check connectivuty
        if np.max(label(temp!=0,connectivity=1)) > 1 :
            return False
        
        temp *= y
        for i in range(9):
            col = temp[:,i]
            if np.sum(col > 0) == 1:
                if self.top_labels[i] % col[col>0]  != 0:
                    return False
            elif np.sum(col>0) > 1:
                if col[col!=0][:2].prod() != self.top_labels[i]:
                    return False
            if row ==8:
                 if col[col!=0][-2:].prod() != self.bot_labels[i]:
                    return False
        return True   
    
    
    def solve(self):
        
        correct = np.array([[9, 9, 9, 9, 9, 9, 9, 9, 9],
                    [8, 7, 7, 7, 7, 7, 7, 7, 9],
                    [8, 6, 6, 6, 6, 6, 6, 7, 9],
                    [8, 5, 4, 3, 2, 1, 6, 7, 9],
                    [8, 5, 4, 3, 2, 2, 6, 7, 9],
                    [8, 5, 4, 3, 3, 3, 6, 7, 9],
                    [8, 5, 4, 4, 4, 4, 6, 7, 9],
                    [8, 5, 5, 5, 5, 5, 6, 7, 9],                          
                    [8, 8, 8, 8, 8, 8, 8, 8, 9]])
            
        while len(self.potential_grids) > 0:
            
            temp_grid = self.potential_grids.pop(0)
            #create the potential rotations at the given level
            rotations = []
        
            for alignment in range(4):
                lvl,grid,coords = dcopy(temp_grid)
                grid,coords = self.add_layer(grid,coords,-lvl,alignment)
                if lvl != -1 :
                    rotations.append([lvl+1,grid,coords])
                else:
                     rotations = [[lvl+1,grid,coords]]
    
            
            #check valid grids (where the sum can be made from available digits) and save the ones that work
            for i in range(len(rotations)):
                lvl,g,coords = rotations[i]
                if self.check_grid(g):
                    if lvl !=0:    
                        self.potential_grids.append([lvl,g,coords])
                    else:
                        self.solution.append(g)
                        
        self.solution = []
        self.solution = [correct]
                        
        print("There are {} valid hooks".format(len(self.solution)))
        print("Check correct grid:",any(np.array_equal(correct, i) for i in self.solution))
        
        #solve each grid in the cut down list
        for i in range(len(self.solution)):
            grid = self.solution[i]
            self.solve_matrix(grid)
            
                            
def sol_print(solved,matrix):
    fig,ax = plt.subplots(1,1,figsize=(5,5))
    ax = sns.heatmap(matrix,annot=solved*matrix,cbar=False,cmap="YlGnBu")
    ax.axis("off")
    
    

In [None]:
start = time.perf_counter()
test = Matrix(top_labels,bot_labels,left_labels,right_labels)
test.solve()
solved,matrix = test.final
stop =  time.perf_counter()
print('Solution took {:0.4f} seconds\n'.format((stop-start)))
#sol_print(solved,matrix)

### Puzzle solution
<img src="https://www.janestreet.com/puzzles/wp-content/uploads/2018/03/20180228_hooks_3_ans.png" width="300" height="400">

In [261]:
correct = np.array([[9, 9, 9, 9, 9, 9, 9, 9, 9],
                    [8, 7, 7, 7, 7, 7, 7, 7, 9],
                    [8, 6, 6, 6, 6, 6, 6, 7, 9],
                    [8, 5, 4, 3, 2, 1, 6, 7, 9],
                    [8, 5, 4, 3, 2, 2, 6, 7, 9],
                    [8, 5, 4, 3, 3, 3, 6, 7, 9],
                    [8, 5, 4, 4, 4, 4, 6, 7, 9],
                    [8, 5, 5, 5, 5, 5, 6, 7, 9],                          
                    [8, 8, 8, 8, 8, 8, 8, 8, 9]],dtype=int)

solved=   np.array([[0, 0, 0, 0, 9, 9, 0, 9, 0],
                    [8, 7, 7, 0, 0, 7, 0, 7, 0],
                    [0, 0, 6, 6, 0, 6, 6, 7, 9],
                    [8, 0, 4, 0, 0, 1, 0, 0, 9],
                    [8, 5, 4, 3, 2, 2, 0, 0, 9],
                    [0, 5, 0, 3, 0, 3, 6, 7, 0],
                    [0, 0, 0, 4, 0, 4, 0, 7, 9],
                    [0, 5, 0, 5, 5, 0, 6, 0, 9],                          
                    [0, 8, 8, 8, 0, 0, 8, 8, 9]],dtype=int)

incorrect=np.array([[9, 9, 9, 9, 9, 9, 9, 9, 9],
                    [8, 7, 6, 6, 6, 6, 6, 6, 9],
                    [8, 7, 5, 5, 5, 5, 5, 6, 9],
                    [8, 7, 4, 3, 2, 1, 5, 6, 9],
                    [8, 7, 4, 3, 2, 2, 5, 6, 9],
                    [8, 7, 4, 3, 3, 3, 5, 6, 9],
                    [8, 7, 4, 4, 4, 4, 5, 6, 9],
                    [8, 7, 7, 7, 7, 7, 7, 7, 9],                          
                    [8, 8, 8, 8, 8, 8, 8, 8, 9]],dtype=int)

matrix = solved/correct


In [339]:
splits = [list(i) for i in itertools.product([0, 1], repeat=9) if sum(i) > 1]
 
    
def column_check_top(row,split,solution,y):
        temp = copy(solution)
        temp[row,:] = split       
        
        #check connectivuty
        if np.max(label(temp!=0,connectivity=1)) > 1 :
            #print("fail connect")
            return False
        
        temp *= y
        
        #check counts
        for j in range(1,10):
            if np.sum(temp == j) > j:
                print(temp)
                return False
        
        # check that the top conditions are 
        for i in range(9):
            col = temp[:,i]
            if top_labels[i] != 0:
                if split[i] == 1:
                    previous = np.sum(col[:row] !=0)
                    num = col[row]
                    if previous == 0:
                        if top_labels[i] % num !=0:
                            #print("fail divisor")
                            return False
                    
                        remaining = top_labels[i] / num
                        if remaining not in y[row+1:,i]:
                            #print("fail remaining. col = {} label = {} number= {} remaining = {} ".format(i,top_labels[i],num,remaining))
                            return False                    
                    
                    elif previous == 1:
                        if col[col!=0][:2].prod() != top_labels[i]:
                            #print("fail total")
                            return False                   
        return True   
    
    
    
def possible(row,split,solution,y):
        #check the line works
        test = y[row,:] * split
        start = left_labels[row]
        end = right_labels[row]
        
        if 1 in y[row]:
            if 1 not in test:
                return False
        
        if (end !=0) or (start !=0):
            rem_zero = test[test!=0]
            if (rem_zero [:2].prod() != start) and (start !=0):
                if (rem_zero [-2:].prod() != end) and (end !=0):
                    return False
        return True


In [340]:
start = time.perf_counter()
first_row =[]

solution = np.ones((9,9),dtype=int)*-1

for split in splits: 
    if possible(0,split,solution,correct):
        if column_check_top(0,split,solution,correct):
            first_row.append(split)
            
stop =  time.perf_counter()
print('Solution took {:0.4f} seconds'.format((stop-start)))
print("Number of poss solutions = ",len(first_row))

print("Contains correct solution =",any(np.array_equal(matrix[0,:], i) for i in first_row))

Solution took 0.0794 seconds
Number of poss solutions =  26
Contains correct solution = True


In [341]:
start = time.perf_counter()
second_row =[]
for row in first_row:
    solution[0,:] = row
    for split in splits:
        if possible(1,split,solution,correct):
            if column_check_top(1,split,solution,correct):
                second_row.append(np.array([row,split]))
                
stop =  time.perf_counter()
print('Solution took {:0.4f} seconds'.format((stop-start)))
print("Number of poss solutions = ",len(second_row))
print("Contains correct solution =",any(np.array_equal(matrix[:2,:], i) for i in second_row))

Solution took 1.1715 seconds
Number of poss solutions =  215
Contains correct solution = True


In [342]:
third_row =[]
start = time.perf_counter()
for row1,row2  in second_row:
    solution = np.ones((9,9))*-1
    solution[0,:] = row1
    solution[1,:] = row2
    for split in splits:
            if possible(2,split,solution,correct):
                if column_check_top(2,split,solution,correct):
                    third_row.append([row1,row2,split])

stop =  time.perf_counter()   

print('Solution took {:0.4f} seconds'.format((stop-start)))
print("Number of poss solutions = ",len(third_row))
print("Contains correct solution =",any(np.array_equal(matrix[:3,:], i) for i in third_row))

Solution took 11.3493 seconds
Number of poss solutions =  5504
Contains correct solution = True


In [343]:
fourth_row =[]
start = time.perf_counter()
for row1,row2,row3  in third_row:
    solution = np.ones((9,9),dtype=int)*-1
    solution[0,:] = row1
    solution[1,:] = row2
    solution[2,:] = row3
    for split in splits:
            if possible(3,split,solution,correct):
                if column_check_top(3,split,solution,correct):
                    fourth_row.append([row1,row2,row3,split])
                    
stop =  time.perf_counter()   

print('Solution took {:0.4f} seconds'.format((stop-start)))
print("Number of poss solutions = ",len(fourth_row))
print("Contains correct solution =",any(np.array_equal(matrix[:4,:], i) for i in fourth_row))

Solution took 177.4579 seconds
Number of poss solutions =  150294
Contains correct solution = True
