<a href="https://colab.research.google.com/github/hahnfabian/sudoku-solver/blob/main/sudoku_solver_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
! pip install -U scikit-learn

from google.colab import drive
drive.mount("/content/gdrive")
%cd /content/gdrive/My Drive/sudoku_solver/sudoku-solver

import copy
import keras
import numpy as np
from scripts.model import get_model
from scripts.process_data import get_data

Mounted at /content/gdrive
/content/gdrive/My Drive/sudoku_solver/sudoku-solver


## Load data

In [11]:
x_train, x_test, y_train, y_test = get_data('sudoku.csv', 1000)

## Train the model

In [13]:
model = get_model()

adam = keras.optimizers.Adam(learning_rate=.001)
model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)

model.fit(x_train, y_train, batch_size=32, epochs=2)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x7a9e18db0e20>

## Load pretrained model

In [None]:
model = keras.models.load_model('model/sudoku.model')

## Solve a sudoku

In [14]:
def norm(a):
    return (a/9)-.5

In [15]:
def denorm(a):
    return (a+.5)*9

In [21]:
def inference_sudoku(sample):

    '''
        This function solve the sudoku by filling blank positions one by one.
    '''

    feat = copy.copy(sample)

    while(1):

        out = model.predict(feat.reshape((1,9,9,1)))
        out = out.squeeze()

        pred = np.argmax(out, axis=1).reshape((9,9))+1
        prob = np.around(np.max(out, axis=1).reshape((9,9)), 2)

        feat = denorm(feat).reshape((9,9))
        mask = (feat==0)

        if(mask.sum()==0):
            break

        prob_new = prob*mask

        ind = np.argmax(prob_new)
        x, y = (ind//9), (ind%9)

        val = pred[x][y]
        feat[x][y] = val
        feat = norm(feat)

    return feat

## Testing 100 games

In [18]:
def test_accuracy(feats, labels):

    correct = 0

    for i,feat in enumerate(feats):

        pred = inference_sudoku(feat)

        true = labels[i].reshape((9,9))+1

        if(abs(true - pred).sum()==0):
            correct += 1

    print(correct/feats.shape[0])

In [20]:
test_accuracy(x_test[:10], y_test[:10])

0.0
