# Building and Training a Custom ConvLogic Model for MNIST

## 🔧 Clone & Install ConvLogic Repository

We begin by cloning the `convlogic` GitHub repository and installing all necessary dependencies.

In [2]:
!git clone https://github.com/lkorinek/convlogic.git

Cloning into 'convlogic'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 97 (delta 20), reused 90 (delta 16), pack-reused 0 (from 0)[K
Receiving objects: 100% (97/97), 557.12 KiB | 9.61 MiB/s, done.
Resolving deltas: 100% (20/20), done.


In [3]:
!pip install -r convlogic/requirements.txt

Collecting torch==2.7.0 (from -r convlogic/requirements.txt (line 1))
  Downloading torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting torchvision==0.22.0 (from -r convlogic/requirements.txt (line 4))
  Downloading torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting matplotlib==3.10.3 (from -r convlogic/requirements.txt (line 5))
  Downloading matplotlib-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting lightning==2.3.3 (from -r convlogic/requirements.txt (line 7))
  Downloading lightning-2.3.3-py3-none-any.whl.metadata (35 kB)
Collecting hydra-core==1.3.2 (from -r convlogic/requirements.txt (line 8))
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting sympy>=1.13.3 (from torch==2.7.0->-r convlogic/requirements.txt (line 1))
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch==2.7.0->-r convlogi

In [4]:
!pip install -v convlogic/. --no-deps --no-build-isolation

Using pip 24.1.2 from /usr/local/lib/python3.11/dist-packages/pip (python 3.11)
Processing ./convlogic
  Running command Preparing metadata (pyproject.toml)
  >>> Building with CUDA support
  running dist_info
  creating /tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info
  writing /tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/PKG-INFO
  writing dependency_links to /tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/dependency_links.txt
  writing requirements to /tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/requires.txt
  writing top-level names to /tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/top_level.txt
  writing manifest file '/tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/SOURCES.txt'
  reading manifest file '/tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/SOURCES.txt'
  adding license file 'LICENSE'
  writing manifest file '/tmp/pip-modern-metadata-0oyyi325/convlogic.egg-info/SOURCES.txt'
  creating '/tmp/pip-modern-metadata-0oyyi325/convlogic-1

In [5]:
import sys

sys.path.append("/content/convlogic/src")

In [6]:
import os

import pytorch_lightning as pl
import torch
import torch.nn as nn
from data import ConvLogicDataModule
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import get_original_cwd
from model import ConvLogicModel
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, WandbLogger

## ⚙️ Load and Customize Configuration

We use Hydra to load the base configuration and optionally override specific parameters such as the number of training epochs, learning rate, or dataset type. These values define the training behavior of the ConvLogic model.

In [7]:
if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()
initialize(config_path="convlogic/configs", version_base="1.1")
cfg = compose(config_name="config")

# Override config if needed
cfg.trainer.max_epochs = 50
cfg.model.lr: 0.01
cfg.model.weight_decay: 0.0
cfg.model.k: 16
cfg.model.tau: 6.5
cfg.model.batch_size: 512
cfg.model.dataset_name = "mnist"
cfg.evaluate.model_filename = "model_mnist_demo"
cfg.general.seed = 42
print(OmegaConf.to_yaml(cfg))

general:
  seed: 42
  profile: false
  profile_type: advanced
logging:
  wandb: false
trainer:
  max_epochs: 50
  accelerator: gpu
  devices: 1
  deterministic: false
data:
  num_workers: 2
  threshold_levels: null
  train_val_split: 0.9
  data_dir: ../../data
early_stopping:
  use: false
  monitor: val_eval/acc
  min_delta: 0.001
  patience: 50
  mode: max
accuracy_threshold_stop:
  monitor: val_no_eval/acc
  use: false
  epoch: 50
  threshold: 0.58
evaluate:
  checkpoint_path: models/
  model_filename: model_mnist_demo
model:
  lr: 0.02
  weight_decay: 0.002
  k: 32
  tau: 20
  dataset_name: mnist
  batch_size: 128
  input_channels: null
  implementation: cuda



In [8]:
callbacks = ModelCheckpoint(
    monitor="val_eval/acc",
    mode="max",
    save_top_k=1,
    save_last=True,
    dirpath="models/",
    filename="model_cifar10_demo",
    auto_insert_metric_name=False,
)

## Prepare the MNIST Dataset

We use `ConvLogicDataModule` to load and preprocess the MNIST dataset. The data is thresholded to convert pixel values into binary inputs. We also extract the number of input channels from the dataset to configure our model.

In [9]:
dm = ConvLogicDataModule(
    dataset_name=cfg.model.dataset_name,
    data_dir=cfg.data.data_dir,
    batch_size=cfg.model.batch_size,
    num_workers=cfg.data.num_workers,
    threshold_levels=cfg.data.threshold_levels,
    train_val_split=cfg.data.train_val_split,
)
dm.prepare_data()
dm.setup("fit")

cfg.model.input_channels = dm.input_channels

100%|██████████| 9.91M/9.91M [00:00<00:00, 54.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.73MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 12.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.60MB/s]


## Define a Custom ConvLogic Model for MNIST

Here we define our own neural network using `ConvLogicLayer` for logic-based convolutions and `LogicLayer` for dense logic gate layers.

We integrate our custom `ConvLogicMnistModel` into the higher-level `ConvLogicModel` interface, which provides the training and evaluation logic.

In [13]:
from convlogic import ConvLogicLayer
from difflogic import GroupSum, LogicLayer


class ConvLogicMnistModel(nn.Module):
    """
    ConvLogic model for MNIST dataset.
    """

    def __init__(self, input_channels=1, input_size=28, k_param=32, tau=6.5, implementation="cuda"):
        super().__init__()
        c1, c2, c3 = k_param, 3 * k_param, 9 * k_param
        # First logic convolution layer with a larger kernel
        layers = [
            ConvLogicLayer(
                in_channels=input_channels,
                out_channels=c1,
                kernel=5,
                padding=0,
                residual_init=True,  # Initializes weights to favor A gate (90%) (useful for training stability).
                complete=True,  # The full 3-stage logic pipeline: ConvLogic → TreeLogic1 → TreeLogic2.
                implementation=implementation,
            ),
            ConvLogicLayer(
                in_channels=c1,
                out_channels=c2,
                kernel=3,
                padding=1,
                residual_init=True,
                complete=True,
                implementation=implementation,
            ),
            ConvLogicLayer(
                in_channels=c2,
                out_channels=c3,
                kernel=3,
                padding=1,
                residual_init=True,
                complete=True,
                implementation=implementation,
            ),
            # Flatten for fully connected layers
            nn.Flatten(),
            LogicLayer(81 * k_param, 1280 * k_param, residual_init=False, implementation=implementation),
            LogicLayer(1280 * k_param, 640 * k_param, residual_init=False, implementation=implementation),
            LogicLayer(640 * k_param, 320 * k_param, residual_init=False, implementation=implementation),
            GroupSum(k=10, tau=tau),
        ]
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


model = ConvLogicModel(cfg.model)
# Replace the internal model with a custom-defined one
model.model = ConvLogicMnistModel(
    input_channels=cfg.model.input_channels,
    input_size=28,
    k_param=cfg.model.k,
    tau=cfg.model.tau,
    implementation=cfg.model.implementation,
)

## Train the Model

Using PyTorch Lightning's `Trainer`, we now begin training our custom model. The training loop includes automatic checkpointing of the best-performing model based on validation accuracy.

In [15]:
trainer = Trainer(
    max_epochs=cfg.trainer.max_epochs,
    accelerator="gpu",
    devices=1,
    callbacks=callbacks,
    deterministic=cfg.trainer.deterministic,
)

trainer.fit(model, datamodule=dm)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                | Params | Mode 
----------------------------------------------------------
0 | model     | ConvLogicMnistModel | 1.2 M  | train
1 | criterion | CrossEntropyLoss    | 0      | train
----------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.774     Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=50` reached.


## ✅ Evaluate on Test Set

After training, we evaluate the final model performance on the held-out MNIST test set.

- `test/acc_eval`: Accuracy using **discretized logic gates**, where each gate selects its most likely logic function. This reflects the performance of the model if deployed with fixed, hard logic.
- `test/acc`: Accuracy using **probabilistic logic functions**, where the output is a weighted sum over all possible gates. This reflects the soft version used during training.

In [17]:
trainer.test(model, datamodule=dm)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test/loss_eval': 0.08189473301172256,
  'test/acc_eval': 0.9773637652397156,
  'test/loss': 0.08079593628644943,
  'test/acc': 0.9768629670143127}]