In [None]:
!git clone https://github.com/locuslab/SATNet
%cd SATNet
!python setup.py develop > install.log 2>&1

Cloning into 'SATNet'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 101 (delta 0), reused 2 (delta 0), pack-reused 96[K
Receiving objects: 100% (101/101), 497.29 KiB | 3.68 MiB/s, done.
Resolving deltas: 100% (37/37), done.
/content/SATNet


In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
import os
import shutil
import argparse
from collections import namedtuple

import numpy as np
import numpy.random as npr

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

import matplotlib.pyplot as plt
from IPython.display import display, Markdown, Latex, clear_output
import tqdm
import pickle as pk

import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from tqdm.auto import tqdm
import random 
random.seed(6)

import io 

from sklearn.utils import shuffle

if not torch.cuda.is_available(): 
    print('[WARNING] Not using GPU.')
    print('Please select "Runtime -> Change runtime type" and switch to GPU for better performance')
else:
    print('Using', torch.cuda.get_device_name(0))

Using Tesla T4


# SATNet

Based on model from [SATNet paper](https://icml.cc/Conferences/2019/Schedule?showEvent=3947).

In [None]:
import satnet
print('SATNet document\n', satnet.SATNet.__doc__)

SATNet document
 Apply a SATNet layer to complete the input probabilities.

    Args:
        n: Number of input variables.
        m: Rank of the clause matrix.
        aux: Number of auxiliary variables.

        max_iter: Maximum number of iterations for solving
            the inner optimization problem.
            Default: 40
        eps: The stopping threshold for the inner optimizaiton problem.
            The inner Mixing method will stop when the function decrease
            is less then eps times the initial function decrease.
            Default: 1e-4
        prox_lam: The diagonal increment in the backward linear system
            to make the backward pass more stable.
            Default: 1e-2
        weight_normalize: Set true to perform normlization for init weights.
            Default: True

    Inputs: (z, is_input)
        **z** of shape `(batch, n)`: 
            Float tensor containing the probabilities (must be in [0,1]).
        **is_input** of shape `(batch, 

# Building SATNet-based Models
To solve **Numbrix** we construct a SATNet-based Numbrix Solver layer that takes as input a logical (bit) representation of the initial Numbrix board along with a mask representing which bits must be learned (i.e. all bits in empty cells). This input is vectorized. Given this input, the Numbrix Solver layer then outputs a bit representation of the Numbrix board with guesses for the unknown bits.

In [None]:
class NumbrixSolver(nn.Module):
    def __init__(self, boardSz, aux, m):
        super(NumbrixSolver, self).__init__()
        n = boardSz**4
        self.sat = satnet.SATNet(n, m, aux)

    def forward(self, y_in, mask):
        out = self.sat(y_in, mask)
        del y_in, mask
        return out

In [None]:
from exps.sudoku import FigLogger
args_dict = {'lr': 1e-4, # 2e-3,
             'cuda': torch.cuda.is_available(), 
             'batchSz': 40,
             'boardSz': 5, # for 5x5 Numbrix
             'm': 500,
             'aux': 200,
             'nEpoch': 80
            }
args = namedtuple('Args', args_dict.keys())(*args_dict.values())



EZ = False # Easy
#EZ = False # Hard

# The Numbrix Dataset

In [None]:
cd /content/drive/MyDrive/Numbrix-data

/content/drive/MyDrive/Numbrix-data


In [None]:
'''
Check if there are empty cells on the edges of the puzzle
'''
def is_hard(puzzle, n):
    empty_idx = (torch.tensor(puzzle) == 0).nonzero(as_tuple=False)
    on_edge = ((empty_idx==0) | (empty_idx == n-1)).sum(dim=1)
    return on_edge.sum() > 0

'''
Fill in empty puzzle edges in x with values from the solution y
'''
def fill_in_edges(x, y, n):
    empty_idx = (torch.tensor(x) == 0).nonzero()
    on_edge = ((empty_idx==0) | (empty_idx == n-1))
    on_edge = empty_idx[on_edge.sum(dim=1) == 1]
    x[on_edge[:, 0], on_edge[:,1]] = y[on_edge[:, 0], on_edge[:,1]]
    return x

In [None]:
def process_inputs(X, Y, boardSz, ez=False):
    is_input = []
    X_in = []
    Y_in = []
    

    val = boardSz*boardSz

    zeros = np.zeros(val)
    ones = np.ones((val, val))
    one_hot = np.eye(val)

    for i in range(len(X)):
      if i == 10000:
        break
      x = X[i]
      y = Y[i]
      if is_hard(x, boardSz) > 0 and ez:
        x = fill_in_edges(x, y, boardSz)
        X[i] = x
      Y_in.append(torch.tensor(one_hot[y - 1].flatten(), dtype=torch.float))
      X_in.append(torch.tensor(np.vstack([one_hot, zeros])[x- 1].flatten(), dtype=torch.float))
      is_input.append(torch.tensor(np.vstack([ones, zeros])[x - 1].flatten(), dtype=torch.int32))
    return torch.stack(X_in), torch.stack(Y_in), torch.stack(is_input), X



with open('numbrix_5x5.pk', 'rb') as f:
    X_in, Y_in = pk.load(f)
    
X, Y, is_input, X_in = process_inputs(X_in, Y_in, args.boardSz, EZ)
X, Y, is_input, X_in = shuffle(X, Y, is_input, X_in, random_state=0)

if args.cuda: X, is_input, Y = X.cuda(), is_input.cuda(), Y.cuda()



N = len(X_in)
nTrain = int(N*0.9)


In [None]:
# Number of Models in the bag
NUM_BAG_MODELS = 5
SUBSAMPLE_PERCENT = 0.8

train_datasets = []

for i in range(NUM_BAG_MODELS):
  idxs = random.sample(range(nTrain), int(nTrain*SUBSAMPLE_PERCENT))
  train_datasets.append(TensorDataset(X[idxs], is_input[idxs], Y[idxs]))
  
numbrix_test =  TensorDataset(X[nTrain:], is_input[nTrain:], Y[nTrain:])

In [None]:
print(len(train_datasets[0]))
print(len(train_datasets))
print(len(numbrix_test))

6120
5
850


In [None]:
cd /content/SATNet

/content/SATNet


## Numbrix Example


In [None]:
print(X_in[0])

[[11 10  9  8  7]
 [12  0  0  0  6]
 [13  0  0  0  1]
 [14  0  0  0 20]
 [15 16 17 18 19]]


In [None]:
torch.set_printoptions(threshold=10_000, linewidth=110)
print(X[0])

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 

In [None]:
'''
Check percentage of unsolved boards w/empty cells on the edge 
'''
def get_difficulty(puzzles, n):
  num_hard = 0
  for puzzle in puzzles:
    empty_idx = (torch.tensor(puzzle) == 0).nonzero(as_tuple=False)
    on_edge = ((empty_idx==0) | (empty_idx == 4)).sum(dim=1)
    if is_hard(puzzle, n) > 0:
      num_hard += 1

  return num_hard / len(puzzles)
  
prop_difficult = get_difficulty(X_in, 5)
print(f"{round(prop_difficult*100, 2):.2f}% of boards have empty cells on the edges")

32.18% of boards have empty cells on the edges


# Experiments
The results for training a stacking ensemble of 5x5 Numbrix models are below. We train each 5x5 Numbrix on a  SATNet model with different $m$ and $aux$ parameters. 

In [None]:
from google.colab import output

'''
Training SATNet model
'''
def run(boardSz, epoch, model, optimizer, logger, dataset, batchSz, to_train=False):

    loss_final, err_final = 0, 0

    loader = DataLoader(dataset, batch_size=batchSz)
    tloader = tqdm(enumerate(loader), total=len(loader))

    print("loader len: ", len(loader))

    for i,(data,is_input,label) in tloader:
        if to_train: optimizer.zero_grad()
        preds = model(data.contiguous(), is_input.contiguous())
        loss = nn.functional.binary_cross_entropy(preds, label)

        if to_train:
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        err = computeErr(preds.data, label, boardSz)
        tloader.set_description('Epoch {} {} Loss {:.4f} Err: {:.4f}'.format(epoch, ('Train' if to_train else 'Test '), loss.item(), err/batchSz))
        loss_final += loss.item() * (data.size(0)/len(dataset)) #weight by batchsize incase not equal across all examples
        err_final += err

    loss_final, err_final = loss_final, err_final/len(dataset)
    logger.log((epoch, loss_final, err_final))

    if not to_train:
        print('TESTING SET RESULTS: Average loss: {:.4f} Err: {:.4f}'.format(loss_final, err_final))

    torch.cuda.empty_cache()

def train(args, epoch, model, optimizer, logger, dataset, batchSz):
    run(args, epoch, model, optimizer, logger, dataset, batchSz, True)

@torch.no_grad()
def test(args, epoch, model, optimizer, logger, dataset, batchSz):
    run(args, epoch, model, optimizer, logger, dataset, batchSz, False)

@torch.no_grad()
def computeErr(pred_flat, label_flat, n):
    nsq = n*n
    pred = pred_flat.view(-1, n, n, nsq)
    batchSz = pred.size(0)
    I = (torch.max(pred, 3)[1] + 1).squeeze().view(batchSz, nsq)
    label = label_flat.view(-1, n, n, nsq)
    assert batchSz == label.size(0)

    L = (torch.max(label, 3)[1] + 1).squeeze().view(batchSz, nsq)

    boardCorrect = torch.all(I == L, dim=1)
    return float(batchSz-boardCorrect.sum())


In [None]:
models = []
aux = [200, 300, 300, 400, 500]
m = [500, 600, 800, 900, 1000]
for i in range(NUM_BAG_MODELS):
  numbrix_5x5 = NumbrixSolver(args.boardSz, aux[i], m[i])
  models.append(numbrix_5x5)


for i in range(4, NUM_BAG_MODELS):
  print("Training Model -->", i+1 )
  if args.cuda: models[i] = models[i].cuda()

  plt.ioff()
  optimizer = optim.Adam(models[i].parameters(), lr=args.lr)

  fig, axes = plt.subplots(1,2, figsize=(10,4))
  plt.subplots_adjust(wspace=0.4)
  train_logger = FigLogger(fig, axes[0], 'Traininig')
  test_logger = FigLogger(fig, axes[1], 'Testing')

  test(args.boardSz, 0, models[i], optimizer, test_logger, numbrix_test, args.batchSz)

  for epoch in range(1, args.nEpoch+1):
      train(args.boardSz, epoch, models[i], optimizer, train_logger, train_datasets[i], args.batchSz)
      test(args.boardSz, epoch, models[i], optimizer, test_logger, numbrix_test, args.batchSz)
      if epoch%20 == 1:
        display(fig)
      if epoch%10 == 1:
        torch.save(models[i].state_dict(), f'/content/drive/MyDrive/Numbrix-data/ensemb_1e-4_bag {i}')

  display(fig)

  torch.save(models[i].state_dict(), f'/content/drive/MyDrive/Numbrix-data/ensemb_1e-4_bag {i}')
  output.clear()



In [None]:
ensemble = []

for i in range(NUM_BAG_MODELS):
  numbrix_5x5 = NumbrixSolver(args.boardSz, aux[i], m[i])
  numbrix_5x5.load_state_dict(torch.load(f'/content/drive/MyDrive/Numbrix-data/ensemb_1e-4_bag {i}'))
  ensemble.append(numbrix_5x5)

In [None]:
'''
Get the neighbours of a Numbrix cell
'''
def get_neighbours(x, y, max_idx):
  if y>0: up = (x, y-1) 
  else: up = None

  if y<max_idx: down = (x, y+1)
  else: down = None

  if x>0: left = (x-1, y) 
  else: left = None

  if x<max_idx: right = (x+1, y) 
  else: right = None

  neighbours = [up, down, left, right]
  row_idx, col_idx = [cell[0] for cell in neighbours if cell is not None], [cell[1] for cell in neighbours if cell is not None]
  return row_idx, col_idx

def in_order(board, n):
  num_in_order = 0
  nsq = n*n
  for i in range(n):
    for j in range(n):
      cell_val = board[i, j]
      row_idx, col_idx = get_neighbours(i, j, n-1)
      neighbours = board[row_idx, col_idx]

      if cell_val < nsq and cell_val+1 in neighbours:
        num_in_order += 1
      if cell_val == nsq and cell_val - 1 in neighbours:
        num_in_order += 1
  return num_in_order 

def is_unique(boards):
  count = F.one_hot(boards).sum(dim=1)
  return torch.all(count == 1, dim=1)


In [None]:
@torch.no_grad()
def get_stats(pred_flat, n):
    nsq = n*n
    pred = pred_flat.view(-1, n, n, nsq)
    batchSz = pred.size(0)
    I = (torch.max(pred, 3)[1]).squeeze().view(batchSz, nsq)
    uniq = is_unique(I).sum()
    I = (torch.max(pred, 3)[1] + 1).squeeze().view(n, n).detach()
    order = in_order(I, n)
    return float(uniq), order 

In [None]:
correct = 0 
correct_per = [0 for _ in range(len(ensemble))]

loader = DataLoader(numbrix_test)
tloader = tqdm(enumerate(loader), total=len(loader))
for i,(data,is_input,label) in tloader:
  c = 0
  for j, model in enumerate(ensemble):
    if args.cuda: model = model.cuda()
    model.eval()
    preds = model(data.contiguous(), is_input.contiguous())
    unique_vals, ordered = get_stats(preds, args.boardSz)
    if ordered == args.boardSz*args.boardSz and unique_vals == 1:
      c = 1
      correct_per[j] += 1 
  correct += c
print("Total Correct: ", correct)
print("Accuracy: ", correct/len(loader))
print("Correct each: ", correct_per)

  0%|          | 0/850 [00:00<?, ?it/s]

Total Correct:  779
Accuracy:  0.9164705882352941
Correct each:  [664, 688, 694, 708, 702]


## Additional Aggregation Methods not used in the Paper

In [None]:
def get_board(pred_flat, n):
    nsq = n*n
    pred = pred_flat.view(-1, n, n, nsq)
    batchSz = pred.size(0)
    I = (torch.max(pred, 3)[1] + 1).squeeze().view(batchSz, nsq).detach()
    return I

In [None]:
correct = 0 
loader = DataLoader(numbrix_test, batch_size = args.batchSz)
tloader = tqdm(enumerate(loader), total=len(loader))

for i,(data,is_input,label) in tloader:
  boards = []
  for j, model in enumerate(ensemble):
    if args.cuda: model = model.cuda()
    model.eval()
    preds = model(data.contiguous(), is_input.contiguous())
    board = get_board(preds, args.boardSz)
    boards.append(board)
  max_vote = torch.mode(torch.stack(boards), dim=0)[0]
  expected_output = get_board(label, args.boardSz)
  isCorrect = torch.all(max_vote == expected_output, dim=1).sum()
  correct +=int(isCorrect)
print("Total Correct: ", correct)
print("Accuracy: ", correct/len(numbrix_test))

  0%|          | 0/22 [00:00<?, ?it/s]

Total Correct:  727
Accuracy:  0.8552941176470589


In [None]:
correct = 0 
loader = DataLoader(numbrix_test, batch_size = args.batchSz)
tloader = tqdm(enumerate(loader), total=len(loader))

for i,(data,is_input,label) in tloader:
  predictions = []
  for j, model in enumerate(ensemble):
    if args.cuda: model = model.cuda()
    model.eval()
    preds = model(data.contiguous(), is_input.contiguous())
    predictions.append(preds)
  avg_prob = torch.mean(torch.stack(predictions), dim=0)
  pred_board = get_board(avg_prob, args.boardSz)
  expected_output = get_board(label, args.boardSz)
  isCorrect = torch.all(pred_board == expected_output, dim=1).sum()
  correct +=int(isCorrect)
print("Total Correct: ", correct)
print("Accuracy: ", correct/len(numbrix_test))

  0%|          | 0/22 [00:00<?, ?it/s]

Total Correct:  736
Accuracy:  0.8658823529411764
