In [1]:
import numpy as np
import torch
from cirkit.templates import data_modalities, utils
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, transforms

PIXEL_RANGE = 255
example_image = None

KERNEL_SIZE = (1, 1)
CIFAR_SIZE = (32, 32)
DEVICE = "cuda:5"
EPOCH = 30

## Data Preparation

Let's define a function to create and use patches of the base Dataset

In [2]:
def patchify(kernel_size, stride, compile=True, contiguous_output=False):
    kh, kw = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
    sh, sw = (stride, stride) if isinstance(stride, int) else stride

    def _patchify(image: torch.Tensor):
        # Accept (C,H,W) or (B,C,H,W)

        # Ensure contiguous NCHW for predictable strides
        x = image.contiguous()  # (B,C,H,W)
        B, C, H, W = x.shape

        # Number of patches along H/W
        Lh = (H - kh) // sh + 1
        Lw = (W - kw) // sw + 1

        # Create a zero-copy view: (B, C, Lh, Lw, kh, kw)
        sN, sC, sH, sW = x.stride()
        patches = x.as_strided(
            size=(B, C, Lh, Lw, kh, kw),
            stride=(sN, sC, sH * sh, sW * sw, sH, sW),
        )
        # Reorder to (B, P, C, kh, kw) where P = Lh*Lw
        patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * Lh * Lw, C, kh, kw)

        if contiguous_output:
            patches = (
                patches.contiguous()
            )  # materialize if the next ops need contiguous

        return patches

    if compile:
        _patchify = torch.compile(_patchify, fullgraph=True, dynamic=False)
    return _patchify


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (PIXEL_RANGE * x).long()),
    ]
)

data_train = datasets.CIFAR10(
    "datasets", train=True, download=True, transform=transform
)
data_test = datasets.CIFAR10(
    "datasets", train=False, download=True, transform=transform
)

train_idx, val_idx = train_test_split(range(len(data_train)), test_size=0.25)
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, batch_size=512, sampler=train_sampler)
val_dataloader = DataLoader(data_train, batch_size=512, sampler=val_sampler)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=512)

## Defining the Circuit

We want to create a factory to create the different circuit we will want to compare.

In [3]:
from cirkit.symbolic.circuit import Circuit, Scope
from cirkit.symbolic.layers import (
    GaussianLayer,
    SumLayer,
    HadamardLayer,
    CategoricalLayer,
)
from cirkit.templates import utils


def build_1_1() -> Circuit:
    # This parametrizes the mixture weights such that they add up to one.
    weight_factory = utils.parameterization_to_factory(
        utils.Parameterization(
            activation="softmax",  # Parameterize the sum weights by using a softmax activation
            initialization="uniform",  # Initialize the sum weights by sampling from a standard normal distribution
        )
    )

    # We introduce one more mixture than in the original model
    # Again, SGD/Adam is not the best way to fit a (shallow) Gaussian mixture model
    units = 10

    g0 = CategoricalLayer(Scope((0,)), units, num_categories=256)
    g1 = CategoricalLayer(Scope((1,)), units, num_categories=256)
    g2 = CategoricalLayer(Scope((2,)), units, num_categories=256)

    s0 = SumLayer(units, 1, 1, weight_factory=weight_factory)
    s2 = SumLayer(units, 1, 1, weight_factory=weight_factory)
    s1 = SumLayer(units, 1, 1, weight_factory=weight_factory)

    return Circuit(
        layers=[
            g0,
            g1,
            g2,
            s0,
            s1,
            s2,
        ],  # Layers that appear in the circuit (i.e. nodes in the graph)
        in_layers={  # Connections between layers (i.e. edges in the graph as an adjacency list)
            g0: [],
            g1: [],
            g2: [],
            s0: [g0],
            s1: [g1],
            s2: [g2],
        },
        outputs=[s0, s1, s2],  # Nodes that are returned by the circuit
    )


def patch_circuit_factory(kernel_size, region_graph, layer_type, num_units):
    return data_modalities.image_data(
        (3, *kernel_size),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )


def base_circuit_factory(region_graph, layer_type, num_units):
    return data_modalities.image_data(
        (3, *CIFAR_SIZE),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )


def circuit_factory(circuit_type: str, **kwargs):
    name = f"{circuit_type}"
    for item in kwargs.values():
        name += f" + {item}"
    if circuit_type == "patch":
        return name, patch_circuit_factory(KERNEL_SIZE, **kwargs)
    else:
        return name, base_circuit_factory(**kwargs)

In [4]:
import itertools

explore_grid = {
    "circuit_type": ["patch", "base"],
    "layer_type": ["cp-t", "cp", "tucker"],
    "region_graph": ["quad-graph", "quad-tree-2"],
    "num_units": [16, 32, 64, 128],
}
keys, values = zip(*explore_grid.items())
explore_list = [dict(zip(keys, v)) for v in itertools.product(*values)]

In [5]:
circuits = dict((circuit_factory(**config) for config in explore_list))

## Training

In [4]:
import random
import time

import pandas as pd
from cirkit.pipeline import compile
from torch.utils.flop_counter import FlopCounterMode


def get_flops(model, inp, with_backward=False):
    istrain = model.training
    model.eval()

    inp = inp if isinstance(inp, torch.Tensor) else torch.randn(inp)

    flop_counter = FlopCounterMode(mods=model, display=False, depth=None)
    with flop_counter:
        if with_backward:
            model(inp).sum().backward()
        else:
            model(inp)
    total_flops = flop_counter.get_total_flops()
    if istrain:
        model.train()
    return total_flops


def train_and_eval_circuit(cc, patch: bool):
    torch.cuda.memory.reset_peak_memory_stats()
    # Set some seeds
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    # torch.cuda.manual_seed(42)

    # Set the torch device to use
    device = torch.device(DEVICE)
    # Move the circuit to chosen device
    circuit = compile(cc)
    circuit = circuit.to(device)

    num_epochs = 20
    step_idx = 0
    running_loss = 0.0
    running_samples = 0
    stats = dict()

    stats["# trainable parameters"] = sum(
        p.numel() for p in circuit.parameters() if p.requires_grad
    )
    stats["train loss"] = []
    patch_fn = patchify(KERNEL_SIZE, KERNEL_SIZE)
    # Initialize a torch optimizer of your choice,
    #  e.g., Adam, by passing the parameters of the circuit
    optimizer = torch.optim.Adam(circuit.parameters(), lr=0.01)
    begin_train = time.time()
    keep_batch = None

    for epoch_idx in range(num_epochs):
        for i, (batch, _) in enumerate(train_dataloader):
            # The circuit expects an input of shape (batch_dim, num_variables)
            if patch:
                batch = patch_fn(batch)
            BS = batch.shape[0]
            batch = batch.view(BS, -1)
            if keep_batch is None:
                keep_batch = batch
            batch = batch.to(device)
            # Compute the log-likelihoods of the batch, by evaluating the circuit
            log_likelihoods = circuit(batch)

            # We take the negated average log-likelihood as loss
            loss = -torch.mean(log_likelihoods)
            loss.backward()
            # Update the parameters of the circuits, as any other model in PyTorch
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.detach() * len(batch)
            running_samples += len(batch)
            step_idx += 1
            if step_idx % 200 == 0:
                average_nll = running_loss / running_samples
                print(f"Step {step_idx}: Average NLL: {average_nll:.3f}")
                running_loss = 0.0
                running_samples = 0

                stats["train loss"].append(average_nll.cpu().item())
    end_train = time.time()

    with torch.no_grad():
        test_lls = 0.0

        for batch, _ in val_dataloader:
            # The circuit expects an input of shape (batch_dim, num_variables)
            if patch:
                batch = patch_fn(batch)
            BS = batch.shape[0]
            batch = batch.view(BS, -1).to(device)

            # Compute the log-likelihoods of the batch
            log_likelihoods = circuit(batch)

            # Accumulate the log-likelihoods
            test_lls += log_likelihoods.sum().item()

        # Compute average test log-likelihood and bits per dimension
        average_nll = -test_lls / len(data_test)
        bpd = average_nll / (32 * 32 * 3 * np.log(2.0))
        print(f"Average test LL: {average_nll:.3f}")
        print(f"Bits per dimension: {bpd:.3f}")

        stats["test loss"] = average_nll
        stats["test bits per dimension"] = bpd
    end_test = time.time()

    stats["train loss (min)"] = min(stats["train loss"])
    stats["train time"] = end_train - begin_train
    stats["test time"] = end_test - end_train
    stats["FLOPs"] = get_flops(circuit, keep_batch.to(device))
    stats["memory"] = torch.cuda.memory.max_memory_allocated(device)
    print(f"Total Flops {stats['FLOPs']}")
    print(f"Total Memory cost {stats['memory']}")

    # Free GPU memory
    circuit = circuit.to("cpu")
    torch.cuda.empty_cache()

    return stats


# results = dict()
# for k, cc in circuits.items():
#     print('\nTraining circuit "%s"' % k)
#     ctype = k.split("+")[0].strip()
#     results[k] = train_and_eval_circuit(cc, patch=ctype == "patch")
#     results[k]["type"] = k.split("+")[0].strip()
#     results[k]["sum product layer"] = k.split("+")[1].strip()
#     results[k]["structure"] = k.split("+")[2].strip()

In [5]:
circ = build_1_1()
train_and_eval_circuit(circ, patch=True)

Step 200: Average NLL: 5.496
Step 400: Average NLL: 5.477
Step 600: Average NLL: 5.477
Step 800: Average NLL: 5.478
Step 1000: Average NLL: 5.477
Step 1200: Average NLL: 5.478
Step 1400: Average NLL: 5.477
Average test LL: 21027.791
Bits per dimension: 9.875
Total Flops 31457280
Total Memory cost 341210112


  flop_counter = FlopCounterMode(mods=model, display=False, depth=None)


{'# trainable parameters': 7710,
 'train loss': [5.496336460113525,
  5.477449417114258,
  5.477447509765625,
  5.477670669555664,
  5.4773030281066895,
  5.477632999420166,
  5.477457523345947],
 'test loss': 21027.791075,
 'test bits per dimension': np.float64(9.875224578369696),
 'train loss (min)': 5.4773030281066895,
 'train time': 120.21795463562012,
 'test time': 1.78206205368042,
 'FLOPs': 31457280,
 'memory': 341210112}

In [None]:
compile(circ)(torch.randint(0, 255, (1, 3)))

tensor([[[-5.8599],
         [-5.7600],
         [-4.8966]]], grad_fn=<TransposeBackward0>)

In [None]:
# circuit =compile(list(circuits.values())[3])
# circuit = Meter(circuit)
# circuit.to("cuda:6")
batch, _ = next(iter(val_dataloader))
patch_fn = patchify(KERNEL_SIZE, KERNEL_SIZE)

batch = patch_fn(batch)
BS = batch.shape[0]
batch = batch.view(BS, -1).to("cuda:6")

res = circuit(batch)
print(res)
circuit.overview()

tensor([[[-266.7532]],

        [[-268.8200]],

        [[-268.7206]],

        ...,

        [[-267.1960]],

        [[-268.8501]],

        [[-268.8837]]], device='cuda:6', grad_fn=<TransposeBackward0>)


IndexError: tuple index out of range

In [None]:
df = pd.DataFrame.from_dict(results, orient="index")
df = df.drop(columns="train loss")

df.index = df.index.map(lambda x: x.split("+")[0])
df["# trainable parameters"] = df["# trainable parameters"].map("{:,d}".format)
pd.options.display.float_format = "{:,.3f}".format
df["train time format"] = pd.to_datetime(df["train time"], unit="s").dt.strftime(
    "%m:%S"
)
df["test time format"] = pd.to_datetime(df["test time"], unit="s").dt.strftime("%m:%S")

df.sort_values("test bits per dimension")

Unnamed: 0,# trainable parameters,test loss,test bits per dimension,train loss (min),train time,test time,type,sum product layer,structure,train time format,test time format
patch,8135170,12483.816,5.863,194.937,1048.771,3.044,patch,tucker,quad-graph,01:28,01:03
patch,4460544,12647.629,5.94,197.561,535.622,2.226,patch,tucker,quad-tree-2,01:55,01:02
patch,1032834,12685.412,5.957,198.096,294.293,1.839,patch,cp,quad-graph,01:54,01:01
patch,901762,12689.151,5.959,198.086,286.904,1.766,patch,cp.T,quad-graph,01:46,01:01
patch,909376,12832.732,6.027,200.429,268.187,1.509,patch,cp,quad-tree-2,01:28,01:01
patch,843840,12860.927,6.04,200.87,261.979,1.584,patch,cp.T,quad-tree-2,01:21,01:01
base,67136130,13412.097,6.299,13043.954,364.976,1.746,base,cp,quad-graph,01:04,01:01
base,586205698,13430.028,6.307,13018.108,1864.905,3.985,base,tucker,quad-graph,01:04,01:03
base,58712128,13458.212,6.32,13103.154,300.24,1.567,base,cp,quad-tree-2,01:00,01:01
base,318246912,13459.455,6.321,13045.494,992.584,2.709,base,tucker,quad-tree-2,01:32,01:02


In [None]:
print(
    df[
        [
            "type",
            "sum product layer",
            "structure",
            "# trainable parameters",
            "test bits per dimension",
            "test loss",
            "train time format",
            "test time format",
        ]
    ]
    .sort_values("test bits per dimension")
    .to_latex(float_format="%.2f", escape=True, index=False)
)

\begin{tabular}{llllrrll}
\toprule
type & sum product layer & structure & \# trainable parameters & test bits per dimension & test loss & train time format & test time format \\
\midrule
patch & tucker & quad-graph & 8,135,170 & 5.86 & 12483.82 & 01:28 & 01:03 \\
patch & tucker & quad-tree-2 & 4,460,544 & 5.94 & 12647.63 & 01:55 & 01:02 \\
patch & cp & quad-graph & 1,032,834 & 5.96 & 12685.41 & 01:54 & 01:01 \\
patch & cp.T & quad-graph & 901,762 & 5.96 & 12689.15 & 01:46 & 01:01 \\
patch & cp & quad-tree-2 & 909,376 & 6.03 & 12832.73 & 01:28 & 01:01 \\
patch & cp.T & quad-tree-2 & 843,840 & 6.04 & 12860.93 & 01:21 & 01:01 \\
base & cp & quad-graph & 67,136,130 & 6.30 & 13412.10 & 01:04 & 01:01 \\
base & tucker & quad-graph & 586,205,698 & 6.31 & 13430.03 & 01:04 & 01:03 \\
base & cp & quad-tree-2 & 58,712,128 & 6.32 & 13458.21 & 01:00 & 01:01 \\
base & tucker & quad-tree-2 & 318,246,912 & 6.32 & 13459.45 & 01:32 & 01:02 \\
base & cp.T & quad-graph & 58,747,522 & 6.35 & 13515.25 & 01:3

In [None]:
df.to_csv("bench_cifar.csv")