<a href="https://colab.research.google.com/github/hahnfabian/sudoku-solver/blob/main/sudoku_solver_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

Drive is mounted to get python scripts and the data.

In [1]:
from google.colab import drive
drive.mount("/content/gdrive")
%cd /content/gdrive/My Drive/sudoku_solver/sudoku-solver

! pip install -U scikit-learn

import copy
import keras
import numpy as np
from scripts.model import get_model
from scripts.process_data import get_data

Mounted at /content/gdrive
/content/gdrive/My Drive/sudoku_solver/sudoku-solver
Collecting scikit-learn
  Downloading scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m54.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.2.2
    Uninstalling scikit-learn-1.2.2:
      Successfully uninstalled scikit-learn-1.2.2
Successfully installed scikit-learn-1.3.2


## Load data

The second parameter of `get_data` is the number of puzzle-solution pairs loaded.

In [29]:
x_train, x_test, y_train, y_test = get_data('sudoku.csv', 1000000)

## Train the model

In [None]:
model = get_model()

adam = keras.optimizers.Adam(learning_rate=.001)
model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)

model.fit(x_train, y_train, batch_size=32, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7b8f35171270>

In [None]:
model.save("sudoku_model.h5")

  saving_api.save_model(


## Load pretrained model

In [2]:
model = keras.models.load_model('sudoku_model.h5')

## Solve a sudoku

In [7]:
# Normalize the input array to the range [-0.5, 0.5].

def norm(array):
    return (array / 9) - 0.5

In [6]:
# Denormalize the input array to the range [-4.5, 4.5].
def denorm(array):
    return (array + 0.5) * 9

In [9]:
def inference_sudoku(sample):

    '''
        This function solve the sudoku by filling blank positions one by one.
    '''

    feat = copy.copy(sample)

    while(1):

        model_output = model.predict(feat.reshape((1, 9, 9, 1)))
        model_output = model_output.squeeze()

        pred = np.argmax(model_output, axis=1).reshape((9, 9)) + 1
        prob = np.around(np.max(model_output, axis=1).reshape((9, 9)), 2)

        feat = denorm(feat).reshape((9,9))
        mask = (feat == 0)

        if(mask.sum()==0):
            break

        prob_new = prob * mask

        ind = np.argmax(prob_new)
        x, y = (ind//9), (ind%9)

        val = pred[x][y]
        feat[x][y] = val
        feat = norm(feat)

    return feat

## Testing 100 games

In [8]:
def test_accuracy(feats, labels):

    correct = 0

    for i, feat in enumerate(feats):

        pred = inference_sudoku(feat)

        true = labels[i].reshape((9,9)) + 1

        if(abs(true - pred).sum() == 0):
            correct += 1

    print(correct/feats.shape[0])

In [30]:
test_accuracy(x_test[:100], y_test[:100])

0.86


## Test a game

In [20]:
def is_valid_sudoku(solution):
    for i in range(9):
        if not (is_valid_group(solution[i, :]) and is_valid_group(solution[:, i])):
            return False

    for i in range(0, 9, 3):
        for j in range(0, 9, 3):
            if not is_valid_group(solution[i:i + 3, j:j + 3].flatten()):
                return False

    return True

def is_valid_group(group):
    values = set()
    for value in group:
        if value == 0:
            continue
        if value in values:
            return False
        values.add(value)
    return True




In [13]:
def solve_sudoku(game):

    game = game.replace('\n', '')
    game = game.replace(' ', '')
    game = np.array([int(j) for j in game]).reshape((9,9,1))
    game = norm(game)
    game = inference_sudoku(game)
    return game

In [23]:
game = '''
          0 8 0 0 3 2 0 0 1
          7 0 3 0 8 0 0 0 2
          5 0 0 0 0 7 0 3 0
          0 5 0 0 0 1 9 7 0
          6 0 0 7 0 9 0 0 8
          0 4 7 2 0 0 0 5 0
          0 2 0 6 0 0 0 0 9
          8 0 0 0 9 0 3 0 5
          3 0 0 8 2 0 0 1 0
      '''

game = solve_sudoku(game)
valid_solution = 'The solution is valid.' if is_valid_sudoku(game) else 'The solution is INVALID.'
print('solved puzzle:\n')
print(game)
print(valid_solution)

solved puzzle:

[[4. 8. 9. 5. 3. 2. 7. 6. 1.]
 [7. 1. 3. 4. 8. 6. 5. 9. 2.]
 [5. 6. 2. 9. 1. 7. 8. 3. 4.]
 [2. 5. 8. 3. 4. 1. 9. 7. 6.]
 [6. 3. 1. 7. 5. 9. 2. 4. 8.]
 [9. 4. 7. 2. 6. 8. 1. 5. 3.]
 [1. 2. 5. 6. 7. 3. 4. 8. 9.]
 [8. 7. 6. 1. 9. 4. 3. 2. 5.]
 [3. 9. 4. 8. 2. 5. 6. 1. 7.]]
The solution is valid.
