In [7]:
import os
import numpy as np
import torch
from problog.logic import Term, Constant, list2term
from deepproblog.dataset import Dataset
from deepproblog.query import Query

In [8]:
class MNIST_SudokuNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 6, 5),
            torch.nn.MaxPool2d(2, 2),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(6, 16, 5),
            torch.nn.MaxPool2d(2, 2),
            torch.nn.ReLU(True),
        )
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(16*4*4, 120),
            torch.nn.ReLU(),
            torch.nn.Linear(120, 84),
            torch.nn.ReLU(),
            torch.nn.Linear(84, 9),
            torch.nn.Softmax(dim=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(-1, 16 * 4 * 4)
        x = self.classifier(x)
        return x

In [9]:
class SudokuPuzzleDataset(Dataset):
    """
    Loads puzzles (images) and their corresponding labels (correct digits).
    - puzzle_path: .npy files of shape (9,9), each cell is either None or a 28x28 image
    - label_path:  .npy files of shape (9,9), each cell is an integer label (1..9 or 0 for empty)
    This dataset can:
      (1) Provide queries for entire puzzles (puzzle_solve).
      (2) Provide cell-level queries for training each digit with a known label.
    """
    def __init__(self, subset: str, puzzle_path: str, label_path: str):
        self.subset = subset
        self.puzzle_files = sorted(f for f in os.listdir(puzzle_path) if f.endswith('.npy'))
        self.label_files  = sorted(f for f in os.listdir(label_path)  if f.endswith('.npy'))

        self.puzzles = []
        self.labels  = []
        for p_file, l_file in zip(self.puzzle_files, self.label_files):
            puzzle = np.load(os.path.join(puzzle_path, p_file), allow_pickle=True)
            label  = np.load(os.path.join(label_path,  l_file), allow_pickle=True)
            # Normalize the images; keep 'None' as is. 
            # If label is 0 or None => means an empty cell
            for i in range(9):
                for j in range(9):
                    if puzzle[i][j] is not None:
                        puzzle[i][j] = puzzle[i][j].astype(np.float32) / 255.0
            self.puzzles.append(puzzle)
            self.labels.append(label)

    def __len__(self):
        return len(self.puzzles)

    def __getitem__(self, item):
        """
        This is how DeepProbLog retrieves actual image data:
        We expect item = (puzzle_index, row, col).
        If puzzle cell is None => can't retrieve an image => raise error.
        """
        puzzle_index, row, col = item
        cell_image = self.puzzles[puzzle_index][row][col]
        if cell_image is None:
            raise ValueError("Tried to retrieve an image from an empty (None) cell.")
        return torch.tensor(cell_image).unsqueeze(0)  # shape (1,28,28)

    ##############################
    # 1) Query for Full Puzzle
    ##############################
    def to_query_puzzle(self, index: int) -> Query:
        """
        Constructs a query: puzzle_solve(RowListOfLists)
        Each cell => 'none' if empty, or tensor(...) if image exists.
        """
        puzzle = self.puzzles[index]
        row_list_terms = []

        # We'll reference each cell by puzzle-index, row, col
        for i in range(9):
            col_terms = []
            for j in range(9):
                if puzzle[i][j] is None:
                    # No image => 'none'
                    col_terms.append(Constant('none'))
                else:
                    # Has image => reference it
                    col_terms.append(Term("tensor", Term(self.subset, Constant(index), Constant(i), Constant(j))))
            row_list_terms.append(list2term(col_terms))

        # puzzle_solve([...])
        puzzle_term = list2term(row_list_terms)
        query_term = Term("puzzle_solve", puzzle_term)
        return Query(query_term)

    ##############################
    # 2) Query for Labeled Cell
    ##############################
    def to_query_cell(self, puzzle_index: int, row: int, col: int) -> Query:
        """
        Creates a query digit(tensor(...), CorrectDigit).
        Used for supervised training of the digit recognition.
        """
        cell_image = self.puzzles[puzzle_index][row][col]
        cell_label = self.labels[puzzle_index][row][col]  # e.g. 1..9 or 0 if empty

        # If there's no image or label is 0 => skip (or raise an error)
        if cell_image is None or cell_label == 0:
            raise ValueError("Cell is empty or label=0, cannot do supervised digit training.")

        # Prolog side: digit(tensor(subset, puzzle_index, row, col), label)
        image_term = Term("tensor", Term(self.subset, Constant(puzzle_index), Constant(row), Constant(col)))
        label_const = Constant(int(cell_label))
        return Query(Term("digit", image_term, label_const))

In [10]:
from deepproblog.model import Model
from deepproblog.network import Network
from deepproblog.engines import ExactEngine
from deepproblog.dataset import DataLoader
from deepproblog.train import train_model

In [11]:
# 1) Load dataset
train_data = SudokuPuzzleDataset("train", "../../mnist_sudoku_generator/dataset/images/puzzles/train", "../../mnist_sudoku_generator/dataset/arrays/puzzles/train")
network = MNIST_SudokuNet()
net = Network(network, "mnist_net", batching=True)
net.optimizer = torch.optim.Adam(network.parameters(), lr=1e-3)

# 3) Build the model from the updated Prolog
model = Model("sudoku_no_clpfd.pl", [net])
model.set_engine(ExactEngine(model))

# 4) Add the dataset as a tensor source
model.add_tensor_source("train", train_data)

TypeError: Can't instantiate abstract class SudokuPuzzleDataset with abstract method to_query