In [4]:
from typing import Iterator
import numpy as np
import os
import torch
from problog.logic import Term, Constant
from deepproblog.dataset import Dataset
from deepproblog.query import Query

In [5]:
class SudokuCellDataset(Dataset):
    """
    Each item in the dataset is a single cell (image + label).
    Good for supervised training of digit recognition.
    """
    def __init__(self, subset: str, puzzle_path: str, solution_path: str):
        self.subset = subset

        puzzle_files = sorted(f for f in os.listdir(puzzle_path) if f.endswith('.npy'))[:1000]
        solution_files = sorted(f for f in os.listdir(solution_path) if f.endswith('.npy'))[:1000]

        self.cells = []  # will hold (image_array, correct_digit, cell_id, puzzle_id)
        puzzle_id = 0

        for p_file, s_file in zip(puzzle_files, solution_files):
            puzzle = np.load(os.path.join(puzzle_path, p_file), allow_pickle=True)
            solution = np.load(os.path.join(solution_path, s_file), allow_pickle=True)
            
            # Both puzzle & solution are 9x9. solution[i][j] is the correct digit for puzzle[i][j].
            for i in range(9):
                for j in range(9):
                    # If puzzle[i][j] is None, we typically define it as zero-image.
                    if puzzle[i][j] is None:
                        puzzle[i][j] = np.ones((28,28), dtype=np.float32)
                    else:
                        puzzle[i][j] = puzzle[i][j].astype(np.float32)
                    puzzle[i][j] /= 255.0

                    # solution[i][j] is the label (1..9 or 0 if blank?)
                    if solution[i][j] is not None:
                        label_digit = int(solution[i][j])  # e.g. 1..9
                    else:
                        label_digit = 0

                    # We'll store each cell as a separate training example
                    cell_id = puzzle_id*81 + (i*9 + j)
                    self.cells.append((puzzle[i][j], label_digit, cell_id, puzzle_id))

            puzzle_id += 1

    def __len__(self):
        return len(self.cells)

    def to_query(self, index: int) -> Query:
        """
        Construct a query for this single cell:
          digit(tensor(subset, cell_id), CorrectDigit)
        """
        img_array, label_digit, cell_id, _ = self.cells[index]

        # The Prolog side sees "digit(ImageTerm, Label)" with domain 1..9
        image_term = Term("tensor", Term(self.subset, Constant(cell_id)))
        label = Constant(label_digit)
        return Query(Term('digit', image_term, label))

    def __getitem__(self, item):
        """Used by DeepProbLog to retrieve the actual tensor for 'tensor(subset, cell_id)'."""
        img_array, label_digit, cell_id, _ = self.cells[int(item[0])]
        # Return as 1×28×28 PyTorch tensor:
        return torch.tensor(img_array).unsqueeze(0)

In [4]:
import torch
import torch.nn as nn
from deepproblog.model import Model
from deepproblog.network import Network
from deepproblog.engines import ExactEngine

class MNIST_Net(nn.Module):
    def __init__(self):
        super(MNIST_Net, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 6, 5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(True),
            nn.Conv2d(6, 16, 5),
            nn.MaxPool2d(2, 2),
            nn.ReLU(True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 9),  # 9 classes (digits 1..9)
            nn.Softmax(dim=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

network = MNIST_Net()
net = Network(network, "mnist_net", batching=True)
net.optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

# Load the "sudoku_no_clpfd.pl", which has `digit(...), sudoku(...)` definitions.
model = Model("sudoku_model.pl", [net])
model.set_engine(ExactEngine(model))

# Attach the dataset as a tensor source
train_data = SudokuCellDataset("train", "../mnist_sudoku_generator/dataset/images/puzzles/train", "../mnist_sudoku_generator/dataset/arrays/puzzles/train")
model.add_tensor_source("train", train_data)


NameError: name 'SudokuCellDataset' is not defined

In [5]:
# Now train the 'digit' predicate:
from deepproblog.dataset import DataLoader
from deepproblog.train import train_model
from deepproblog.utils.stop_condition import Threshold

stopper = Threshold('Iteration', 6000)

loader = DataLoader(train_data, batch_size=32, shuffle=False)
train_model(model, loader, 1, log_iter=100, )

model.save_state("snapshot/trained_sudoku_digits.pth")

Training  for 1 epoch(s)
Epoch 1
Iteration:  100 	s:4.9950 	Average Loss:  14.251375604881286
Iteration:  200 	s:4.4945 	Average Loss:  13.930967620051248
Iteration:  300 	s:4.4341 	Average Loss:  13.819868718602917
Iteration:  400 	s:4.4505 	Average Loss:  13.77460189233665
Iteration:  500 	s:4.3996 	Average Loss:  13.774253270671322
Iteration:  600 	s:4.4707 	Average Loss:  13.700331466796424
Iteration:  700 	s:4.4276 	Average Loss:  13.72305796475496
Iteration:  800 	s:4.8924 	Average Loss:  13.72079532426276
Iteration:  900 	s:4.5488 	Average Loss:  13.71642169526322
Iteration:  1000 	s:4.5091 	Average Loss:  13.68330980060716
Iteration:  1100 	s:4.3821 	Average Loss:  13.70845503873437
Iteration:  1200 	s:4.4274 	Average Loss:  13.655003630793674
Iteration:  1300 	s:4.4542 	Average Loss:  13.731020794009732
Iteration:  1400 	s:4.4817 	Average Loss:  13.627937947417553
Iteration:  1500 	s:4.4970 	Average Loss:  13.74822774195583
Iteration:  1600 	s:4.5297 	Average Loss:  13.6375902

In [5]:
# Attach the dataset as a tensor source
test_data = SudokuCellDataset("test", "../mnist_sudoku_generator/dataset/images/puzzles/test", "../mnist_sudoku_generator/dataset/arrays/puzzles/test")
model.add_tensor_source("test", test_data)

In [6]:
from deepproblog.evaluate import get_confusion_matrix
cm = get_confusion_matrix(model, test_data)

In [7]:
cm.accuracy()

Accuracy:  0.5050370370370371


np.float64(0.5050370370370371)

In [None]:
print(str(cm))