# Varational autoencoder by RNN for sudoku sequence generation

## dataset 
- [10,000 solved sudoku](http://www.printable-sudoku-puzzles.com/wfiles/)
- [bigger dataset: 1M sudoku and solutions](https://www.kaggle.com/bryanpark/sudoku)

## Load sudoku training data

In [2]:
import numpy as np
from glob import glob

import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

In [7]:

puzzles = sum([open(f).readlines() 
               for f in glob("/home/dola/ws/data/123456789.txt")], [])
puzzles = [p.strip() for p in puzzles if len(p.strip())==81]
len(puzzles)

10000

In [6]:
import pandas as pd
kaggle_sudoku = pd.read_csv("/home/dola/data/sudoku.csv")
puzzles = list(kaggle_sudoku.solutions)
len(puzzles)

1000000

In [8]:
puzzles[0]

'123456789578139624496872153952381467641297835387564291719623548864915372235748916'

### check they are valid sudoku

In [21]:
def check_sudoku(puzzle):
    assert len(puzzle) == 81
    p = np.array(list(puzzle)).reshape(9, 9)
    rows = range(9)
    cols = range(9)
    digits = set('123456789')
    strides = [slice(0, 3), slice(3, 6), slice(6, 9)]
    squares = [(r, c) for r in strides for c in strides]
    for r in rows:
        assert set(p[r,:]) == digits, "err: row %i" % r
    for c in cols:
        assert set(p[:,c]) == digits, "err: col %i" % c
    for sr, sc in squares:
        assert set(p[sr, sc].ravel()) == digits, "err: sqr %i %i" % (sr, sc)

In [10]:
# negative example
p = ''.join(['123456789'] * 9)
print(p)
check_sudoku(p)

123456789123456789123456789123456789123456789123456789123456789123456789123456789


AssertionError: err: col 0

In [11]:
from tqdm import tqdm_notebook
for i in tqdm_notebook(range(len(puzzles))):
    check_sudoku(puzzles[i])




## Data Processing

### data augumentation

In [12]:
from tqdm import tqdm_notebook

In [13]:
%%time
def augument(puzzles):
    augumented = [p for p in puzzles]
#     augumented += [p[::-1] for p in puzzles]
    for p in puzzles:
        p = np.array(list(p)).reshape([9, 9])
        augumented.append(''.join(np.fliplr(p).ravel()))
        p = p.T
        augumented.append(''.join(p.ravel()))
        augumented.append(''.join(np.fliplr(p).ravel()))
        p = p.T
        augumented.append(''.join(p.ravel()))
        augumented.append(''.join(np.fliplr(p).ravel()))
        p = p.T
        augumented.append(''.join(p.ravel()))
        augumented.append(''.join(np.fliplr(p).ravel()))
    return augumented

puzzles = augument(puzzles)

CPU times: user 1.41 s, sys: 24 ms, total: 1.43 s
Wall time: 1.41 s


In [14]:
print(len(puzzles))

for p in puzzles:
    check_sudoku(p)

80000


In [15]:
symbol2index = dict(zip('123456789', range(9)))
index2symbol = dict(zip(range(9), '123456789'))

def puzzle2tensor(puzzle_batch):
    batch_size = len(puzzle_batch)
    seq_len = len(puzzle_batch[0])
    t = torch.zeros([batch_size, seq_len, 9])
    for r in range(batch_size):
        for c in range(seq_len):
            s = symbol2index[puzzle_batch[r][c]]
            t[r, c, s] = 1
    return t

def puzzle2target(puzzle_batch):
    batch_size = len(puzzle_batch)
    seq_len = len(puzzle_batch[0])
    t = torch.LongTensor(batch_size, seq_len).zero_()
    for r in range(batch_size):
        for c in range(seq_len):
            s = symbol2index[puzzle_batch[r][c]]
            t[r, c] = s
    return t

def tensor2puzzle(tensors):
    """tensors.size() == [batch_size, seq, 9]
    """
    _, p = tensors.max(dim=2)
    p = p.squeeze().numpy()
    puzzles = []
    for r in p:
        puzzles.append(''.join([index2symbol.get(s) for s in r]))
    return puzzles

def output2puzzle(y):
    """y.size() == [batch_size, 9]
    """
    _, labels = y.max(dim=1)
    return labels.numpy().squeeze() + 1

In [16]:
## test
t = puzzle2tensor([p[:5] for p in puzzles[:2]])
p = tensor2puzzle(t)
print(t.size())
print(p)

torch.Size([2, 5, 9])
['12345', '12345']


In [17]:
## test
puzzle2target([p[:5] for p in puzzles[:2]])


 0  1  2  3  4
 0  1  2  3  4
[torch.LongTensor of size 2x5]

## Models - reinforcement learning

### build a sudoku env following openai

In [170]:
class Sudoku(object):
    
    def __init__(self):
        self.state = None
        
    def reset(self):
        """return observation"""
        self.state = []
        obs = self.state
        return obs
    
    def step(self, action):
        """action: '123456789'
        return [obs, reward, done, info]"""
        self.state.append(action)
        obs = self.state
        reward = 1
        done, info = self.is_valid(self.state)
        return [obs, reward, done, info]
    
    def is_valid(self, puzzle):
        """puzzle: an array of digits.
        Check if a partial puzzle is valid
        """
        nrows = int(np.ceil(len(puzzle)/9))
        for r in range(0, nrows):
            row = puzzle[r*9:(r+1)*9]
            if len(row) != len(set(row)):
                return False, "row violation"
        for c in range(0, 9):
            col = puzzle[c:c+9*9:9]
            if len(col) != len(set(col)):
                return False, "col violation"
        strides = [list(range(0, 3)), list(range(3, 6)), list(range(6, 9))]
        squares = [(r, c) for r in strides for c in strides]
        for sqr, sqc in squares:
            sqi = [r*9+c for r in sqr for c in sqc]
            print(sqi)
            square = [puzzle[i] for i in sqi if i < len(puzzle)]
            print(square)
        return True, "valid"

In [171]:
env = Sudoku()
env.reset()
env.step('1')

[0, 1, 2, 9, 10, 11, 18, 19, 20]
['1']
[3, 4, 5, 12, 13, 14, 21, 22, 23]
[]
[6, 7, 8, 15, 16, 17, 24, 25, 26]
[]
[27, 28, 29, 36, 37, 38, 45, 46, 47]
[]
[30, 31, 32, 39, 40, 41, 48, 49, 50]
[]
[33, 34, 35, 42, 43, 44, 51, 52, 53]
[]
[54, 55, 56, 63, 64, 65, 72, 73, 74]
[]
[57, 58, 59, 66, 67, 68, 75, 76, 77]
[]
[60, 61, 62, 69, 70, 71, 78, 79, 80]
[]


[['1'], 1, True, 'valid']

In [172]:
puzzle = list("123456789987654321")
puzzle = puzzles[0]
env.is_valid(puzzle)

[0, 1, 2, 9, 10, 11, 18, 19, 20]
['1', '2', '3', '5', '7', '8', '4', '9', '6']
[3, 4, 5, 12, 13, 14, 21, 22, 23]
['4', '5', '6', '1', '3', '9', '8', '7', '2']
[6, 7, 8, 15, 16, 17, 24, 25, 26]
['7', '8', '9', '6', '2', '4', '1', '5', '3']
[27, 28, 29, 36, 37, 38, 45, 46, 47]
['9', '5', '2', '6', '4', '1', '3', '8', '7']
[30, 31, 32, 39, 40, 41, 48, 49, 50]
['3', '8', '1', '2', '9', '7', '5', '6', '4']
[33, 34, 35, 42, 43, 44, 51, 52, 53]
['4', '6', '7', '8', '3', '5', '2', '9', '1']
[54, 55, 56, 63, 64, 65, 72, 73, 74]
['7', '1', '9', '8', '6', '4', '2', '3', '5']
[57, 58, 59, 66, 67, 68, 75, 76, 77]
['6', '2', '3', '9', '1', '5', '7', '4', '8']
[60, 61, 62, 69, 70, 71, 78, 79, 80]
['5', '4', '8', '3', '7', '2', '9', '1', '6']


(True, 'valid')

In [159]:
strides = [list(range(0, 3)), list(range(3, 6)), list(range(6, 9))]
squares = [(r, c) for r in strides for c in strides]

In [160]:
squares

[([0, 1, 2], [0, 1, 2]),
 ([0, 1, 2], [3, 4, 5]),
 ([0, 1, 2], [6, 7, 8]),
 ([3, 4, 5], [0, 1, 2]),
 ([3, 4, 5], [3, 4, 5]),
 ([3, 4, 5], [6, 7, 8]),
 ([6, 7, 8], [0, 1, 2]),
 ([6, 7, 8], [3, 4, 5]),
 ([6, 7, 8], [6, 7, 8])]