In [None]:
!git clone https://github.com/locuslab/SATNet
%cd SATNet
!python setup.py develop > install.log 2>&1

Cloning into 'SATNet'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 101 (delta 0), reused 2 (delta 0), pack-reused 96[K
Receiving objects: 100% (101/101), 497.29 KiB | 7.77 MiB/s, done.
Resolving deltas: 100% (37/37), done.
/content/SATNet


In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
import os
import shutil
import argparse
from collections import namedtuple

import numpy as np
import numpy.random as npr

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import matplotlib.pyplot as plt
from IPython.display import display, Markdown, Latex, clear_output
import tqdm
import pickle as pk

import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.auto import tqdm
import random 
import io 

from sklearn.utils import shuffle

if not torch.cuda.is_available(): 
    print('[WARNING] Not using GPU.')
    print('Please select "Runtime -> Change runtime type" and switch to GPU for better performance')
else:
    print('Using', torch.cuda.get_device_name(0))

Using Tesla T4


# SATNet

Based on model from [SATNet paper](https://icml.cc/Conferences/2019/Schedule?showEvent=3947).

In [None]:
import satnet
print('SATNet document\n', satnet.SATNet.__doc__)

SATNet document
 Apply a SATNet layer to complete the input probabilities.

    Args:
        n: Number of input variables.
        m: Rank of the clause matrix.
        aux: Number of auxiliary variables.

        max_iter: Maximum number of iterations for solving
            the inner optimization problem.
            Default: 40
        eps: The stopping threshold for the inner optimizaiton problem.
            The inner Mixing method will stop when the function decrease
            is less then eps times the initial function decrease.
            Default: 1e-4
        prox_lam: The diagonal increment in the backward linear system
            to make the backward pass more stable.
            Default: 1e-2
        weight_normalize: Set true to perform normlization for init weights.
            Default: True

    Inputs: (z, is_input)
        **z** of shape `(batch, n)`: 
            Float tensor containing the probabilities (must be in [0,1]).
        **is_input** of shape `(batch, 

# Building SATNet-based Models
To solve **Numbrix** we construct a SATNet-based Numbrix Solver layer that takes as input a logical (bit) representation of the initial Numbrix board along with a mask representing which bits must be learned (i.e. all bits in empty cells). This input is vectorized. Given this input, the Numbrix Solver layer then outputs a bit representation of the Numbrix board with guesses for the unknown bits.

In [None]:
BOARD_SZ = 6
#BOARD_SZ = 5

In [None]:
class NumbrixSolver(nn.Module):
    def __init__(self, boardSz, aux, m):
        super(NumbrixSolver, self).__init__()
        n = boardSz**4
        self.sat = satnet.SATNet(n, m, aux)

    def forward(self, y_in, mask):
        out = self.sat(y_in, mask)
        #del y_in, mask
        return out

In [None]:
from exps.sudoku import FigLogger
args_dict = {'lr': 1e-4, # 2e-3,
             'cuda': torch.cuda.is_available(), 
             'batchSz': 50,
             'boardSz': BOARD_SZ,
             'm': 800,
             'aux': 400,
             'nEpoch': 100
            }
args = namedtuple('Args', args_dict.keys())(*args_dict.values())



EZ = False # Easy
#EZ = False # Hard

# The Numbrix Dataset

In [None]:
cd /content/drive/MyDrive/Numbrix-data

/content/drive/MyDrive/Numbrix-data


In [None]:
'''
Check if there are empty cells on the edges of the puzzle
'''
def is_hard(puzzle, n):
    empty_idx = (torch.tensor(puzzle) == 0).nonzero(as_tuple=False)
    on_edge = ((empty_idx==0) | (empty_idx == n-1)).sum(dim=1)
    return on_edge.sum() > 0

'''
Fill in empty puzzle edges in x with values from the solution y
'''
def fill_in_edges(x, y, n):
    empty_idx = (torch.tensor(x) == 0).nonzero()
    on_edge = ((empty_idx==0) | (empty_idx == n-1))
    on_edge = empty_idx[on_edge.sum(dim=1) == 1]
    x[on_edge[:, 0], on_edge[:,1]] = y[on_edge[:, 0], on_edge[:,1]]
    return x

In [None]:
def process_inputs(X, Y, boardSz, ez=False):
    is_input = []
    X_in = []
    Y_in = []
    

    val = boardSz*boardSz

    zeros = np.zeros(val)
    ones = np.ones((val, val))
    one_hot = np.eye(val)

    for i in range(len(X)):
      x = X[i]
      y = Y[i]
      if is_hard(x, boardSz) > 0 and ez:
        x = fill_in_edges(x, y, boardSz)
        X[i] = x
      Y_in.append(torch.tensor(one_hot[y - 1].flatten(), dtype=torch.float))
      X_in.append(torch.tensor(np.vstack([one_hot, zeros])[x- 1].flatten(), dtype=torch.float))
      is_input.append(torch.tensor(np.vstack([ones, zeros])[x - 1].flatten(), dtype=torch.int32))
    return torch.stack(X_in), torch.stack(Y_in), torch.stack(is_input), X


if BOARD_SZ == 5:
  dataset = 'numbrix_5x5.pk'
elif BOARD_SZ == 6:
  dataset = 'numbrix_6x6.pk'

with open(dataset, 'rb') as f:
    X_in, Y_in = pk.load(f)
    
X, Y, is_input, X_in = process_inputs(X_in, Y_in, args.boardSz, EZ)
X, Y, is_input, X_in = shuffle(X, Y, is_input, X_in, random_state=0)

if args.cuda: X, is_input, Y = X.cuda(), is_input.cuda(), Y.cuda()



N = len(X_in)
nTrain = int(N*0.9)

numbrix_train = TensorDataset(X[:nTrain], is_input[:nTrain], Y[:nTrain])
numbrix_test =  TensorDataset(X[nTrain:], is_input[nTrain:], Y[nTrain:])

In [None]:
print(len(numbrix_train))
print(len(numbrix_test))

21150
2350


In [None]:
cd /content/SATNet

/content/SATNet


## Numbrix Example


In [None]:
print(X_in[0])

[[20  0 22 23  0 25]
 [ 0  0  0  0  0  0]
 [18 11  8  3  4 27]
 [17 12  7  6  5 28]
 [ 0  0  0  0  0  0]
 [15  0 33 34  0 36]]


In [None]:
torch.set_printoptions(threshold=10_000, linewidth=110)
print(X[0])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 

In [None]:
'''
Check percentage of unsolved boards w/empty cells on the edge 
'''
def get_difficulty(puzzles, n):
  num_hard = 0
  for puzzle in puzzles:
    empty_idx = (torch.tensor(puzzle) == 0).nonzero(as_tuple=False)
    on_edge = ((empty_idx==0) | (empty_idx == (n-1))).sum(dim=1)
    if is_hard(puzzle, n) > 0:
      num_hard += 1

  return num_hard / len(puzzles)
  
prop_difficult = get_difficulty(X_in, BOARD_SZ)
print(f"{round(prop_difficult*100, 2):.2f}% of boards have empty cells on the edges")

21.78% of boards have empty cells on the edges


# Experiments
The results for NxN Numbrix are below. We input an unsolved NxN Numbrix board to the SATNet model and output the models guess at the solved board

In [None]:
'''
Training SATNet model
'''
def run(boardSz, epoch, model, optimizer, logger, dataset, batchSz, to_train=False):

    loss_final, err_final = 0, 0

    loader = DataLoader(dataset, batch_size=batchSz)
    tloader = tqdm(enumerate(loader), total=len(loader))

    print("loader len: ", len(loader))

    for i,(data,is_input,label) in tloader:
        if to_train: optimizer.zero_grad()
        preds = model(data.contiguous(), is_input.contiguous())
        loss = nn.functional.binary_cross_entropy(preds, label)

        if to_train:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        err = computeErr(preds.data, label, boardSz)
        tloader.set_description('Epoch {} {} Loss {:.4f} Err: {:.4f}'.format(epoch, ('Train' if to_train else 'Test '), loss.item(), err/batchSz))
        loss_final += loss.item()*(data.size(0)/len(dataset))
        err_final += err

    loss_final, err_final = loss_final, err_final/len(dataset)
    logger.log((epoch, loss_final, err_final))

    if not to_train:
        print('TESTING SET RESULTS: Average loss: {:.4f} Err: {:.4f}'.format(loss_final, err_final))

    torch.cuda.empty_cache()

def train(args, epoch, model, optimizer, logger, dataset, batchSz):
    run(args, epoch, model, optimizer, logger, dataset, batchSz, True)

@torch.no_grad()
def test(args, epoch, model, optimizer, logger, dataset, batchSz):
    run(args, epoch, model, optimizer, logger, dataset, batchSz, False)

@torch.no_grad()
def computeErr(pred_flat, label_flat, n):
    nsq = n*n
    pred = pred_flat.view(-1, n, n, nsq)
    batchSz = pred.size(0)
    I = (torch.max(pred, 3)[1] + 1).squeeze().view(batchSz, nsq)
    label = label_flat.view(-1, n, n, nsq)
    assert batchSz == label.size(0)

    L = (torch.max(label, 3)[1] + 1).squeeze().view(batchSz, nsq)

    boardCorrect = torch.all(I == L, dim=1)
    return float(batchSz-boardCorrect.sum())


In [None]:
file = 'numbrix6_800_500_lr1e-4-23500-final.pt'

try:
  numbrix_5 = torch.load("/content/drive/MyDrive/Numbrix-data/" + file)

except FileNotFoundError: 
  numbrix_5 = NumbrixSolver(args.boardSz, args.aux, args.m)
  if args.cuda: numbrix_5 = numbrix_5.cuda()

  plt.ioff()
  optimizer = optim.Adam(numbrix_5.parameters(), lr=args.lr)
  #optimizer = optim.Adam(numbrix_6.parameters(), lr=args.lr, weight_decay=1e-6 )
  fig, axes = plt.subplots(1,2, figsize=(10,4))
  plt.subplots_adjust(wspace=0.4)
  train_logger = FigLogger(fig, axes[0], 'Traininig')
  test_logger = FigLogger(fig, axes[1], 'Testing')

  test(args.boardSz, 0, numbrix_5, optimizer, test_logger, numbrix_test, args.batchSz)
  #plt.pause(0.01)
  for epoch in range(1, args.nEpoch+1):
      train(args.boardSz, epoch, numbrix_5, optimizer, train_logger, numbrix_train, args.batchSz)
      test(args.boardSz, epoch, numbrix_5, optimizer, test_logger, numbrix_test, args.batchSz)
      if epoch % 10 == 1:
        display(fig)
        torch.save(numbrix_5, "/content/drive/MyDrive/Numbrix-data/" + file)
  display(fig)

  torch.save(numbrix_5, "/content/drive/MyDrive/Numbrix-data/" + file)


In [None]:
#plt.savefig('numb_5x5_200_500_single.jpeg',  dpi=500)

In [None]:
'''
Get the neighbours of a Numbrix cell
'''
def get_neighbours(x, y, max_idx):
  if y>0: up = (x, y-1) 
  else: up = None

  if y<max_idx: down = (x, y+1)
  else: down = None

  if x>0: left = (x-1, y) 
  else: left = None

  if x<max_idx: right = (x+1, y) 
  else: right = None

  neighbours = [up, down, left, right]
  row_idx, col_idx = [cell[0] for cell in neighbours if cell is not None], [cell[1] for cell in neighbours if cell is not None]
  return row_idx, col_idx

def in_order(board, n):
  num_in_order = 0
  nsq = n*n
  for i in range(n):
    for j in range(n):
      cell_val = board[i, j]
      row_idx, col_idx = get_neighbours(i, j, n-1)
      neighbours = board[row_idx, col_idx]

      if cell_val < nsq and cell_val+1 in neighbours:
        num_in_order += 1
      if cell_val == nsq and cell_val - 1 in neighbours:
        num_in_order += 1
  return num_in_order 

def is_unique(boards):
  count = F.one_hot(boards).sum(dim=1)
  return torch.all(count == 1, dim=1)

def count_duplicates(board):
  count = F.one_hot(board).sum(dim=1)
  return (count > 1).sum()

In [None]:
def get_stats(pred_flat, n):
    nsq = n*n
    pred = pred_flat.view(-1, n, n, nsq)
    batchSz = pred.size(0)
    I = (torch.max(pred, 3)[1]).squeeze().view(batchSz, nsq)
    uniq = float(is_unique(I).sum())
    dups = float(count_duplicates(I))

    I = (torch.max(pred, 3)[1] + 1).squeeze().view(n, n).detach()
    order = in_order(I, n)

    return uniq, order, dups 

def print_board(pred_flat, n, label_flat):
    nsq = n*n
    pred = pred_flat.view(-1, n, n, nsq)
    I = (torch.max(pred, 3)[1] + 1).squeeze().view(n, n)


    label = label_flat.view(-1, n, n, nsq)
    L = (torch.max(label, 3)[1] + 1).squeeze().view(n, n)
    print("\nModel Output")
    print(I)
    print("\nCorrect Solution")
    print(L)
    print("")

@torch.no_grad()
def examine_output(boardSz, model, dataset, X_in, ez=False):
    num_correct, examples = 0, 0
    prop_uniq, err_final, avg_in_order, num_in_order, num_dups = 0, 0, 0, 0, 0
    if not ez:
      solved_hard, uns_hard = 0, 0
    loader = DataLoader(dataset)
    tloader = tqdm(enumerate(loader), total=len(loader))

    test_err = 0
    counter, s_counter = 0, 0
    for i,(data,is_input,label) in tloader:
        rand = random.random()
        preds = model(data.contiguous(), is_input.contiguous())
        if not ez:
          hard_puzz = is_hard(X_in[i], boardSz) > 0
        err = computeErr(preds.data, label, boardSz)
        uniq, order, dups = get_stats(preds, boardSz) 

        if order == boardSz*boardSz and uniq == 1:
            test_err += 1
            if not ez and hard_puzz:
              solved_hard += 1
            s_counter += 1

        if err != 0:
          if order == boardSz*boardSz:
            num_in_order += 1
          if not ez and hard_puzz:
            uns_hard += 1
          prop_uniq += uniq
          avg_in_order += order
          num_dups += dups 
          counter += 1


        if err == 0 and num_correct < 1:
          print("Solved board")
          print_board(preds, boardSz, label)
          num_correct += 1
          print("")
        if err > 0 and examples < 1 and rand < 0.2:
          print("Unsolver board")
          print_board(preds, boardSz, label)        
          examples += 1
          print("")

        err_final += err


    print('Unique boards {:.4f}. Percent error: {:.4f}'.format(prop_uniq/len(loader), err_final/len(loader)))
    print('% correctly ordered boards: {:.4f}. Avg # correctly ordered cells per incorrect board: {}'.format(num_in_order/len(loader), round(avg_in_order/counter)))
    print('Avg. # of duplicates per incorrect board: {:.1f}'.format(num_dups/counter))
    if not ez:
      print('{:.2f}% of solved puzzles are hard and {:.2f}% of unsolved'.format(solved_hard/(s_counter)*100, uns_hard/counter*100))
    print(test_err/len(loader))

examine_output(args.boardSz, numbrix_5, numbrix_test, X_in[nTrain:], EZ)

  0%|          | 0/2350 [00:00<?, ?it/s]

Solved board

Model Output
tensor([[33, 34, 35, 36, 11, 10],
        [32, 23, 22, 13, 12,  9],
        [31, 24, 21, 14,  1,  8],
        [30, 25, 20, 15,  2,  7],
        [29, 26, 19, 16,  3,  6],
        [28, 27, 18, 17,  4,  5]], device='cuda:0')

Correct Solution
tensor([[33, 34, 35, 36, 11, 10],
        [32, 23, 22, 13, 12,  9],
        [31, 24, 21, 14,  1,  8],
        [30, 25, 20, 15,  2,  7],
        [29, 26, 19, 16,  3,  6],
        [28, 27, 18, 17,  4,  5]], device='cuda:0')


Unsolver board

Model Output
tensor([[ 6,  7, 10, 11, 12, 13],
        [ 5,  8,  9, 24, 19, 14],
        [ 4, 21, 20, 19, 18, 15],
        [ 3,  2,  1, 24, 17, 16],
        [28, 27, 26, 25, 34, 35],
        [29, 30, 31, 32, 33, 36]], device='cuda:0')

Correct Solution
tensor([[ 6,  7, 10, 11, 12, 13],
        [ 5,  8,  9, 20, 19, 14],
        [ 4,  1, 22, 21, 18, 15],
        [ 3,  2, 23, 24, 17, 16],
        [28, 27, 26, 25, 34, 35],
        [29, 30, 31, 32, 33, 36]], device='cuda:0')


Unique boards 0.

In [24]:
examine_output(args.boardSz, numbrix_5, numbrix_train, X_in[:nTrain], EZ)  


  0%|          | 0/21150 [00:00<?, ?it/s]

Solved board

Model Output
tensor([[27, 28, 29, 30, 33, 34],
        [26, 25, 24, 31, 32, 35],
        [21, 22, 23, 16, 15, 36],
        [20, 19, 18, 17, 14,  1],
        [ 9, 10, 11, 12, 13,  2],
        [ 8,  7,  6,  5,  4,  3]], device='cuda:0')

Correct Solution
tensor([[27, 28, 29, 30, 33, 34],
        [26, 25, 24, 31, 32, 35],
        [21, 22, 23, 16, 15, 36],
        [20, 19, 18, 17, 14,  1],
        [ 9, 10, 11, 12, 13,  2],
        [ 8,  7,  6,  5,  4,  3]], device='cuda:0')


Unsolver board

Model Output
tensor([[31, 32, 33, 34, 35, 36],
        [30, 21, 20, 19, 18, 17],
        [20, 22,  1,  2, 15, 16],
        [11, 23,  4,  3, 14, 13],
        [27, 24,  5,  8,  9, 12],
        [26, 25,  6,  7, 10, 11]], device='cuda:0')

Correct Solution
tensor([[31, 32, 33, 34, 35, 36],
        [30, 21, 20, 19, 18, 17],
        [29, 22,  1,  2, 15, 16],
        [28, 23,  4,  3, 14, 13],
        [27, 24,  5,  8,  9, 12],
        [26, 25,  6,  7, 10, 11]], device='cuda:0')


Unique boards 0.

In [25]:
def get_average_filled(puzzles, n):
  cells = n*n
  num_filled = 0
  for puzzle in puzzles:
    num_filled += cells - (puzzle == 0).sum()
  return num_filled/len(puzzles)
  
avg_filled = get_average_filled(X_in, BOARD_SZ)
print(f'On average {avg_filled:.2f} cells are filled.\n~{round(avg_filled/(BOARD_SZ**2)*100, 2)}% of the board')

On average 20.00 cells are filled.
~55.56% of the board
