In [1]:
import numpy as np
import torch
import torchinfo

from neurodifflogic.difflogic.compiled_model import CompiledLogicNet
from neurodifflogic.models import CNN

  warn(f"Failed to load image Python extension: {e}")


In [2]:
model = CNN(class_count=10, tau=1, implementation="python", device="cpu")
model.train(False)  # Switch model to eval mode

print(torchinfo.summary(model, input_size=(1, 1, 28, 28)))

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      [1, 10]                   --
├─Sequential: 1-1                        [1, 10]                   --
│    └─LogicConv2d: 2-1                  [1, 16, 24, 24]           3,840
│    └─OrPoolingLayer: 2-2               [1, 16, 12, 12]           --
│    └─LogicConv2d: 2-3                  [1, 48, 10, 10]           11,520
│    └─OrPoolingLayer: 2-4               [1, 48, 6, 6]             --
│    └─LogicConv2d: 2-5                  [1, 144, 4, 4]            34,560
│    └─OrPoolingLayer: 2-6               [1, 144, 3, 3]            --
│    └─Flatten: 2-7                      [1, 1296]                 --
│    └─LogicLayer: 2-8                   [1, 20480]                327,680
│    └─LogicLayer: 2-9                   [1, 10240]                163,840
│    └─LogicLayer: 2-10                  [1, 5120]                 81,920
│    └─GroupSum: 2-11                    [1, 10]            

In [3]:

compiled_model = CompiledLogicNet(
    model=model.model, num_bits=8, cpu_compiler="gcc", verbose=True
)
compiled_model.compile(save_lib_path="compiled_clgn_model.so", verbose=False)


Found Flatten layer
Found GroupSum layer with 10 classes
Parsed 3 conv, 3 pooling, 3 linear layers
Layer execution order: [('conv', 0), ('pool', 0), ('conv', 1), ('pool', 1), ('conv', 2), ('pool', 2), ('flatten', 0), ('linear', 0), ('linear', 1), ('linear', 2)]
Compiling finished in 30.275 seconds.


In [4]:
x = torch.randint(0, 2, (8, 1, 28, 28), dtype=torch.int16)
x_np = x.bool().numpy()

preds = model(x)
preds_compiled = compiled_model(x_np)

assert np.allclose(preds.numpy(), preds_compiled, atol=1e-5), "Compiled model predictions do not match original model predictions"


In [5]:
from datetime import datetime

x = torch.randint(0, 2, (1_000, 1, 28, 28), dtype=torch.int16)
x_np = x.bool().numpy()

start_time = datetime.now()
preds = model(x)
end_time = datetime.now()
print(f"Original model inference time: {end_time - start_time}")

start_time = datetime.now()
preds_compiled = compiled_model(x_np)
end_time = datetime.now()
print(f"Compiled model inference time: {end_time - start_time}")


Original model inference time: 0:00:15.035126
Compiled model inference time: 0:00:00.008955


In [6]:
x = torch.randint(0, 2, (1_000_000, 1, 28, 28), dtype=torch.int16)
x_np = x.bool().numpy()

start_time = datetime.now()
preds_compiled = compiled_model(x_np)
end_time = datetime.now()
print(f"Compiled model inference time: {end_time - start_time}")


Compiled model inference time: 0:00:09.226433
