In [9]:
%load_ext autoreload
%autoreload 2
import sys
import math

import matplotlib.pyplot as plt
import torch
import cirkit
import numpy as np
from cirkit.templates import data_modalities, utils
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

PIXEL_RANGE = 255
example_image = None

KERNEL_SIZE = (4, 4)
CIFAR_SIZE = (28,28)
DEVICE = "cuda:1"
EPOCH = 30
import gc
gc.collect()
torch.cuda.empty_cache()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data Preparation

Let's define a function to create and use patches of the base Dataset

In [2]:
def patchify(kernel_size, stride, compile=True, contiguous_output=False):
    kh, kw = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
    sh, sw = (stride, stride) if isinstance(stride, int) else stride

    def _patchify(image: torch.Tensor):
        # Accept (C,H,W) or (B,C,H,W)

        # Ensure contiguous NCHW for predictable strides
        x = image.contiguous()  # (B,C,H,W)
        B, C, H, W = x.shape

        # Number of patches along H/W
        Lh = (H - kh) // sh + 1
        Lw = (W - kw) // sw + 1

        # Create a zero-copy view: (B, C, Lh, Lw, kh, kw)
        sN, sC, sH, sW = x.stride()
        patches = x.as_strided(
            size=(B, C, Lh, Lw, kh, kw),
            stride=(sN, sC, sH * sh, sW * sw, sH, sW),
        )
        # Reorder to (B, P, C, kh, kw) where P = Lh*Lw
        patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * Lh * Lw, C, kh, kw)

        if contiguous_output:
            patches = (
                patches.contiguous()
            )  # materialize if the next ops need contiguous

        return patches

    if compile:
        _patchify = torch.compile(_patchify, fullgraph=True, dynamic=False)
    return _patchify


transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (PIXEL_RANGE * x).long()),
    ]
)

data_train = datasets.MNIST("datasets", train=True, download=True, transform=transform)
data_test = datasets.MNIST("datasets", train=False, download=True, transform=transform)

# Instantiate the training and testing data loaders
train_dataloader = DataLoader(data_train, shuffle=True, drop_last=True, batch_size=256)
test_dataloader = DataLoader(data_test, shuffle=False, batch_size=256)

## Defining the Circuit

We want to create a factory to create the different circuit we will want to compare.

In [3]:
def patch_share_circuit_factory(
    kernel_size, region_graph, layer_type, num_units, big_region_graph=None
):
    shared_circ = data_modalities.image_data(
        (1, *kernel_size),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        num_classes=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )

    big_circ = data_modalities.image_data(
        (1, *CIFAR_SIZE),
        region_graph=big_region_graph if big_region_graph is not None else region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )

    return share_scope(big_circ, shared_circ, math.prod(kernel_size))

def patch_circuit_factory(
    kernel_size, region_graph, layer_type, num_units, big_region_graph=None
):
    shared_circ = data_modalities.image_data(
        (1, *kernel_size),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        num_classes=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )
    return shared_circ

def top_circuit_factory(
    subspace_size, region_graph, layer_type, num_units, big_region_graph=None
):
    shared_circ = data_modalities.image_data(
        (1, *subspace_size),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        num_classes=1,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )
    return shared_circ

def base_circuit_factory(region_graph, layer_type, num_units):
    return data_modalities.image_data(
        (1, *CIFAR_SIZE),
        # (1,*kernel_size),
        region_graph=region_graph,
        input_layer="categorical",
        num_input_units=num_units,
        sum_product_layer=layer_type,
        num_sum_units=num_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    )

In [4]:
from cirkit.pipeline import compile as cirkit_compile
big_circ = cirkit_compile(base_circuit_factory("quad-graph", "cp-t", 128))
patch_circ = cirkit_compile(patch_circuit_factory((2,2),"quad-graph", "cp-t", 128))

pytorch_total_params = sum(p.numel() for p in big_circ.parameters() if p.requires_grad)
print(pytorch_total_params)


51282690


In [4]:
%load_ext autoreload
%autoreload 2
from cirkit.backend.torch.layers.input import TorchCategoricalLayer
from cirkit.backend.torch.parameters.parameter import TorchParameter, TorchParameterNode, FoldIndexInfo
from cirkit.backend.torch.parameters.nodes import TorchParameterInput
from cirkit.backend.torch.graph.folding import build_unfold_index_info
from torch import Tensor
from cirkit.backend.torch.semiring import Semiring
from typing import Sequence, Mapping
import functools


def share_parameter(graph:TorchParameter, new_fold:int):
    new_param_nodes = []
    for i,n in enumerate(graph.topological_ordering()):
        instance = type(n)
        new_conf=n.config
        new_conf["num_folds"]=new_fold
        if "initializer_" in new_conf:
            reduced_initializer_list = new_conf["initializer_"].keywords["initializers"][:new_fold]
            new_conf["initializer_"]= functools.partial(new_conf["initializer_"].func, initializers=reduced_initializer_list)
        if "shape" in new_conf:
            shape=new_conf["shape"]
            del new_conf["shape"]
            new_param = instance(*shape, **new_conf)
        else:
            new_param = instance( **new_conf)
        new_param.reset_parameters()
        new_param_nodes.append(new_param)
   
    shared= TorchSharedParameter(graph.shape, new_param_nodes, num_folds=graph.num_folds)
    fold_idx_info = FoldIndexInfo(
        ordering=[shared],
        in_fold_idx={0:[[]]},
        out_fold_idx=[(0, f) for f in range(graph.num_folds)]
    )
    return TorchParameter([shared], {shared:[]}, [shared], fold_idx_info=fold_idx_info)


class PatchOrderingLayer:
    def __init__(self, size:tuple[int,int,int], patch:tuple[int,int]):
        self.patch_fn = patchify(patch, patch)
        self.size=size
        self.patch=patch
    
    def __call__(self,x:torch.Tensor):
        #x: (B,N) where N is W*H from the original image
        #We first retrieve the original image
        B,N = x.shape
        x=x.reshape(B,*self.size)
        patched = self.patch_fn(x)
        return patched.reshape(B,N)

# class TorchSharedParameter(TorchParameterInput):
#     def __init__(self,
#         in_shape:tuple[int,...],
#         parameter:list[torch.nn.Module],
#         num_folds:int
# ):class TorchSharedParameter(TorchParameterInput):
#     def __init__(self,
#         in_shape:tuple[int,...],
#         parameter:list[torch.nn.Module],
#         num_folds:int
# ):
#         super().__init__()
#         self._num_folds=num_folds
#         self.in_shape=in_shape
#         self.internal_param = torch.nn.ModuleList(parameter)
    
#     def forward(self):
#         current_input=None
#         for param in self.internal_param:
#             if current_input is None:
#                 current_input=param()
#             else:
#                 current_input=param(current_input)
#         share_fold, *inner_units = current_input.shape
#         return current_input.expand(self.num_folds//share_fold,share_fold, *inner_units).reshape(self.num_folds,*inner_units)

#     @property
#     def shape(self):
#         return self.in_shape
#         super().__init__()
#         self._num_folds=num_folds
#         self.in_shape=in_shape
#         self.internal_param = torch.nn.ModuleList(parameter)
    
#     def forward(self):
#         current_input=None
#         for param in self.internal_param:
#             if current_input is None:
#                 current_input=param()
#             else:
#                 current_input=param(current_input)
#         share_fold, *inner_units = current_input.shape
#         return current_input.expand(self.num_folds//share_fold,share_fold, *inner_units).reshape(self.num_folds,*inner_units)

#     @property
#     def shape(self):
#         return self.in_shape

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
from copy import copy
from cirkit.symbolic.circuit import Scope
from cirkit.symbolic.circuit import Circuit
import sys
sys.path.append("pconv")

from cirkit.symbolic.layers import InputLayer, SumLayer
from cirkit.symbolic.parameters import (
    Parameter,
)
from cirkit.backend.torch.parameters.parameter import TorchParameter
from cirkit.backend.torch.parameters.nodes import TorchParameterInput

from cirkit.backend.torch.layers.inner import TorchSumLayer
from cirkit.backend.torch.layers.optimized import TorchCPTLayer
from cirkit.backend.torch.layers.input import TorchInputLayer
from cirkit.backend.torch.circuits import TorchCircuit
from cirkit.backend.torch.parameters.nodes import TorchParameterInput, TorchMixingWeightParameter, TorchTensorParameter, TorchUnaryParameterOp
from cirkit.backend.torch.parameters.parameter import TorchParameter
from cirkit.backend.torch.layers.inner import TorchSumLayer
from cirkit.backend.torch.layers.optimized import TorchCPTLayer, TorchTuckerLayer
from cirkit.backend.torch.layers.input import TorchInputLayer
from cirkit.backend.torch.circuits import TorchCircuit

def patch_circuits(top_circuit_param, patch_circuit_param):
    im_shape = [i*k for i,k in zip(top_circuit_param["image_shape"],patch_circuit_param["image_shape"])]
    kernel_shape = patch_circuit_param["image_shape"]

    example_data = torch.arange(im_shape[0]*im_shape[1]*im_shape[2]).reshape(1, *im_shape)
    patch_fn = patchify(kernel_shape[1:], kernel_shape[1:])
    scope_order=patch_fn(example_data).reshape(-1, kernel_shape[1]*kernel_shape[2])
    # print(example_data)
    top = data_modalities.image_data(**top_circuit_param)
    new_layers = top._nodes.copy()
    new_inputs = top._in_nodes.copy()
    for new_scope,input_node in zip(scope_order, list(top.layerwise_topological_ordering())[0]):
        # Remove input node
        new_layers.remove(input_node)
        # add output of patch (create patch)
        patch = data_modalities.image_data(**patch_circuit_param)
        patch_input = list(patch.layerwise_topological_ordering())[0]
        for idx,inp in enumerate(patch_input):
            inp.scope= Scope([new_scope[list(inp.scope)[0]].item()])
        new_layers.extend(patch._nodes)

        # verify connections
        for node, inputs in top._in_nodes.items():
            if input_node in inputs:
                new_inputs[node].remove(input_node)
                new_inputs[node].extend(patch.outputs)
        new_inputs.update(patch._in_nodes)
        

        
    return Circuit(new_layers, new_inputs, top.outputs)
    

def copy_parameter(graph: TorchParameter, new_shape):
    new_param_nodes = []
    copy_map = {}
    in_nodes = {}
    outputs = []
    for n in graph.topological_ordering():
        instance = type(n)
        config = n.config
        if isinstance(n, TorchTensorParameter):
            del config["shape"]
            new_param = instance(*new_shape,**config)
            new_param._ptensor = torch.nn.Parameter(torch.zeros((graph.shape[0],*new_shape)))

        elif isinstance(n, TorchUnaryParameterOp):
            config["in_shape"]=new_shape

            new_param = instance(**config)
        new_param_nodes.append(new_param)
        copy_map[n] = new_param
        inputs = [copy_map[in_node] for in_node in graph.node_inputs(n)]
        if len(inputs) > 0:
            in_nodes[new_param] = inputs
    outputs = [copy_map[out_node] for out_node in graph.outputs]
    parameter= TorchParameter(modules=new_param_nodes, in_modules=in_nodes, outputs=outputs)
    return parameter

class TorchSharedParameter(TorchParameterInput):
    def __init__(self,
        in_shape:tuple[int,...],
        parameter:list[torch.nn.Module],
        num_folds:int
):
        super().__init__()
        self._num_folds=num_folds
        self.in_shape=in_shape
        self.internal_param = parameter
    
    def forward(self):
        current_input=None
        for param in self.internal_param:
            if current_input is None:
                current_input=param()
            else:
                current_input=param(current_input)
        share_fold, *inner_units = current_input.shape
        expanded = current_input.expand(self.num_folds//share_fold,share_fold, *inner_units).reshape(self.num_folds,*inner_units)

        return expanded

    @property
    def shape(self):
        return self.in_shape

def share_param_like(base_circ:TorchCircuit, share_struct:TorchCircuit, should_init_mean=False, should_freeze=False):
    for idx, layer in enumerate(share_struct.layers):
        if isinstance(layer, TorchInputLayer):
            folds = base_circ.layers[idx].probs.num_folds
            shared_param = TorchSharedParameter(base_circ.layers[idx].probs.shape, parameter=layer.probs.nodes, num_folds=folds)
            base_circ.layers[idx].probs = shared_param
        elif isinstance(layer, TorchCPTLayer) or isinstance(layer, TorchSumLayer):
            internal_param=layer.weight.nodes
            has_new_nodes=False
            if layer.num_output_units != base_circ.layers[idx].num_output_units:
                new_parameter = copy_parameter(layer.weight, base_circ.layers[idx].weight.nodes[0].shape)
                _,o,i = internal_param[0]._ptensor.data.shape
                _, goal_o, goal_i = new_parameter.nodes[0]._ptensor.data.shape
                num_input = goal_i//i
                num_output = goal_o//o
                new_parameter.nodes[0]._ptensor.data=internal_param[0]._ptensor.data.clone().repeat((1, num_output, num_input))
                has_new_nodes=True
                internal_param=new_parameter.nodes
            folds = base_circ.layers[idx].weight.num_folds
            shared_param = TorchSharedParameter(base_circ.layers[idx].weight.shape, parameter=internal_param, num_folds=folds)
            
            if should_freeze and not has_new_nodes:
                freeze_parameter(shared_param)
            base_circ.layers[idx].weight = shared_param
    if should_init_mean:
        init_mean(base_circ, len(share_struct.layers))


def init_mean(circ, start_idx):
    for layer in circ.layers[start_idx:]:
        if isinstance(layer, TorchCPTLayer) or isinstance(layer, TorchSumLayer):
            param = layer.weight
            tensor = param.nodes[0]._ptensor
            inputs =tensor.shape[-1]

            param.nodes[0]._ptensor.data = torch.full(tensor.shape, torch.exp(torch.tensor(1/inputs)))

def freeze_parameter(param:TorchSharedParameter):
    for p in param.parameters():
        p.requires_grad = False
            
patch_units=512
top_units=512
new_circ=patch_circuits(
    {
        "image_shape": (1, 7,7),
        "region_graph": "quad-graph",
        "input_layer": "categorical",
        "num_input_units": top_units,
        "sum_product_layer": "cp",
        "num_sum_units": top_units,
        "sum_weight_param": utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    },
     {
        "image_shape": (1, 4,4),
        "region_graph": "quad-graph",
        "input_layer": "categorical",
        "num_input_units": patch_units,
        "sum_product_layer": "cp-t",
        "num_sum_units": patch_units,
        "num_classes":top_units,
        "sum_weight_param": utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    }
)



from cirkit.pipeline import compile


from src.benchmark_logic import BenchPCImage

trained_mod = BenchPCImage.load_from_checkpoint("pconv/checkpoints/epoch=38-step=13728.ckpt")
patch_circ_trained = trained_mod.circuit
patch_circ=data_modalities.image_data(
        (1,4,4),
        region_graph="quad-graph",
        input_layer="categorical",
        num_input_units=patch_units,
        sum_product_layer="cp-t",
        num_sum_units=patch_units,
        num_classes=1,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
)

input_circ=data_modalities.image_data(
        (1,1,1),
        region_graph="quad-graph",
        input_layer="categorical",
        num_input_units=patch_units,
        sum_product_layer="cp-t",
        num_sum_units=patch_units,
        num_classes=patch_units,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
)
cpatch = compile(patch_circ)
cbase = compile(new_circ)
# cinput = compile(input_circ)
share_param_like(cbase, patch_circ_trained, should_freeze=True)
# share_param_like(cbase, cinput)
cbase
test_circuit(cbase)


Average test LL: 669.779
Bits per dimension: 1.233


In [11]:
test_circuit(patch_circ_trained, patch=True)

Average test LL: 669.779
Bits per dimension: 1.233


In [12]:
pytorch_shared_params = sum(p.numel() for p in cbase.parameters() if p.requires_grad)
old_big_circ = compile(base_circuit_factory("quad-graph", "cp", 512))
pytorch_total_params = sum(p.numel() for p in old_big_circ.parameters() if p.requires_grad)
trained_params = sum(p.numel() for p in patch_circ_trained.parameters() if p.requires_grad)

print("Big circuit:", pytorch_total_params)
print("Shared circuit:",pytorch_shared_params)
print("Trained circuit:",trained_params)


Big circuit: 920914946
Shared circuit: 49822722
Trained circuit: 2098178


In [14]:
pytorch_shared_params = sum(p.numel() for p in cbase.parameters() )
pytorch_shared_params

57166850

## Training

In [None]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "0"

In [13]:
import random
import time

import numpy as np
import pandas as pd
from cirkit.pipeline import compile

gc.collect()
torch.cuda.empty_cache()

def train_and_eval_circuit(circuit, patch: bool):
    # Set some seeds
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    # torch.cuda.manual_seed(42)

    # Set the torch device to use
    device = torch.device(DEVICE)
    # device="cpu"
    # Move the circuit to chosen device
    circuit = circuit.to(device)
    patch_fn = patchify(KERNEL_SIZE, KERNEL_SIZE)
    num_epochs = 20
    step_idx = 0
    running_loss = 0.0
    running_samples = 0
    stats = dict()

    stats["# trainable parameters"] = sum(
        p.numel() for p in circuit.parameters() if p.requires_grad
    )
    stats["train loss"] = []
    patch_order=PatchOrderingLayer((1,*CIFAR_SIZE), KERNEL_SIZE)
    # Initialize a torch optimizer of your choice,
    #  e.g., Adam, by passing the parameters of the circuit
    optimizer = torch.optim.Adam(circuit.parameters(), lr=0.05)
    begin_train = time.time()
    for epoch_idx in range(num_epochs):
        for i, (batch, _) in enumerate(train_dataloader):
            # The circuit expects an input of shape (batch_dim, num_variables)
            BS = batch.shape[0]
            batch=batch.view(BS, -1)
            if patch:
                batch = patch_order(batch)
            batch = batch.to(device)
            # Compute the log-likelihoods of the batch, by evaluating the circuit
            log_likelihoods = circuit(batch)

            # We take the negated average log-likelihood as loss
            loss = -torch.mean(log_likelihoods)
            loss.backward()
            # Update the parameters of the circuits, as any other model in PyTorch
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.detach() * len(batch)
            running_samples += len(batch)
            step_idx += 1
            if step_idx % 200 == 0:
                average_nll = running_loss / running_samples
                print(f"Step {step_idx}: Average NLL: {average_nll:.3f}")
                running_loss = 0.0
                running_samples = 0

                stats["train loss"].append(average_nll.cpu().item())
    end_train = time.time()

    with torch.no_grad():
        test_lls = 0.0

        for batch, _ in test_dataloader:
            # The circuit expects an input of shape (batch_dim, num_variables)
            BS = batch.shape[0]
            batch=batch.view(BS, -1)
            if patch:
                batch = patch_order(batch)
            batch = batch.to(device)

            # Compute the log-likelihoods of the batch
            print(batch.shape)
            log_likelihoods = circuit(batch)
            print(log_likelihoods.shape)

            # Accumulate the log-likelihoods
            test_lls += log_likelihoods.sum().item()

        # Compute average test log-likelihood and bits per dimension
        average_nll = -test_lls / len(data_test)
        bpd = average_nll / (28*28*1 * np.log(2.0))
        print(f"Average test LL: {average_nll:.3f}")
        print(f"Bits per dimension: {bpd:.3f}")
        stats["test loss"] = average_nll
        stats["test bits per dimension"] = bpd
    end_test = time.time()

    # Free GPU memory
    circuit = circuit.to("cpu")
    torch.cuda.empty_cache()

    stats["train loss (min)"] = min(stats["train loss"])
    stats["train time"] = end_train - begin_train
    stats["test time"] = end_test - end_train

    return stats


# results = dict()
# for k, cc in circuits.items():
#     print('\nTraining circuit "%s"' % k)
#     ctype = k.split("+")[0].strip()
#     results[k] = train_and_eval_circuit(cc, patch=False)
#     results[k]["type"] = k.spli512t("+")[0].strip()
#     results[k]["sum product layer"] = k.split("+")[1].strip()
#     results[k]["structure"] = k.split("+")[2].strip()
# base_circ = cirkit_compile(stitched)
# base_circ.reset_parameters()
train_and_eval_circuit(cbase, False)

Step 200: Average NLL: 640.526
Step 400: Average NLL: 625.818
Step 600: Average NLL: 621.707
Step 800: Average NLL: 619.054
Step 1000: Average NLL: 618.383
Step 1200: Average NLL: 617.771
Step 1400: Average NLL: 617.308
Step 1600: Average NLL: 615.868
Step 1800: Average NLL: 615.845
Step 2000: Average NLL: 614.522
Step 2200: Average NLL: 614.142
Step 2400: Average NLL: 613.999
Step 2600: Average NLL: 614.398
Step 2800: Average NLL: 613.354
Step 3000: Average NLL: 612.930
Step 3200: Average NLL: 612.538
Step 3400: Average NLL: 612.082
Step 3600: Average NLL: 613.302
Step 3800: Average NLL: 612.283
Step 4000: Average NLL: 612.930
Step 4200: Average NLL: 612.329
Step 4400: Average NLL: 611.404
Step 4600: Average NLL: 611.644
torch.Size([256, 784])
torch.Size([256, 1, 1])
torch.Size([256, 784])
torch.Size([256, 1, 1])
torch.Size([256, 784])
torch.Size([256, 1, 1])
torch.Size([256, 784])
torch.Size([256, 1, 1])
torch.Size([256, 784])
torch.Size([256, 1, 1])
torch.Size([256, 784])
torch.Size

{'# trainable parameters': 49822722,
 'train loss': [640.5264282226562,
  625.8182373046875,
  621.7069702148438,
  619.0540771484375,
  618.382568359375,
  617.7711791992188,
  617.3084716796875,
  615.86767578125,
  615.8446044921875,
  614.5216674804688,
  614.141845703125,
  613.9989624023438,
  614.3976440429688,
  613.353515625,
  612.930419921875,
  612.5384521484375,
  612.0823974609375,
  613.3015747070312,
  612.2826538085938,
  612.9298095703125,
  612.3287353515625,
  611.4038696289062,
  611.6441040039062],
 'test loss': 621.7103943359375,
 'test bits per dimension': np.float64(1.1440542127265039),
 'train loss (min)': 611.4038696289062,
 'train time': 508.01440167427063,
 'test time': 2.9679906368255615}

In [20]:
import sys
sys.path.append("pconv")

In [None]:
from cirkit.backend.torch.queries import SamplingQuery
from src.utils import unpatchify

unpatch_fn = unpatchify(CIFAR_SIZE, KERNEL_SIZE, KERNEL_SIZE,1)
circuit = cbase.cpu()
query = SamplingQuery(circuit)

samples, _ = query(num_samples=1)
# samples = unpatch_fn(samples).reshape(28,28,1)
plt.imshow(samples.reshape(28,28,1), cmap="grey")

In [7]:
def test_circuit(circuit, patch=False):
    device = torch.device(DEVICE)
    circuit=circuit.to(device)
    patch_fn = patchify(KERNEL_SIZE, KERNEL_SIZE)

    with torch.no_grad():
        test_lls = 0.0

        for batch, _ in test_dataloader:
            # The circuit expects an input of shape (batch_dim, num_variables)
            if patch:
                batch = patch_fn(batch)
            BS = batch.shape[0]
            batch = batch.view(BS, -1).to(device)

            # Compute the log-likelihoods of the batch
            log_likelihoods = circuit(batch)

            # Accumulate the log-likelihoods
            test_lls += log_likelihoods.sum().item()

        # Compute average test log-likelihood and bits per dimension
        average_nll = -test_lls / len(data_test)
        bpd = average_nll / (28*28*1 * np.log(2.0))
        print(f"Average test LL: {average_nll:.3f}")
        print(f"Bits per dimension: {bpd:.3f}")
# test_circuit(circuit)

In [74]:
from cirkit.pipeline import integrate

ic = integrate(circuit)
ic(None).exp()

tensor([[1.0000]], grad_fn=<ExpBackward0>)