# Sudoku data handlers

API surface: creates objects
train_loader, test_loader, and the function check_sudoku(tensor)

## Code

In [None]:
!pip install -q polars torch torchvision

In [None]:
%pdb on

Automatic pdb calling has been turned ON


In [None]:
import polars as pl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
from matplotlib import pyplot as plt
import time

from torch.utils.data import Dataset, DataLoader

In [None]:
df = pl.read_csv('/content/data/sudoku-small.csv', infer_schema_length=0)

In [None]:
df[3]['puzzle'].item()

'008317000004205109000040070327160904901450000045700800030001060872604000416070080'

008317000 \\
004205109 \\
000040070 \\
327160904 \\
901450000 \\
045700800 \\
030001060 \\
872604000 \\
416070080

In [None]:
torch.tensor([int(digit) for digit in df[3]['puzzle'].item()]).reshape(9,9)

tensor([[0, 0, 8, 3, 1, 7, 0, 0, 0],
        [0, 0, 4, 2, 0, 5, 1, 0, 9],
        [0, 0, 0, 0, 4, 0, 0, 7, 0],
        [3, 2, 7, 1, 6, 0, 9, 0, 4],
        [9, 0, 1, 4, 5, 0, 0, 0, 0],
        [0, 4, 5, 7, 0, 0, 8, 0, 0],
        [0, 3, 0, 0, 0, 1, 0, 6, 0],
        [8, 7, 2, 6, 0, 4, 0, 0, 0],
        [4, 1, 6, 0, 7, 0, 0, 8, 0]])

In [None]:
class SudokuDataset(Dataset):
  def __init__(self, path=None, ds=None):
    if not ds:
      if not path:
        raise Exception("where dataset parameters????")
      else:
        self.ds = pl.read_csv(path, infer_schema_length=0)
    else:
      self.ds = ds

  def __len__(self):
    return len(self.ds)

  def __getitem__(self, idx):
    row = self.ds[idx]

    puzzle = torch.tensor([int(digit) for digit in row['puzzle'].item()]).reshape(9,9)
    # solution = torch.tensor([int(digit) for digit in row['solution'].item()]).reshape(9,9)

    return F.one_hot(puzzle, 10)

In [None]:
dataset = SudokuDataset(path='/content/data/sudoku-small.csv')
len_ds = len(dataset)

train_set, test_set = torch.utils.data.random_split(dataset, (int(0.8*len_ds), len_ds-int(0.8*len_ds)))


In [None]:
batch_size = 1
train_loader, test_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size), DataLoader(train_set, shuffle=True, batch_size=batch_size)

In [None]:
next(iter(test_loader))[0].argmax(dim=-1)

tensor([[6, 0, 0, 0, 1, 7, 4, 0, 0],
        [4, 0, 1, 0, 0, 3, 0, 0, 8],
        [0, 5, 9, 8, 0, 0, 7, 2, 1],
        [1, 2, 0, 0, 0, 0, 0, 5, 0],
        [0, 0, 0, 0, 4, 0, 8, 0, 0],
        [0, 0, 8, 0, 2, 0, 1, 0, 0],
        [0, 0, 4, 5, 3, 0, 0, 0, 7],
        [7, 0, 0, 0, 9, 0, 0, 8, 6],
        [2, 6, 3, 1, 7, 0, 0, 0, 0]])

In [None]:
def is_row_valid(board, row, num):
    return num not in board[row]

def is_col_valid(board, col, num):
    return num not in board[:, col]

def is_subgrid_valid(board, x, y, num):
    subgrid_size = 3
    startRow, startCol = subgrid_size * (x // subgrid_size), subgrid_size * (y // subgrid_size)
    return num not in board[startRow:startRow + subgrid_size, startCol:startCol + subgrid_size]


def has_legal_moves(board): # avoid cycling through all moves for speed - Danny
    n = 9
    empty_cells = np.argwhere(board == 0)
    for x, y in empty_cells:
        for num in range(1, n + 1):
            if is_move_legal(board, x, y, num):
                return True
    return False

def is_move_legal(board, x, y, num):
  n = 9
  return is_row_valid(board, x, num) and \
            is_col_valid(board, y, num) and \
            is_subgrid_valid(board, x, y, num) and \
            0 < num <= n



def get_legal_moves(board):
      moves = []
      n = 9
      empty_cells = torch.chunk(np.argwhere(board == 0), 2)
      ec0 = empty_cells[0].reshape(-1)
      ec1 = empty_cells[1].reshape(-1)

      for x, y in zip(ec0, ec1):
          for num in range(1, n + 1):
              if is_move_legal(board, x, y, num):
                  moves.append((x, y, num))
      return moves

def get_valid_moves(board):
    # Return a binary vector where each entry indicates if placing a number (1-9) in a cell (row, col) is valid
    n = 9
    # 3D
    valid_moves = np.zeros((n, n, n+1))
    moves = get_legal_moves(board) # list of (x, y, num) tuples
    for x, y, num in moves:
        valid_moves[x, y, num] = 1
    return valid_moves


In [None]:
def is_row_valid_torch(board, num):
    return ~torch.any(board == num, dim=1)

def is_col_valid_torch(board, num):
    return ~torch.any(board == num, dim=0)

def is_subgrid_valid_torch(board, num):
    subgrid_size = 3
    n = 9
    subgrids = board.view(n // subgrid_size, subgrid_size, -1, subgrid_size).transpose(1, 2)
    exists_in_subgrids = torch.any(subgrids == num, dim=3).any(dim=2)
    return ~exists_in_subgrids

def is_move_legal_torch(board, num):
    n = 9
    row_valid = is_row_valid_torch(board, num)
    col_valid = is_col_valid_torch(board, num)
    subgrid_valid = is_subgrid_valid_torch(board, num)
    return row_valid & col_valid & subgrid_valid

def has_legal_moves_torch(board):
    n = 9
    for num in range(1, n + 1):
        if torch.any(is_move_legal_torch(board, torch.tensor(num))):
            return True
    return False

def get_valid_moves_torch(board):
    n = 9
    valid_moves = torch.zeros((n, n, n + 1), dtype=torch.bool)
    for num in range(1, n + 1):
        legal_move_mask = is_move_legal_torch(board, torch.tensor(num))
        valid_moves[:, :, num] = legal_move_mask
    return valid_moves

In [None]:
x_test = next(iter(test_loader))[0]

In [None]:
has_legal_moves_torch(x_test)

RuntimeError: ignored

> [0;32m<ipython-input-23-237b30701a04>[0m(19)[0;36mis_move_legal_torch[0;34m()[0m
[0;32m     17 [0;31m    [0mcol_valid[0m [0;34m=[0m [0mis_col_valid_torch[0m[0;34m([0m[0mboard[0m[0;34m,[0m [0mnum[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m    [0msubgrid_valid[0m [0;34m=[0m [0mis_subgrid_valid_torch[0m[0;34m([0m[0mboard[0m[0;34m,[0m [0mnum[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m    [0;32mreturn[0m [0mrow_valid[0m [0;34m&[0m [0mcol_valid[0m [0;34m&[0m [0msubgrid_valid[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m[0;34m[0m[0m
[0m[0;32m     21 [0;31m[0;32mdef[0m [0mhas_legal_moves_torch[0m[0;34m([0m[0mboard[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> w
  [0;32m<ipython-input-24-165b195e845b>[0m(1)[0;36m<cell line: 1>[0;34m()[0m
[0;32m----> 1 [0;31m[0mhas_legal_moves_torch[0m[0;34m([0m[0mx_test[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m



sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/debugger.py", line 1075, in cmdloop
    sys.settrace(None)



--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user


# Model section

# New Idea

its hard to predict the correct moves -- what if we let the model learn it implicitly :)

In [None]:
from tqdm.notebook import tqdm

In [None]:
x = next(iter(test_loader))[0].argmax(dim=-1).numpy()
x.shape

(9, 9)

In [None]:
get_valid_moves(x)

In [None]:
def calculate_reward(x: torch.TensorType):
  num_solved = torch.count_nonzero(x.argmax(dim=-1)).item()
  if num_solved == 81:
    return 3
  else:
    return 1./num_solved

In [None]:
class StateFlow(nn.Module):
  def __init__(self, num_hidden=512):
    super().__init__()
    self.mlp = nn.Sequential(
        nn.Linear(810, num_hidden),
        nn.LeakyReLU(),
        nn.Linear(num_hidden, 729*2)
    ) # predict 810 states, mask out the ones we can't take (ie already filled or would conflict) -> sike we just let the model figure out the rest lmao

    self.logZ = nn.Parameter(torch.ones(1))

  def forward(self, x, init_board):
    b = x.size(0)

    valid_moves = torch.stack([torch.tensor(get_valid_moves(state)) for state in x.argmax(dim=-1)])

    logits = self.mlp(x.float().view(x.size(0),-1))

    # breakpoint()
    pf = logits[:, :729].masked_fill((1-valid_moves[:, :, :, 1:]).bool().reshape(b, -1), -1e10)
    # breakpoint()
    pb = logits[:, 729:] * (1-init_board[:, :, :, 1:].reshape(b, -1)) * x[:, :, :, 1:].reshape(b, -1) * -10
    return pf, pb

In [None]:
model = StateFlow()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 50000


losses = []
logZs = []

generated_boards = []
minibatch_loss = 0
update_freq = 2

pbar = tqdm(range(5000))
for episode in pbar:
  board = next(iter(train_loader)).float()
  initial_board = board.clone().detach()

  pf, pb = model(board, initial_board)

  total_pf = 0
  total_pb = 0

  num_moves_required = 81-torch.count_nonzero(initial_board.argmax(dim=-1))

  # breakpoint()
  for t in range(num_moves_required.item()):
    cat = torch.distributions.Categorical(logits=pf)
    action = cat.sample()

    total_pf += cat.log_prob(action)

    n=9

    x = action // (n ** 2)
    y = (action % (n ** 2)) // n
    z = (action % (n ** 2)) % n + 1
    # breakpoint()
    new_move = torch.zeros((9,9,10))
    new_move[x, y, z] = 1
    new_move[x, y, 0] = -1

    # print(board.argmax(dim=-1),'\n', new_move.argmax(dim=-1))
    board = board + new_move

    if torch.count_nonzero(torch.tensor(get_valid_moves(board.squeeze(0).argmax(dim=-1)))).item() == 0:
      # breakpoint()
      reward = torch.tensor(calculate_reward(board))

    pf, pb = model(board, initial_board)

    total_pb = torch.distributions.Categorical(logits=pb).log_prob(action)
  # breakpoint()
  loss = (model.logZ + total_pf - torch.log(reward).clip(-20) - total_pf).pow(2)
  minibatch_loss += loss

  generated_boards.append(board)

  if episode % update_freq == 0:
    losses.append(minibatch_loss.item())
    if episode % 10 == 0:
        pbar.update(10)
        pbar.set_description(f"Loss: {minibatch_loss.item():.4f}")
    minibatch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    minibatch_loss = 0
    logZs.append(model.logZ.item())

  0%|          | 0/5000 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), "./content/")

In [None]:
plt.plot(np.arange(0, 5000//update_freq), losses)

In [None]:
plt.plot(np.arange(0, 5000//update_freq), logz)

In [None]:
tensor([[[6, 1, 3, 7, 5, 9, 4, 2, 8],
         [4, 5, 7, 3, 8, 2, 1, 9, 6],
         [9, 8, 2, 0, 4, 6, 7, 3, 5],
         [1, 2, 6, 4, 3, 7, 5, 8, 9],
         [5, 3, 8, 6, 9, 1, 2, 7, 4],
         [7, 9, 4, 5, 2, 8, 3, 6, 1],
         [2, 6, 5, 9, 1, 3, 8, 4, 7],
         [8, 7, 1, 2, 6, 4, 9, 5, 3],
         [3, 4, 9, 8, 7, 5, 6, 1, 2]]])

# old stuff

In [None]:
# from typing import bool

def check_valid_sudoku(puzz: torch.TensorType) -> bool:
  # check rows & columns
  puzz = puzz.argmax(dim=-1)
  for i in range(9):
    row = puzz[i, :]

    nonzero_row = row[row != 0]

    if len(torch.unique(nonzero_row)) != len(nonzero_row):
      return False

    col = puzz[: , i]

    nonzero_col = col[col != 0]

    if len(torch.unique(nonzero_col)) != len(nonzero_col):
      return False

  # check subgrids
  for i in range(0, 9, 3):
    for j in range(0, 9, 3):
      subgrid = puzz[i:i+3, j:j+3]

      nonzero = subgrid[subgrid != 0].view(-1)

      if len(torch.unique(nonzero)) != len(nonzero):
        return False

  return True

In [None]:
def check_finished(puzz: torch.TensorType) -> bool:
  if torch.count_nonzero(puzz.argmax(dim=-1)) != 81:
    return False
  else:
    return True

In [None]:
def calculate_reward(puzz: torch.TensorType):
  if not check_valid_sudoku(puzz) or not check_finished(puzz):
    return 0

  # sudoku finished successfully :party:
  # elif len(torch.unique(torch.where(puzz.view(-1) > 0))) == 81:

  elif torch.count_nonzero(puzz.argmax(dim=-1)).item() == 81:
    return 3

  else:
    # previously 1, but that seems like an improper implementation, as non-terminal states should not have any reward.
    return 0

In [None]:
x,y = next(iter(test_loader))
#x.shape, y.shape

#
#calculate_reward(x.view(32, -1))
#check_valid_sudoku(x[0].argmax(dim=-1))
# check_finished(y[0])
81-torch.count_nonzero(x[1].argmax(dim=-1)).item()

43

In [None]:
class StateFlow(nn.Module):
  def __init__(self, num_hidden=729*2):
    self.mlp = nn.Sequential(
        nn.Linear(810, num_hidden),
        nn.LeakyReLU(),
        nn.Linear(num_hidden, 729)
    ) # predict 810 states, mask out the ones we can't take (ie already filled or would conflict) -> sike we just let the model figure out the rest lmao

  def forward(self, x):

    invalid_moves = torch.ones_like(x)
    # only thing you can't move is to an empty state -- let this model figure out the rest
    invalid_moves[:, :, :, 0] = 0
    return self.mlp(x).exp() * invalid_moves

In [None]:
def calculate_parents(x):
  parent_states = []
  parent_actions = []

  for i in range(x):

In [None]:
# CONSTRUCTION ZONE - pls ping Alexander when you make changes here :)
class StateFlow(nn.Module):
  def __init__(self, num_hidden):
    # consider a convolution here to expand this into channels grid by grid
    # or potentially an attention layer, intialized with row by row
    self.mlp = nn.Sequential(
        nn.Linear(810, num_hidden),
        nn.LeakyReLU(),
        nn.Linear(num_hidden, 729)
    )

  def forward(self, x):
    # not sure if we need to do onehot encodeing here or something similar in order to deal with gradient sizing
    # onehot encoding is a go
    # what if we do a 9 factor step at once? (for later)
    # ignore filled lines (this happens later)
    # allowing deletions lets you effectively introduce cycles by letting the model eliminate values.  is this a problem from a flow consistency standpoint?
    # not allowing for now
    mask = (x[:,:,0]==0).float()

    return torch.cat([
                      torch.zeros(9,9,1),
                      torch.exp(self.mlp(x)).view(9,9,9)*mask
                    ], dim=-1).view(-1)


In [None]:
# something something model forward.  this needs to handle batches, but as a preliminary overview

dist = torch.distributions.Categorical(model(puzz))
sample = dist.sample(dist.softmax(dim=-1))

one_hot_enc = F.one_hot(9 if sample.item() % 9 == 0 else sample.item() % 9)
x,y = (sample.item()%81 )//9, (sample.item()%81 )%9
puzz[x,y] = one_hot_enc

NameError: ignored

> [0;32m<ipython-input-27-b9b1a3ea67f3>[0m(3)[0;36m<cell line: 3>[0;34m()[0m
[0;32m      1 [0;31m[0;31m# something something model forward.  this needs to handle batches, but as a preliminary overview[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      2 [0;31m[0;34m[0m[0m
[0m[0;32m----> 3 [0;31m[0mdist[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mdistributions[0m[0;34m.[0m[0mCategorical[0m[0;34m([0m[0mmodel[0m[0;34m([0m[0mpuzz[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      4 [0;31m[0msample[0m [0;34m=[0m [0mdist[0m[0;34m.[0m[0msample[0m[0;34m([0m[0mdist[0m[0;34m.[0m[0msoftmax[0m[0;34m([0m[0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      5 [0;31m[0;34m[0m[0m
[0m
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user


In [None]:
# need to write sudoku candidate values functions to get entropy for this board (average entropy per cell, divided by the board overall, and maybe scaled further?)

In [None]:
# TESTING FUNCTION
# good_sudoku_puzzle = torch.tensor([
#     [5, 3, 0, 0, 7, 0, 0, 0, 0],
#     [6, 0, 0, 1, 9, 5, 0, 0, 0],
#     [0, 9, 8, 0, 0, 0, 0, 6, 0],
#     [8, 0, 0, 0, 6, 0, 0, 0, 3],
#     [4, 0, 0, 8, 0, 3, 0, 0, 1],
#     [7, 0, 0, 0, 2, 0, 0, 0, 6],
#     [0, 6, 0, 0, 0, 0, 2, 8, 0],
#     [0, 0, 0, 4, 1, 9, 0, 0, 5],
#     [0, 0, 0, 0, 8, 0, 0, 7, 9]
# ])

# bad_sudoku_puzzle_1 = torch.tensor([
#     [5, 3, 0, 0, 7, 0, 0, 0, 0],
#     [6, 3, 0, 1, 9, 5, 0, 0, 0],
#     [0, 9, 8, 0, 0, 0, 0, 6, 0],
#     [8, 0, 0, 0, 6, 0, 0, 0, 3],
#     [4, 0, 0, 8, 0, 3, 0, 0, 1],
#     [7, 0, 0, 0, 2, 0, 0, 0, 6],
#     [0, 6, 0, 0, 0, 0, 2, 8, 0],
#     [0, 0, 0, 4, 1, 9, 0, 0, 5],
#     [0, 0, 0, 0, 8, 0, 0, 7, 9]
# ])

# bad_sudoku_puzzle_2 = torch.tensor([
#     [5, 3, 3, 0, 7, 0, 0, 0, 0],
#     [6, 0, 0, 1, 9, 5, 0, 0, 0],
#     [0, 9, 8, 0, 0, 0, 0, 6, 0],
#     [8, 0, 0, 0, 6, 0, 0, 0, 3],
#     [4, 0, 0, 8, 0, 3, 0, 0, 1],
#     [7, 0, 0, 0, 2, 0, 0, 0, 6],
#     [0, 6, 0, 0, 0, 0, 2, 8, 0],
#     [0, 0, 0, 4, 1, 9, 0, 0, 5],
#     [0, 0, 0, 0, 8, 0, 0, 7, 9]
# ])

# bad_sudoku_puzzle_3 = torch.tensor([
#     [5, 3, 0, 0, 7, 0, 0, 0, 0],
#     [6, 0, 3, 1, 9, 5, 0, 0, 0],
#     [0, 9, 8, 0, 0, 0, 0, 6, 0],
#     [8, 0, 0, 0, 6, 0, 0, 0, 3],
#     [4, 0, 0, 8, 0, 3, 0, 0, 1],
#     [7, 0, 0, 0, 2, 0, 0, 0, 6],
#     [0, 6, 0, 0, 0, 0, 2, 8, 0],
#     [0, 0, 0, 4, 1, 9, 0, 0, 5],
#     [0, 0, 0, 0, 8, 0, 0, 7, 9]
# ])



# assert check_valid_sudoku(good_sudoku_puzzle)
# assert not check_valid_sudoku(bad_sudoku_puzzle_1)
# assert not check_valid_sudoku(bad_sudoku_puzzle_2)
# assert not check_valid_sudoku(bad_sudoku_puzzle_3)
# print("check_sudoku test passed!")