# Use RNN-GAN to generate sudoku

## dataset 
- [10,000 solved sudoku](http://www.printable-sudoku-puzzles.com/wfiles/)

***I didn't manage to get it work yet - at the end, the discriminator outperforms the geneator too much!***

This is an example of how important the balance between generator/discriminator learning is in GAN.

## Load sudoku training data

In [1]:
import numpy as np
from glob import glob

import torch
from torch.autograd import Variable
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader

In [2]:

puzzles = sum([open(f).readlines() 
               for f in glob("/home/dola/data/sudoku/solved/*.txt")], [])
puzzles = [p.strip() for p in puzzles if len(p.strip())==81]
len(puzzles)

10000

In [3]:
puzzles[0]

'123456789578139624496872153952381467641297835387564291719623548864915372235748916'

### check they are valid sudoku

In [4]:
def check_sudoku(puzzle):
    assert len(puzzle) == 81
    p = np.array(list(puzzle)).reshape(9, 9)
    rows = range(9)
    cols = range(9)
    digits = set('123456789')
    strides = [slice(0, 3), slice(3, 6), slice(6, 9)]
    squares = [(r, c) for r in strides for c in strides]
    for r in rows:
        assert set(p[r,:]) == digits, "err: row %i" % r
    for c in cols:
        assert set(p[:,c]) == digits, "err: col %i" % c
    for sr, sc in squares:
        assert set(p[sr, sc].ravel()) == digits, "err: sqr %i %i" % (sr, sc)

In [5]:
# negative example
p = ''.join(['123456789'] * 9)
print(p)
check_sudoku(p)

123456789123456789123456789123456789123456789123456789123456789123456789123456789


AssertionError: err: col 0

In [6]:
for p in puzzles:
    check_sudoku(p)

## Data Processing

In [32]:
symbol2index = dict(zip('123456789', range(9)))
index2symbol = dict(zip(range(9), '123456789'))

def puzzle2tensor(puzzle_batch):
    batch_size = len(puzzle_batch)
    t = torch.rand([batch_size, 81, 9]) * 0.1
    for r in range(batch_size):
        for c in range(81):
            s = symbol2index[puzzle_batch[r][c]]
            t[r, c, s] = 0.9
    return t

def tensor2puzzle(tensors):
    """tensors.size() == [batch_size, 81, 9]
    """
    _, p = tensors.max(dim=2)
    p = p.squeeze().numpy()
    puzzles = []
    for r in p:
        puzzles.append(''.join([index2symbol.get(s) for s in r]))
    return puzzles

## Models

### Generative model

In [222]:
class Generator(nn.Module):
    def set_params(self):
        self.seq_len = 81
        self.input_dim = 64

        self.rnn_layer = 1
        self.rnn_hidden_size = 128

        self.fc1_hidden_size = 128

        self.output_dim = 9
    def __init__(self):
        super(Generator, self).__init__()
        self.set_params()
        self.rnn = nn.GRU(input_size=self.input_dim,
                          hidden_size=self.rnn_hidden_size,
                          num_layers=self.rnn_layer,
                          batch_first=True,
                          bidirectional=False)
        self.fc1 = nn.Linear(self.rnn_hidden_size, self.fc1_hidden_size)
        self.bn1 = nn.BatchNorm1d(self.fc1_hidden_size)
        self.fc2 = nn.Linear(self.fc1_hidden_size, self.output_dim)
        self.elu = nn.ELU()
        self.softmax= nn.Softmax()
    def forward(self, x0):
        batch_size = x0.size(0)
        input_dim = x0.size(2)
#         pads = Variable(torch.zeros([batch_size, self.seq_len-1, input_dim])).cuda()
#         x = torch.cat([x0, pads], dim=1).contiguous()
        x = x0
        
        h0 = Variable(torch.zeros([self.rnn_layer, batch_size, self.rnn_hidden_size])).cuda()
        
        
        out, _ = self.rnn(x, h0)
        out = out.contiguous()
        out = out.view([-1, self.rnn_hidden_size])
        out = self.fc1(out)
        out = self.elu(out)
#         out = self.bn1(out)
        out = self.fc2(out)
        out = self.softmax(out)
        out = out.view([batch_size, self.seq_len, self.output_dim])
        return out
        
generator = Generator().cuda()

In [223]:
## test
x0 = Variable(torch.randn([2, generator.seq_len, generator.input_dim])).cuda()
generated = generator(x0)
tensor2puzzle(generated.cpu().data)

['333339933333884777777777773733973397567373593333393483777377697999892733333773788',
 '379997173447773333333331993313999933933988333395734373394733965778777773355739399']

### Discriminative Model

In [224]:
class Discriminator(nn.Module):
    def set_params(self):
        self.seq_len = 81
        self.input_dim = 9
        
        self.rnn_layer = 1
        self.rnn_hidden_size = 32
        
        self.fc1_hidden_size = 64
        
        self.output_dim = 1
    def __init__(self):
        super(Discriminator, self).__init__()
        self.set_params()
        
        self.rnn = nn.GRU(input_size=self.input_dim,
                          hidden_size=self.rnn_hidden_size,
                          num_layers=self.rnn_layer,
                          batch_first=True,
                          bidirectional=True)
        self.fc1 = nn.Linear(self.rnn_hidden_size*2, self.fc1_hidden_size)
        self.bn1 = nn.BatchNorm1d(self.fc1_hidden_size)
        self.fc2 = nn.Linear(self.fc1_hidden_size, self.output_dim)
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        batch_size = x.size(0)
        h0 = Variable(torch.zeros([self.rnn_layer*2, batch_size, self.rnn_hidden_size])).cuda()
        out, _ = self.rnn(x, h0)
        out = out[:, -1, :]
        out = self.fc1(out)
        out = self.elu(out)
#         out = self.bn1(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out
    
discriminator = Discriminator().cuda()

In [225]:
## test
x_real = Variable(puzzle2tensor(puzzles[:2])).cuda()
x_fake = generated
discriminator(x_real), discriminator(x_fake)

(Variable containing:
  0.5076
  0.5061
 [torch.cuda.FloatTensor of size 2x1 (GPU 0)], Variable containing:
  0.5056
  0.5056
 [torch.cuda.FloatTensor of size 2x1 (GPU 0)])

## training GAN

In [226]:
puzzles = np.array(puzzles)
n = len(puzzles)

In [227]:
objective = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
d_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)

In [228]:
n_epochs = 10
batch_size = 128

generator.train()
discriminator.train()

for epoch in range(n_epochs):
    index = np.random.permutation(n)
    n_batches = n // batch_size
    for ib, batch_index in enumerate(np.array_split(index, n_batches)):
        # train discriminator
        puzzle_batch = puzzles[batch_index]
        real_puzzles = Variable(puzzle2tensor(puzzle_batch)).cuda()
        real_output = discriminator(real_puzzles)
        real_labels = Variable(torch.ones(real_puzzles.size(0))).cuda()
        
        g_x0 = Variable(torch.randn([batch_size, 81, generator.input_dim])).cuda()
        fake_puzzles = generator(g_x0)
        fake_output = discriminator(fake_puzzles)
        fake_labels = Variable(torch.zeros(fake_puzzles.size(0))).cuda()
        
        discriminator.zero_grad()
        d_loss = objective(real_output, real_labels) + objective(fake_output, fake_labels)
        d_loss.backward()
        d_optimizer.step()
        
        ## train generator
        g_x0 = Variable(torch.randn([batch_size, 81, generator.input_dim])).cuda()
        fake_puzzles = generator(g_x0)
        real_labels = Variable(torch.ones(fake_puzzles.size(0))).cuda()
        d_output = discriminator(fake_puzzles)
        
        generator.zero_grad()
        g_loss = objective(d_output, real_labels)
        g_loss.backward()
        g_optimizer.step()
        
        if ib % 30 == 0:
            print(epoch, ib, d_loss.data[0], g_loss.data[0])

0 0 1.387967586517334 0.6823969483375549
0 30 1.3857080936431885 0.6892383098602295
0 60 1.3826104402542114 0.6926155090332031
1 0 1.3801488876342773 0.6955673098564148
1 30 1.373881459236145 0.6996903419494629
1 60 1.3694381713867188 0.6979137659072876
2 0 1.363874912261963 0.7004465460777283
2 30 1.3516838550567627 0.7063392996788025
2 60 1.334257960319519 0.7186979651451111
3 0 1.3199820518493652 0.728507399559021
3 30 1.269355297088623 0.7603364586830139
3 60 1.2026350498199463 0.7926214933395386
4 0 1.1010921001434326 0.8623343110084534
4 30 0.28382647037506104 2.0830180644989014
4 60 0.03941308706998825 3.6772522926330566
5 0 0.027451906353235245 4.0355401039123535
5 30 0.019295286387205124 4.384931564331055
5 60 0.014923274517059326 4.638742923736572
6 0 0.013085928745567799 4.7688093185424805
6 30 0.010758897289633751 4.960851669311523
6 60 0.00905336532741785 5.129833221435547
7 0 0.008300753310322762 5.222613334655762
7 30 0.007140171714127064 5.365665912628174
7 60 0.0062630

***Discriminator totally outperforms!!***