In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

from collections import Counter

from keras.models import load_model

from utils import get_digits

In [None]:
model = load_model('models/digit')

path = 'data/puzzles'

In [None]:
def has_more_than_one(lst):

    counter = Counter(lst)
    counter = np.array([v for k, v in counter.items() if k != 0])

    return (counter > 1).any()


def is_valid_answer(answer):
    
    # Check row and column
    for i in range(9):

        row_invalid = has_more_than_one(answer[i, :])
        col_invalid = has_more_than_one(answer[:, i])

        if row_invalid or col_invalid:
            return False


    for row in range(3):
        for col in range(3):

            quad = answer[row*3:(row+1)*3, col*3:(col+1)*3]
            quad_invalid = has_more_than_one(quad.flatten())

            if quad_invalid:
                return False

    return True

In [None]:
def extract_digit(puzzle, w, h, model):
    
    w = w//9
    h = h//9
    
    digits = []
    
    for row in range(9):
        for col in range(9):
            
            cell = puzzle[row*w:(row+1)*w, col*h:(col+1)*h]
            cell = cell.reshape((1, 42, 42, 1))
#             cell = cell.flatten().reshape((1, -1))

            pred = model.predict(cell)
            pred = pred.argmax(-1)[0]
            
            digits.append(pred)
            
    return np.array(digits).reshape((9, 9))

In [None]:
levels = ['easy', 'medium', 'hard', 'extreme']

correct_games = []
correct_answers = []
correct_games_names = []

wrong_games = []
wrong_answers = []
wrong_games_names = []

num_wrong = 0


for level in levels:

    puzzles = os.listdir('data/puzzles')
    puzzles = [p for p in puzzles if level in p and 'solution' not in p]
    
    compare_result = True

    for i, puzzle in enumerate(puzzles):
        
        name = '{}/{}_{}.csv'.format(path, level, i+1)
        answer_name = '{}/{}_{}_solution.csv'.format(path, level, i+1)
        
        puzzle = pd.read_csv(name).values
        answer = pd.read_csv(answer_name).values
        
        w, h = puzzle.shape[1], puzzle.shape[0]
        
        puzzle_digit = extract_digit(puzzle, w, h, model)
        answer_digit = extract_digit(answer, w, h, model)
    
        compare_base = puzzle_digit[puzzle_digit!=0]
        compare_answer = answer_digit[puzzle_digit!=0]
        
        is_same = (compare_base==compare_answer).all()
        
        num_wrong += (compare_base!=compare_answer).sum()
        
        print('Extracting', name, '\t{}/{}\tBase Matched : {}'.format(i+1, len(puzzles), is_same), flush=True, end='\r')
        
        if is_same and is_valid_answer(answer_digit):
            
            correct_games_names.append([name, answer_name])
            correct_games.append(puzzle_digit)
            correct_answers.append(answer_digit)
            
            
        else:
            
            wrong_games_names.append([name, answer_name])
            wrong_games.append(puzzle_digit)
            wrong_answers.append(answer_digit)

    print()

In [None]:
print('Correctly extracted puzzles :', len(correct_games_names))
print('Incorrectly extracted puzzles :', len(wrong_games_names))
print('Total number of digits misclassified :', num_wrong)
print('Accuracy to correctly classify digits :', 1 - num_wrong / ((len(games) + len(wrong_games)) * 81))

When saying correctly extracted puzzles, it means the digits in base match to ones in answers in same row and col

In [None]:
class SudokuSolver():
    
    
    def __init__(self):
        
        self.answers = []
    
    
    def take(self, puzzle):
        
        self.puzzle = puzzle
        
        
        
    def has_more_than_one(self, lst):
        
        counter = Counter(lst)
        counter = np.array([v for k, v in counter.items() if k != 0])
        
        return (counter > 1).any()
        
        
    def is_valid_move(self, next_puzzle):
        
        # Check row and column
        for i in range(9):
            
            row_invalid = self.has_more_than_one(next_puzzle[i, :])
            col_invalid = self.has_more_than_one(next_puzzle[:, i])
            
            if row_invalid or col_invalid:
                return False
            
            
        for row in range(3):
            for col in range(3):
                
                quad = next_puzzle[row*3:(row+1)*3, col*3:(col+1)*3]
                quad_invalid = self.has_more_than_one(quad.flatten())
                
                if quad_invalid:
                    return False
                
        return True
    
    
    def solve(self):
        
        self.step(self.puzzle.copy())
        
        print('There are {} computed answer(s)'.format(len(self.answers)))
        
        
        
    def check_complete(self, puzzle):
        
        return 0 not in puzzle.flatten() and self.is_valid_move(puzzle)
        
        
    def step(self, puzzle, loc=0):

        if loc > 80:
            if self.is_valid_move(puzzle):
                self.answers.append(puzzle)
            return

        row = loc // 9
        col = loc % 9
        
        # Already filled
        if puzzle[row, col] != 0:
            self.step(puzzle, loc+1)
        else:
            for num in range(1, 10):
                puzzle[row, col] = num

                if self.is_valid_move(puzzle):
                    self.step(puzzle.copy(), loc+1)
                    
                    
    def compare_answer(self, true_answer):
        
        for answer in self.answers:
            
            if (answer == true_answer).all():

                print('There is a correct solved answer')
                
                return answer
            
        print('There are no correct solved answer(s)')


In [None]:
for i, (game_name, answer_name) in enumerate(correct_games_names):
    
    solver = SudokuSolver()
    
    puzzle_name = game_name.split('/')[-1].split('.')[0]
    
    print('Solving', puzzle_name)
    puzzle = correct_games[i]
    answer = correct_answers[i]
    
    solver.take(puzzle)
    solver.solve()
    
    correct_answer = solver.compare_answer(answer)
    
    print()