# Varational autoencoder by RNN for sudoku sequence generation

## dataset 
- [10,000 solved sudoku](http://www.printable-sudoku-puzzles.com/wfiles/)
- [bigger dataset: 1M sudoku and solutions](https://www.kaggle.com/bryanpark/sudoku)

## Load sudoku training data

In [1]:
import numpy as np
from glob import glob

import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

In [153]:

puzzles = sum([open(f).readlines() 
               for f in glob("123456789.txt")], [])
puzzles = [p.strip() for p in puzzles if len(p.strip())==81]
len(puzzles)

10000

In [6]:
import pandas as pd
kaggle_sudoku = pd.read_csv("/home/dola/data/sudoku.csv")
puzzles = list(kaggle_sudoku.solutions)
len(puzzles)

1000000

In [8]:
puzzles[0]

'864371259325849761971265843436192587198657432257483916689734125713528694542916378'

### check they are valid sudoku

In [9]:
def check_sudoku(puzzle):
    assert len(puzzle) == 81
    p = np.array(list(puzzle)).reshape(9, 9)
    rows = range(9)
    cols = range(9)
    digits = set('123456789')
    strides = [slice(0, 3), slice(3, 6), slice(6, 9)]
    squares = [(r, c) for r in strides for c in strides]
    for r in rows:
        assert set(p[r,:]) == digits, "err: row %i" % r
    for c in cols:
        assert set(p[:,c]) == digits, "err: col %i" % c
    for sr, sc in squares:
        assert set(p[sr, sc].ravel()) == digits, "err: sqr %i %i" % (sr, sc)

In [10]:
# negative example
p = ''.join(['123456789'] * 9)
print(p)
check_sudoku(p)

123456789123456789123456789123456789123456789123456789123456789123456789123456789


AssertionError: err: col 0

In [12]:
from tqdm import tqdm_notebook
for i in tqdm_notebook(range(len(puzzles))):
    check_sudoku(puzzles[i])




## Data Processing

### data augumentation

In [158]:
from tqdm import tqdm_notebook

In [159]:
%%time
def augument(puzzles):
    augumented = [p for p in puzzles]
#     augumented += [p[::-1] for p in puzzles]
    for p in puzzles:
        p = np.array(list(p)).reshape([9, 9])
        augumented.append(''.join(np.fliplr(p).ravel()))
        p = p.T
        augumented.append(''.join(p.ravel()))
        augumented.append(''.join(np.fliplr(p).ravel()))
        p = p.T
        augumented.append(''.join(p.ravel()))
        augumented.append(''.join(np.fliplr(p).ravel()))
        p = p.T
        augumented.append(''.join(p.ravel()))
        augumented.append(''.join(np.fliplr(p).ravel()))
    return augumented

puzzles = augument(puzzles)

CPU times: user 1.41 s, sys: 104 ms, total: 1.52 s
Wall time: 1.44 s


In [160]:
print(len(puzzles))

for p in puzzles:
    check_sudoku(p)

80000


In [13]:
symbol2index = dict(zip('123456789', range(9)))
index2symbol = dict(zip(range(9), '123456789'))

def puzzle2tensor(puzzle_batch):
    batch_size = len(puzzle_batch)
    seq_len = len(puzzle_batch[0])
    t = torch.zeros([batch_size, seq_len, 9])
    for r in range(batch_size):
        for c in range(seq_len):
            s = symbol2index[puzzle_batch[r][c]]
            t[r, c, s] = 1
    return t

def puzzle2target(puzzle_batch):
    batch_size = len(puzzle_batch)
    seq_len = len(puzzle_batch[0])
    t = torch.LongTensor(batch_size, seq_len).zero_()
    for r in range(batch_size):
        for c in range(seq_len):
            s = symbol2index[puzzle_batch[r][c]]
            t[r, c] = s
    return t

def tensor2puzzle(tensors):
    """tensors.size() == [batch_size, seq, 9]
    """
    _, p = tensors.max(dim=2)
    p = p.squeeze().numpy()
    puzzles = []
    for r in p:
        puzzles.append(''.join([index2symbol.get(s) for s in r]))
    return puzzles

def output2puzzle(y):
    """y.size() == [batch_size, 9]
    """
    _, labels = y.max(dim=1)
    return labels.numpy().squeeze() + 1

In [14]:
## test
t = puzzle2tensor([p[:5] for p in puzzles[:2]])
p = tensor2puzzle(t)
print(t.size())
print(p)

torch.Size([2, 5, 9])
['86437', '34617']


In [15]:
## test
puzzle2target([p[:5] for p in puzzles[:2]])


 7  5  3  2  6
 2  3  5  0  6
[torch.LongTensor of size 2x5]

## Models - seq prediction 

In [29]:
class RnnVA(nn.Module):
    """RNN based Variational Autoencoder
    """
    def __init__(self):
        super(RnnVA, self).__init__()
        self.input_size = 9
        self.seq_len = 81
        self.encoder_rnn_num_layers = 2
        self.encoder_rnn_hidden_size = 128
        self.encoder_fc1_hidden_size = 256
        self.encoder_fc2_hidden_size = 256
        self.va_size = 256
        self.decoder_fc1_hidden_size = 256
        self.decoder_rnn_num_layers = 1
        self.decoder_rnn_hidden_size = 128
        self.output_size = 9
        
        self.encoder_rnn = nn.GRU(input_size=self.input_size,
                                  hidden_size=self.encoder_rnn_hidden_size,
                                  num_layers=self.encoder_rnn_num_layers,
                                  batch_first=True,
                                  bidirectional=False)
        self.encoder_fc1 = nn.Linear(self.encoder_rnn_hidden_size,
                                    self.encoder_fc1_hidden_size)
        self.encoder_fc2 = nn.Linear(self.encoder_fc1_hidden_size,
                                    self.encoder_fc2_hidden_size)
        self.elu = nn.ELU()
        self.va_mean = nn.Linear(self.encoder_fc2_hidden_size, self.va_size)
        self.va_gamma = nn.Linear(self.encoder_fc2_hidden_size, self.va_size)
        self.decoder_fc1 = nn.Linear(self.va_size, self.decoder_fc1_hidden_size)
        self.decoder_rnn = nn.GRU(input_size=self.decoder_fc1_hidden_size,
                                 hidden_size=self.decoder_rnn_hidden_size,
                                 num_layers=self.decoder_rnn_num_layers,
                                 batch_first=True,
                                 bidirectional=False)
        self.decoder_classifier = nn.Linear(self.decoder_rnn_hidden_size,
                                            self.output_size)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        batch_size = x.size(0)
        seq_len = x.size(1)
        encode_h0 = Variable(torch.zeros([self.encoder_rnn_num_layers,
                                   batch_size,
                                   self.encoder_rnn_hidden_size])).cuda()
        out, h = self.encoder_rnn(x, encode_h0)
        out = out.contiguous().view([-1, self.encoder_rnn_hidden_size])
        out = self.elu(self.encoder_fc1(out))
        out = self.elu(self.encoder_fc2(out))
        self.mean = self.va_mean(out)
        self.gamma = self.va_gamma(out)
        self.sigma = torch.exp(self.gamma * .5)
        noise = Variable(torch.randn(self.sigma.size())).cuda()
        h = self.mean + self.sigma * noise
        out = self.elu(self.decoder_fc1(h))
        out = out.view([batch_size, seq_len, -1])
        decode_h0 = Variable(torch.zeros([self.decoder_rnn_num_layers,
                                   batch_size,
                                   self.decoder_rnn_hidden_size])).cuda()
        out, h = self.decoder_rnn(out, decode_h0)
        out = out.contiguous().view([-1, self.decoder_rnn_hidden_size])
        logits = self.decoder_classifier(out)
        probs = self.sigmoid(logits)
        probs = probs.view([batch_size, seq_len, -1])
        return probs
    def generate(self, batch_size): 
        seq_len = 81
        h = Variable(torch.randn([batch_size * seq_len, self.va_size])).cuda()
        out = self.elu(self.decoder_fc1(h))
        out = out.view([batch_size, seq_len, -1])
        decode_h0 = Variable(torch.zeros([self.decoder_rnn_num_layers,
                                   batch_size,
                                   self.decoder_rnn_hidden_size])).cuda()
        out, h = self.decoder_rnn(out, decode_h0)
        out = out.contiguous().view([-1, self.decoder_rnn_hidden_size])
        logits = self.decoder_classifier(out)
        probs = self.sigmoid(logits)
        probs = probs.view([batch_size, seq_len, -1])
        return probs

In [30]:
model = RnnVA().cuda()
x = Variable(torch.zeros([32, 81, 9])).cuda()
y = model(x)
y.size(), model.mean.size(), model.gamma.size()

(torch.Size([32, 81, 9]), torch.Size([2592, 256]), torch.Size([2592, 256]))

In [31]:
model.eval()
generated = model.generate(10)
print(generated.size())
tensor2puzzle(generated.cpu().data)

torch.Size([10, 81, 9])


['358833389399139797338337888888737397777333888888888383333797333783733337333938335',
 '779333311633337983389373877773397999898938888835893388378333338936878888899878933',
 '373933788777777733338988998883333899838978988558889338383989889998888888836693377',
 '537888788779883383438935388899838993988273387333333887787818787733839993337779339',
 '889998383877798868888899899967993998918895583883739987779353738873878399333888883',
 '888873883333387553399843338773839939388389747883488998888889887773779778333339999',
 '988333777833888888777333333731383833733138839333383333335988183357788988833311337',
 '787788338833383388585888787733388388887333377999378788888938797833377788889799773',
 '358888388588893339973339983398787788888388889383339988333113333998888733778511379',
 '333388888893838933233888999978887899889953337368333979937887778883888888883833333']

## Data Preparation

In [19]:
class SudokuDataSet(Dataset):
    def __init__(self, puzzles):
        self.puzzles = puzzle2tensor(puzzles)
    def __len__(self):
        return self.puzzles.size(0)
    def __getitem__(self, i):
        return self.puzzles[i, ...]

In [21]:
%time data = SudokuDataSet(puzzles)


CPU times: user 49 s, sys: 1.06 s, total: 50.1 s
Wall time: 46.6 s


In [32]:
model = RnnVA().cuda()
model.train()

RnnVA (
  (encoder_rnn): GRU(9, 128, num_layers=2, batch_first=True)
  (encoder_fc1): Linear (128 -> 256)
  (encoder_fc2): Linear (256 -> 256)
  (elu): ELU (alpha=1.0)
  (va_mean): Linear (256 -> 256)
  (va_gamma): Linear (256 -> 256)
  (decoder_fc1): Linear (256 -> 256)
  (decoder_rnn): GRU(256, 128, batch_first=True)
  (decoder_classifier): Linear (128 -> 9)
  (sigmoid): Sigmoid ()
)

In [33]:
n_epochs = 10

batches = DataLoader(data, batch_size=128, shuffle=True, num_workers=4)
xentropy = nn.BCELoss()


optimizer = optim.Adam(model.parameters())

for epoch in range(n_epochs):
    for b, batch in enumerate(batches):
        x = Variable(batch).cuda()
        
        model.zero_grad()
        xx = model(x)
        restore_loss = xentropy(xx.view([-1, 9]), x.view([-1, 9]))
        latent_loss = 0.5 * torch.mean(torch.exp(model.gamma) + model.mean*model.mean -1 - model.gamma)
        loss = restore_loss + latent_loss
        loss.backward()
        optimizer.step()
        
        if b % 500 == 0:
            print(epoch, b, loss.data[0], restore_loss.data[0], latent_loss.data[0])

0 0 0.6869287490844727 0.6853573322296143 0.0015713907778263092
0 500 0.07329657673835754 0.0546133890748024 0.018683191388845444
0 1000 0.032200321555137634 0.013167204335331917 0.019033119082450867
0 1500 0.02408786304295063 0.004989929962903261 0.019097933545708656
0 2000 0.022378552705049515 0.00337812933139503 0.019000424072146416
0 2500 0.020982276648283005 0.0018435517558827996 0.019138725474476814
0 3000 0.01959368586540222 0.0018789124442264438 0.01771477423608303
0 3500 0.017221301794052124 0.0013526127440854907 0.015868689864873886
0 4000 0.015896441414952278 0.001477017649449408 0.014419423416256905
0 4500 0.016673315316438675 0.00212892796844244 0.014544387347996235
0 5000 0.015405920334160328 0.0012989819515496492 0.014106938615441322
0 5500 0.014674486592411995 0.0010229457402601838 0.013651540502905846
0 6000 0.01421832013875246 0.0008640493033453822 0.013354270718991756
0 6500 0.013682294636964798 0.000946232583373785 0.012736061587929726
0 7000 0.013658950105309486 0.

Process Process-23:
Process Process-24:
Process Process-22:
Traceback (most recent call last):
Traceback (most recent call last):
Process Process-21:
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 249, in _bootstrap
    self.run()
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dola/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 35, in _worker_loop
    r = index_queue.get()
  File "/home/dola/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 35, in _worker_loop
    r = index_queue.get()
  File "/home/dola/anaconda3/lib/python3.6/multiprocessing/queues.py", line 34

KeyboardInterrupt: 

In [34]:
model.eval()
generated_puzzles = tensor2puzzle(model.generate(10).cpu().data)

In [35]:
generated_puzzles

['839142457983517148669418775773893255482886815236132119489648265767795791199359421',
 '467134614492673621933698721574684973228448167313726958646314479952274592924168257',
 '114726342615876993922421447967989344642814466978357984611758196647462461467161329',
 '254972914673464527561396759575252552173957545841127269826643517525487264427215953',
 '766553754832922653527878568411254457557921952977216315424359934552181374787154385',
 '234342868761257145872557745163983116811391597957484564868472411359214469618198138',
 '894191559625138577857145466382918997947861467296716321357318215768585416675975964',
 '945194212669111117762376876499474372722574679572558492927247748324616819344471649',
 '422943964256846866884981479744683575714365183542841247199147471251898136376164766',
 '288554218565177221675134621911393561932829729877683444436577951131495156537641457']

In [36]:
check_sudoku(generated_puzzles[-3])

AssertionError: err: row 0