In [6]:
import torch
from cirkit.templates import data_modalities, utils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

PIXEL_RANGE = 255
example_image = None

KERNEL_SIZE = (4, 4)
CIFAR_SIZE = (28, 28)
DEVICE = "cuda:7"
EPOCH = 10

## Data Preparation

Let's define a function to create and use patches of the base Dataset

In [None]:
def patchify(kernel_size, stride, compile=True, contiguous_output=False):
    kh, kw = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
    sh, sw = (stride, stride) if isinstance(stride, int) else stride

    def _patchify(image: torch.Tensor):
        # Accept (C,H,W) or (B,C,H,W)

        # Ensure contiguous NCHW for predictable strides
        x = image.contiguous()  # (B,C,H,W)
        B, C, H, W = x.shape

        # Number of patches along H/W
        Lh = (H - kh) // sh + 1
        Lw = (W - kw) // sw + 1

        # Create a zero-copy view: (B, C, Lh, Lw, kh, kw)
        sN, sC, sH, sW = x.stride()
        patches = x.as_strided(
            size=(B, C, Lh, Lw, kh, kw),
            stride=(sN, sC, sH * sh, sW * sw, sH, sW),
        )
        # Reorder to (B, P, C, kh, kw) where P = Lh*Lw
        patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * Lh * Lw, C, kh, kw)

        if contiguous_output:
            patches = (
                patches.contiguous()
            )  # materialize if the next ops need contiguous

        return patches

    if compile:
        _patchify = torch.compile(_patchify, fullgraph=True, dynamic=False)
    return _patchify


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (PIXEL_RANGE * x).long()),
    ]
)

data_train = datasets.MNIST("datasets", train=True, download=True, transform=transform)
data_test = datasets.MNIST("datasets", train=False, download=True, transform=transform)

# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)

## Defining the Circuit

We want to create a factory to create the different circuit we will want to compare.

In [3]:
def patch_circuit_factory(kernel_size, region_graph, layer_type, num_units):
    return data_modalities.image_data(
        (1, *kernel_size),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )


def base_circuit_factory(region_graph, layer_type, num_units):
    return data_modalities.image_data(
        (1, *CIFAR_SIZE),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )

In [4]:
circuits = dict()
units = 32
circuits["patch + cp.T + quad-graph"] = patch_circuit_factory(
    KERNEL_SIZE, "quad-graph", "cp-t", units
)
circuits["patch + cp.T + quad-tree-2"] = patch_circuit_factory(
    KERNEL_SIZE, "quad-tree-2", "cp-t", units
)

circuits["patch + cp + quad-graph"] = patch_circuit_factory(
    KERNEL_SIZE, "quad-graph", "cp", units
)
circuits["patch + cp + quad-tree-2"] = patch_circuit_factory(
    KERNEL_SIZE, "quad-tree-2", "cp", units
)

circuits["patch + tucker + quad-graph"] = patch_circuit_factory(
    KERNEL_SIZE, "quad-graph", "tucker", units
)
circuits["patch + tucker + quad-tree-2"] = patch_circuit_factory(
    KERNEL_SIZE, "quad-tree-2", "tucker", units
)

circuits["base + cp.T + quad-graph"] = base_circuit_factory("quad-graph", "cp-t", units)
circuits["base + cp.T + quad-tree-2"] = base_circuit_factory(
    "quad-tree-2", "cp-t", units
)

circuits["base + cp + quad-graph"] = base_circuit_factory("quad-graph", "cp", units)
circuits["base + cp + quad-tree-2"] = base_circuit_factory("quad-tree-2", "cp", units)

circuits["base + tucker + quad-graph"] = base_circuit_factory(
    "quad-graph", "tucker", units
)
circuits["base + tucker + quad-tree-2"] = base_circuit_factory(
    "quad-tree-2", "tucker", units
)

## Training

In [7]:
import random
import time

import numpy as np
import pandas as pd
from cirkit.pipeline import compile


def train_and_eval_circuit(cc, patch: bool):
    # Set some seeds
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    # torch.cuda.manual_seed(42)

    # Set the torch device to use
    device = torch.device(DEVICE)
    # Move the circuit to chosen device
    circuit = compile(cc)
    circuit = circuit.to(device)

    num_epochs = EPOCH
    step_idx = 0
    running_loss = 0.0
    running_samples = 0
    stats = dict()

    stats["# trainable parameters"] = sum(
        p.numel() for p in circuit.parameters() if p.requires_grad
    )
    stats["train loss"] = []
    patch_fn = patchify(KERNEL_SIZE, KERNEL_SIZE)
    # Initialize a torch optimizer of your choice,
    #  e.g., Adam, by passing the parameters of the circuit
    optimizer = torch.optim.Adam(circuit.parameters(), lr=0.01)
    begin_train = time.time()
    for epoch_idx in range(num_epochs):
        for i, (batch, _) in enumerate(train_dataloader):
            # The circuit expects an input of shape (batch_dim, num_variables)
            if patch:
                batch = patch_fn(batch)
            BS = batch.shape[0]
            batch = batch.view(BS, -1).to(device)

            # Compute the log-likelihoods of the batch, by evaluating the circuit
            log_likelihoods = circuit(batch)

            # We take the negated average log-likelihood as loss
            loss = -torch.mean(log_likelihoods)
            loss.backward()
            # Update the parameters of the circuits, as any other model in PyTorch
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.detach() * len(batch)
            running_samples += len(batch)
            step_idx += 1
            if step_idx % 200 == 0:
                average_nll = running_loss / running_samples
                print(f"Step {step_idx}: Average NLL: {average_nll:.3f}")
                running_loss = 0.0
                running_samples = 0

                stats["train loss"].append(average_nll.cpu().item())
    end_train = time.time()

    with torch.no_grad():
        test_lls = 0.0

        for batch, _ in test_dataloader:
            # The circuit expects an input of shape (batch_dim, num_variables)
            if patch:
                batch = patch_fn(batch)
            BS = batch.shape[0]
            batch = batch.view(BS, -1).to(device)

            # Compute the log-likelihoods of the batch
            log_likelihoods = circuit(batch)

            # Accumulate the log-likelihoods
            test_lls += log_likelihoods.sum().item()

        # Compute average test log-likelihood and bits per dimension
        average_nll = -test_lls / len(data_test)
        bpd = average_nll / (28 * 28 * 1 * np.log(2.0))
        print(f"Average test LL: {average_nll:.3f}")
        print(f"Bits per dimension: {bpd:.3f}")

        stats["test loss"] = average_nll
        stats["test bits per dimension"] = bpd
    end_test = time.time()

    # Free GPU memory
    circuit = circuit.to("cpu")
    torch.cuda.empty_cache()

    stats["train loss (min)"] = min(stats["train loss"])
    stats["train time"] = end_train - begin_train
    stats["test time"] = end_test - end_train

    return stats


results = dict()
for k, cc in circuits.items():
    print('\nTraining circuit "%s"' % k)
    ctype = k.split("+")[0].strip()
    results[k] = train_and_eval_circuit(cc, patch=ctype == "patch")
    results[k]["type"] = k.split("+")[0].strip()
    results[k]["sum product layer"] = k.split("+")[1].strip()
    results[k]["structure"] = k.split("+")[2].strip()
    break


Training circuit "patch + cp.T + quad-graph"
Step 200: Average NLL: 26.511
Step 400: Average NLL: 17.413
Step 600: Average NLL: 15.854
Step 800: Average NLL: 15.274
Step 1000: Average NLL: 15.039
Step 1200: Average NLL: 14.900
Step 1400: Average NLL: 14.811
Step 1600: Average NLL: 14.734
Step 1800: Average NLL: 14.707
Step 2000: Average NLL: 14.642
Step 2200: Average NLL: 14.622
Average test LL: 712.793
Bits per dimension: 1.312


In [None]:
df = pd.DataFrame.from_dict(results, orient="index")
df = df.drop(columns="train loss")

df.index = df.index.map(lambda x: x.split("+")[0])
df["# trainable parameters"] = df["# trainable parameters"].map("{:,d}".format)
pd.options.display.float_format = "{:,.3f}".format
df["train time format"] = pd.to_datetime(df["train time"], unit="s").dt.strftime(
    "%m:%S"
)
df["test time format"] = pd.to_datetime(df["test time"], unit="s").dt.strftime("%m:%S")

df.sort_values("test bits per dimension")

Unnamed: 0,# trainable parameters,test loss,test bits per dimension,train loss (min),train time,test time,type,sum product layer,structure,train time format,test time format
base,25657730,681.926,1.255,683.536,140.801,1.277,base,cp,quad-graph,01:20,01:01
base,421306626,683.163,1.257,681.451,530.976,2.879,base,tucker,quad-graph,01:50,01:02
base,19259456,684.707,1.26,686.227,104.214,1.119,base,cp,quad-tree-2,01:44,01:01
base,19259778,689.319,1.268,676.182,121.753,1.206,base,cp.T,quad-graph,01:01,01:01
base,217845760,690.699,1.271,688.059,273.444,1.982,base,tucker,quad-tree-2,01:33,01:01
base,16048192,693.989,1.277,683.492,98.947,1.096,base,cp.T,quad-tree-2,01:38,01:01
patch,377474,723.323,1.331,14.849,95.649,2.056,patch,cp.T,quad-graph,01:35,01:02
patch,319552,729.018,1.342,14.96,82.122,1.094,patch,cp.T,quad-tree-2,01:22,01:01
patch,508546,731.456,1.346,15.045,94.729,1.252,patch,cp,quad-graph,01:34,01:01
patch,385088,734.023,1.351,15.079,83.818,1.131,patch,cp,quad-tree-2,01:23,01:01


In [None]:
print(
    df[
        [
            "type",
            "sum product layer",
            "structure",
            "# trainable parameters",
            "test bits per dimension",
            "test loss",
            "train time format",
            "test time format",
        ]
    ]
    .sort_values("test bits per dimension")
    .to_latex(float_format="%.2f", escape=True, index=False)
)

\begin{tabular}{llllrrll}
\toprule
type & sum product layer & structure & \# trainable parameters & test bits per dimension & test loss & train time format & test time format \\
\midrule
base & cp & quad-graph & 25,657,730 & 1.25 & 681.93 & 01:20 & 01:01 \\
base & tucker & quad-graph & 421,306,626 & 1.26 & 683.16 & 01:50 & 01:02 \\
base & cp & quad-tree-2 & 19,259,456 & 1.26 & 684.71 & 01:44 & 01:01 \\
base & cp.T & quad-graph & 19,259,778 & 1.27 & 689.32 & 01:01 & 01:01 \\
base & tucker & quad-tree-2 & 217,845,760 & 1.27 & 690.70 & 01:33 & 01:01 \\
base & cp.T & quad-tree-2 & 16,048,192 & 1.28 & 693.99 & 01:38 & 01:01 \\
patch & cp.T & quad-graph & 377,474 & 1.33 & 723.32 & 01:35 & 01:02 \\
patch & cp.T & quad-tree-2 & 319,552 & 1.34 & 729.02 & 01:22 & 01:01 \\
patch & cp & quad-graph & 508,546 & 1.35 & 731.46 & 01:34 & 01:01 \\
patch & cp & quad-tree-2 & 385,088 & 1.35 & 734.02 & 01:23 & 01:01 \\
patch & tucker & quad-graph & 7,610,882 & 1.35 & 734.18 & 01:10 & 01:02 \\
patch & tucke

In [None]:
df.to_csv("bench_mnist.csv")