In [6]:
import pulp
import numpy as np
import time
import seaborn as sns
import matplotlib.pyplot as plt
import requests
from bs4 import BeautifulSoup
import itertools
from copy import copy,deepcopy as dcopy


In [7]:
# Working through the hooks puzzles.
# Solving the matrix is fast using the same tools as the Feb 14 puzzle 
# this time I've managed to cut the number of possible grids down to 3200 from 65k 
# 16 times speed up when you factor in the cost of cutting down the grids
# also sorted out some confusion between row and col labels

url='https://www.janestreet.com/puzzles/hooks-2/'
res = requests.get(url)
soup = BeautifulSoup(res.content, 'html.parser')
x =[text for text in soup.body.stripped_strings]

print(" ".join(x[7:]))

The grid below can be partitioned into 9 L-shaped “hooks”.  The largest is 9-by-9 (contains 17 squares), the next largest is 8-by-8 (contains 15 squares), and so on.  The smallest hook is just a single square. Find where the hooks are located, and place nine 9’s in the largest hook, eight 8’s in the next-largest, etc., down to one 1 in the smallest hook. The goal is for the sum of the numbers in each row and column to match the number given outside the grid. As your answer to this puzzle, submit the largest product one can achieve using a subset of the numbered squares in the completed grid, satisfying the condition that no two squares in the subset are in the same row or column.


### Puzzle details
<img src="https://www.janestreet.com/puzzles/wp-content/uploads/2016/05/may16_puzzle_expanded.png" width="500" height="600">

In [8]:
#Setup the constraints
row_labels = [45,44,4, 48, 7, 14, 47, 43, 33]
col_labels = [36, 5, 47, 35, 17, 30, 21,49, 45]

In [9]:
# Doing this as a class now
# First work out the valid hook placements
# then solve for each one
#
# set up the matrix
# - function to add another layer to the hooks
# - function to check to see if the hook placements allow the outside number to be made 
#   from available digits (i.e. 33 cannot be made from 9s so the 9s cannot be on the bottom row)
# - make change is the function to check the partitions are allowable
# - solve matrix is the function to solve the matrix for a given placement of hooks (next step to do this 
#   with backtracking as lp will be tough for the next puzzle)

class Matrix():
    def __init__(self,col_labels,row_labels):
        self.col_labels = col_labels
        self.row_labels = row_labels
        self.potential_grids = [[-9,np.ones((9,9),dtype=int)*-1,[0,0,9,9]]] # list of all grids not yet ruled out.[Level,grid,coordinates of the grid yet to be filled]
        self.potential_grids
        self.solution = []

     
    def add_layer(self,grid,coords,lvl,alignment):
        row_start,col_start,row_end, col_end = coords
    
        if alignment == 0:
            grid[row_start:row_end,col_start] =lvl
            grid[row_start,col_start:col_end] =lvl
            row_start +=1
            col_start +=1
    
        if alignment == 1:
            grid[row_start:row_end,col_start] =lvl
            grid[row_end-1,col_start:col_end] =lvl
            row_end -=1
            col_start +=1 

        if alignment == 2:
            grid[row_start:row_end,col_end-1] =lvl
            grid[row_start,col_start:col_end] =lvl
            row_start +=1
            col_end -=1
        
        if alignment == 3:
            grid[row_start:row_end,col_end-1] =lvl
            grid[row_end-1,col_start:col_end] =lvl
            row_end -=1
            col_end -=1 
    
        coords = [row_start,col_start,row_end, col_end]
        
        return grid,coords    

    def check_grid(self,grid):
        for row in range(9):
            if self.row_labels[row] != 0:
                min = np.min(grid[row,:])
                if min > 0:
                    if not(self.make_change(self.row_labels[row],list(np.arange(9,min-1,-1)))):
                        return False
                    
        for col in range(9):
            if self.col_labels[col] != 0:
                min = np.min(grid[:,col])
                if min > 0:
                    if not(self.make_change(self.col_labels[col],list(np.arange(9,min-1,-1)))):
                        return False
        return True
    
    
    def make_change(self,amount,coins):
        if amount == 0:
            ret = True
        elif len(coins)==0:
            ret = False
        elif amount < 0:
            ret = False 
        else:
            ret = self.make_change(amount-coins[0], coins) or self.make_change(amount, coins[1:])
        return ret
    
    def solve_matrix(self,y):
        nums = range(1, 10)
        problem = pulp.LpProblem('Problem') 
        x = pulp.LpVariable.dicts('x', [(row, col) for row in nums for col in nums],lowBound=0,upBound=1, cat='Binary') # declare decision variables

        for index in nums:
            mask = (y == index)
            problem += pulp.lpSum([x[(row, col)] * mask[row-1,col-1] for row in range(1, 10) for col in range(1,10)]) == index

        for row in nums:
            problem += pulp.lpSum([y[row-1, col-1] * x[(row, col)] for col in nums]) == self.row_labels[row - 1]

        for col in nums:
            problem += pulp.lpSum([y[row-1, col-1] * x[(row, col)] for row in nums]) == self.col_labels[col - 1]

        #Solve LP
        problem.solve()

        solution = np.zeros((9,9),dtype=int)
        for row in nums:
            for col in nums:
                solution[row - 1][col - 1] = x[(row, col)].varValue

        return np.array(solution) , y

    def solve(self):
            
        while len(self.potential_grids) > 0:
            
            temp_grid = self.potential_grids.pop(0)
            #create the potential rotations at the given level
            rotations = []
        
            for alignment in range(4):
                lvl,grid,coords = dcopy(temp_grid)
                grid,coords = self.add_layer(grid,coords,-lvl,alignment)
                if lvl != -1 :
                    rotations.append([lvl+1,grid,coords])
                else:
                     rotations = [[lvl+1,grid,coords]]
    
            
            #check valid grids (where the sum can be made from available digits) and save the ones that work
            for i in range(len(rotations)):
                lvl,g,coords = rotations[i]
                if self.check_grid(g):
                    if lvl !=0:    
                        self.potential_grids.append([lvl,g,coords])
                    else:
                        self.solution.append(g)
                        
        print("There are {} valid hooks".format(len(self.solution)))
        
        #solve each grid in the cut down list
        for i in range(len(self.solution)):
            grid = self.solution[i]
            #solved,matrix = self.solve_matrix(grid)
            
            if np.sum(np.abs(np.sum(solved*matrix,axis=1)-self.row_labels)) + np.sum(np.abs(np.sum(solved*matrix,axis=0)-self.col_labels)) == 0:
                self.final = [solved,matrix]
                break
                
def sol_print(solved,matrix):
    fig,ax = plt.subplots(1,1,figsize=(5,5))
    ax = sns.heatmap(matrix,annot=solved*matrix,cbar=False,cmap="YlGnBu")
    ax.axis("off")

In [None]:
start = time.perf_counter()
test = Matrix(col_labels,row_labels)
test.solve()
solved,matrix = test.final
stop =  time.perf_counter()
print('Solution took {:0.4f} seconds\n'.format((stop-start)))
sol_print(solved,matrix)
print("row check",np.sum(solved*matrix,axis=1)-row_labels)
print("col check",np.sum(solved*matrix,axis=0)-col_labels)
print("\n")

There are 3200 valid hooks


### Puzzle solution
<img src="https://www.janestreet.com/puzzles/wp-content/uploads/2016/06/may16_solution.png" width="500" height="600">

In [6]:
        correct = np.array([[9, 9, 9, 9, 9, 9, 9, 9, 9],
                            [6, 5, 5, 5, 5, 5, 7, 8, 9],
                            [6, 5, 4, 4, 4, 4, 7, 8, 9],
                            [6, 5, 4, 3, 3, 3, 7, 8, 9],
                            [6, 5, 4, 2, 1, 3, 7, 8, 9],
                            [6, 5, 4, 2, 2, 3, 7, 8, 9],
                            [6, 6, 6, 6, 6, 6, 7, 8, 9],
                            [7, 7, 7, 7, 7, 7, 7, 8, 9],                          
                            [8, 8, 8, 8, 8, 8, 8, 8, 9]])
        