In [1]:
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split 
from sklearn.utils import resample
from model import SudokuSolverNetwork

# Read the data

In [2]:
# The data and this reader are taken from https://www.kaggle.com/bryanpark/sudoku

quizzes   = np.zeros((1000000, 81), np.int32)
solutions = np.zeros((1000000, 81), np.int32)
for i, line in enumerate(open('data/sudoku.csv', 'r').read().splitlines()[1:]):
    quiz, solution = line.split(",")
    for j, q_s in enumerate(zip(quiz, solution)):
        q, s = q_s
        quizzes[i, j] = q
        solutions[i, j] = s
quizzes   = quizzes.reshape((-1, 9, 9))
solutions = solutions.reshape((-1, 9, 9))

In [3]:
# add dim so the inputs are [batchsize, 9, 9, channel=1]
quizzes   = np.expand_dims(quizzes, axis=3)

# split the data into a test and train set 
seed = 4

# There are 670,000 samples in the training set and 330,000 in the test set
X_train, X_test, y_train, y_test = train_test_split(quizzes, solutions, test_size=0.33, random_state=seed)


# convert solutions to one hot arrays for targets taken from
# https://stackoverflow.com/questions/36960320/convert-a-2d-matrix-to-a-3d-one-hot-matrix-numpy
oneHotTargets = (np.arange(y_train.max()) == y_train[...,None]-1).astype(int)

# Train the model

In [4]:
# build the graph
tf.reset_default_graph()
network = SudokuSolverNetwork('hyperparams.cfg')
network.makeGraph()

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.



In [None]:
# accuracy function

# how many games are the correct?
# input shape = [batch size, 9 , 9]
def accuracy(prediction, exact):
    correct = 0
    for game in range(prediction.shape[0]):
        if not False in [item for sublist in (solutions[0]==predictions[0]).tolist() for item in sublist]:
            correct += 1
            
    return correct/prediction.shape[0]

In [None]:
batchSize = network.batchSize
numBatches = 50000

losses = []

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for batch in range(numBatches):
        
        # sample a batch
        inputs, targets = resample(X_train, oneHotTargets, n_samples=batchSize)
        
        _, loss = sess.run([network.trainOp, network.loss], 
                           feed_dict={network.train: True,
                                      network.rawInputs: inputs,
                                      network.targets: targets})
        
        losses.append(loss)
        
        if batch % 1000 == 0 or batch == numBatches-1:
            
            # generate a test batch and output its accuracy and also a sample
            print('Batch: ' + str(batch))
            inputs, solutions = resample(X_test, y_test, n_samples=batchSize)
            predictions       = sess.run(network.prediction, feed_dict={network.train: False,
                                                                 network.rawInputs: inputs})

            print(inputs[0,:,:,0])
            print(predictions[0])
            print(solutions[0])
            print(accuracy(prediction=predictions, exact=solutions))
            
            # save checkpoint
            if batch > 0:
                saver.save(sess, 'checkpoints/SudokuNetwork', global_step=network.globalStep)

Batch: 0
[[0 0 2 4 1 0 5 0 0]
 [0 0 0 0 9 0 8 1 0]
 [3 0 6 0 0 2 0 0 0]
 [0 9 0 6 0 4 0 0 7]
 [7 5 0 0 0 0 0 6 8]
 [0 0 1 8 0 0 0 0 3]
 [5 3 9 0 0 8 4 0 0]
 [0 0 0 2 5 0 0 7 0]
 [0 6 0 1 0 0 3 0 0]]
[[2 2 1 7 1 7 7 9 6]
 [1 5 7 2 7 1 5 2 1]
 [7 7 7 2 7 2 2 7 1]
 [7 2 6 7 5 6 6 7 7]
 [2 1 7 2 5 2 7 7 7]
 [6 7 2 7 7 2 2 1 7]
 [7 1 7 5 2 7 7 1 5]
 [5 2 5 7 6 5 1 7 2]
 [5 7 5 1 5 5 7 1 1]]
[[9 8 2 4 1 7 5 3 6]
 [4 7 5 3 9 6 8 1 2]
 [3 1 6 5 8 2 7 9 4]
 [8 9 3 6 2 4 1 5 7]
 [7 5 4 9 3 1 2 6 8]
 [6 2 1 8 7 5 9 4 3]
 [5 3 9 7 6 8 4 2 1]
 [1 4 8 2 5 3 6 7 9]
 [2 6 7 1 4 9 3 8 5]]
0.0
Batch: 100
[[0 0 5 4 1 2 7 9 0]
 [0 0 1 0 0 0 5 0 3]
 [4 0 0 6 0 0 0 0 0]
 [9 0 0 0 0 7 0 0 0]
 [0 8 6 0 9 0 0 2 0]
 [0 1 0 5 0 0 3 8 0]
 [2 7 0 0 0 5 0 0 0]
 [0 0 4 7 0 9 0 6 8]
 [3 6 0 8 0 0 1 0 5]]
[[1 1 5 4 1 1 6 9 1]
 [1 1 1 1 1 1 5 1 3]
 [4 1 1 6 1 1 1 1 1]
 [8 1 1 1 1 7 1 1 1]
 [1 7 5 1 8 1 1 2 1]
 [1 1 1 5 1 1 3 7 1]
 [2 7 1 1 1 5 1 1 1]
 [1 1 4 6 1 9 1 5 7]
 [3 6 1 7 1 1 1 1 5]]
[[8 3 5 4 1 2 7 9 6]
 [6 2