## Getting the Code & Data

In [None]:
!git clone https://github.com/locuslab/SATNet
%cd SATNet
!python setup.py develop > install.log 2>&1

In [None]:
!wget -cq powei.tw/sudoku.zip && unzip -qq sudoku.zip
!wget -cq powei.tw/parity.zip && unzip -qq parity.zip

## Data Preprocessing

In [None]:
import pandas as pd

df = pd.read_csv("sudoku.csv")
df.head()

In [None]:
indices = df.Index.unique()

In [None]:
df.values

In [None]:
import torch
import numpy as np

rule_matrices = []
order_matrices = []

In [None]:
# create order matrices
for index in indices:
    # initialize empty matrix to fill
    current_orders = np.zeros([9,9])
    # get current sudoku board data
    current_sudoku = df[df['Index'] == index].values
    # iterate through each "move"
    for point in current_sudoku:
        # assign the order number to correct spot
        current_orders[point[2], point[3]] = point[1]
    # add to overall array
    order_matrices.append(current_orders)

In [None]:
# create order matrices
for index in indices:
    current_rule = np.zeros([9,9])
    current_sudoku = df[df['Index'] == index].values
    for point in current_sudoku:
        current_rule[point[2], point[3]] = point[4]
    rule_matrices.append(current_rule)

In [None]:
# convert to torch tensors 
final_orders = torch.tensor(order_matrices)
final_rules = torch.tensor(rule_matrices)

In [None]:
import os
import shutil
import argparse
from collections import namedtuple

import numpy as np
import numpy.random as npr

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import matplotlib.pyplot as plt
from IPython.display import display, Markdown, Latex, clear_output
import tqdm

In [None]:
if not torch.cuda.is_available(): 
    print('[WARNING] Not using GPU.')
    print('Please select "Runtime -> Change runtime type" and switch to GPU for better performance')
else:
    print('Using', torch.cuda.get_device_name(0))

In [None]:
import satnet
print('SATNet document\n', satnet.SATNet.__doc__)

In [None]:
class SudokuSolver(nn.Module):
    def __init__(self, boardSz, aux, m):
        super(SudokuSolver, self).__init__()
        n = boardSz**6
        self.sat = satnet.SATNet(n, m, aux)

    def forward(self, y_in, mask):
        out = self.sat(y_in, mask)
        del y_in, mask
        return out

In [None]:
class DigitConv(nn.Module):
    '''
    Convolutional neural network for MNIST digit recognition. From:
    https://github.com/pytorch/examples/blob/master/mnist/main.py
    '''
    def __init__(self):
        super(DigitConv, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)[:,:9].contiguous()

In [None]:
class MNISTSudokuSolver(nn.Module):
    def __init__(self, boardSz, aux, m):
        super(MNISTSudokuSolver, self).__init__()
        self.digit_convnet = DigitConv()
        self.sudoku_solver = SudokuSolver(boardSz, aux, m)
        self.boardSz = boardSz
        self.nSq = boardSz**2
    
    def forward(self, x, is_inputs):
        nBatch = x.shape[0]
        x = x.flatten(start_dim = 0, end_dim = 1)
        digit_guess = self.digit_convnet(x)
        puzzles = digit_guess.view(nBatch, self.nSq * self.nSq * self.nSq)

        solution = self.sudoku_solver(puzzles, is_inputs)
        return solution

In [None]:
from exps.sudoku import train, test, FigLogger, find_unperm
args_dict = {'lr': 2e-3, 
             'cuda': torch.cuda.is_available(), 
             'batchSz': 40,
             'mnistBatchSz': 50,
             'boardSz': 3, # for 9x9 Sudoku
             'm': 600,
             'aux': 300,
             'nEpoch': 20
            }
args = namedtuple('Args', args_dict.keys())(*args_dict.values())

In [None]:
def process_inputs(X, Ximg, Y, boardSz):
    is_input = X.sum(dim=3, keepdim=True).expand_as(X).int().sign()

    Ximg = Ximg.flatten(start_dim=1, end_dim=2)
    Ximg = Ximg.unsqueeze(2).float()

    X      = X.view(X.size(0), -1)
    Y      = Y.view(Y.size(0), -1)
    is_input = is_input.view(is_input.size(0), -1)

    return X, Ximg, Y, is_input

In [None]:
with open('sudoku/features.pt', 'rb') as f:
    X_in = torch.load(f)
with open('sudoku/features_img.pt', 'rb') as f:
    Ximg_in = torch.load(f)
with open('sudoku/labels.pt', 'rb') as f:
    Y_in = torch.load(f)
with open('sudoku/perm.pt', 'rb') as f:
    perm = torch.load(f)

In [None]:
X, Ximg, Y, is_input = process_inputs(X_in, Ximg_in, Y_in, args.boardSz)
if args.cuda: X, Ximg, is_input, Y = X.cuda(), Ximg.cuda(), is_input.cuda(), Y.cuda()

N = X_in.size(0)
nTrain = int(N*0.9)
print(nTrain)

In [None]:
sudoku_train = TensorDataset(X[:nTrain], is_input[:nTrain], Y[:nTrain])
sudoku_test =  TensorDataset(X[nTrain:], is_input[nTrain:], Y[nTrain:])
perm_train = TensorDataset(X[:nTrain,perm], is_input[:nTrain,perm], Y[:nTrain,perm])
perm_test =  TensorDataset(X[nTrain:,perm], is_input[nTrain:,perm], Y[nTrain:,perm])
mnist_train = TensorDataset(Ximg[:nTrain], is_input[:nTrain], Y[:nTrain])
mnist_test =  TensorDataset(Ximg[nTrain:], is_input[nTrain:], Y[nTrain:])

In [None]:
def show_sudoku(raw):
    return (torch.argmax(raw,2)+1)*(raw.sum(2).long())

In [None]:
def show_mnist_sudoku(raw):
    A = raw.numpy()
    digits = np.concatenate(np.concatenate(A,axis=1), axis=1).astype(np.uint8)
    linewidth = 2
    board = np.zeros((digits.shape[0]+linewidth*4, digits.shape[1]+linewidth*4), dtype=np.uint8)
    gridwidth = digits.shape[0]//3

    board[:] = 255
    for i in range(3):
        for j in range(3):
            xoff = linewidth+(linewidth+gridwidth)*i
            yoff = linewidth+(linewidth+gridwidth)*j
            xst = gridwidth*i
            yst = gridwidth*j
            board[xoff:xoff+gridwidth, yoff:yoff+gridwidth] = digits[xst:xst+gridwidth, yst:yst+gridwidth]

    #img = Image.fromarray(255-board)
    plt.imshow(255-board, cmap='gray')

In [None]:
def showw(raw):
    i =- 1
    j =- 1
    res = torch.empty(9,9)
    for k in range(729):
        if k%81 == 0:
            i += 1
        if k%9 == 0:
            j = (j+1) % 9
            h1 = k%9
            h2 = raw[k].item()
        if raw[k].item() > h2:
            h2 = raw[k].item()
            h1 = k%9
        if k%9 == 8:
            res[i][j] = h1+1
    return res

In [None]:
sudoku_model = SudokuSolver(args.boardSz, args.aux, args.m)
if args.cuda: sudoku_model = sudoku_model.cuda()

In [None]:
plt.ioff()
optimizer = optim.Adam(sudoku_model.parameters(), lr=args.lr)

fig, axes = plt.subplots(1,2, figsize=(10,4))
plt.subplots_adjust(wspace=0.4)
train_logger = FigLogger(fig, axes[0], 'Traininig')
test_logger = FigLogger(fig, axes[1], 'Testing')

In [None]:
test(args.boardSz, 0, sudoku_model, optimizer, test_logger, sudoku_test, args.batchSz)
plt.pause(0.01)

In [None]:
for epoch in range(1, args.nEpoch+1):
    train(args.boardSz, epoch, sudoku_model, optimizer, train_logger, sudoku_train, args.batchSz)
    test(args.boardSz, epoch, sudoku_model, optimizer, test_logger, sudoku_test, args.batchSz)
    display(fig)
    aa=X[9000:9040]
    bb=is_input[9000:9040]
    print(showw(sudoku_model.forward(aa,bb)[0]))