In [19]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("pconv")
import cirkit
from cirkit.templates import data_modalities
from cirkit.templates import utils
from cirkit.pipeline import compile, PipelineContext
from cirkit.symbolic.io import plot_circuit
import torch
from src.utils import patchify, unpatchify
from src.benchmark_logic import BenchPCImage
ctx = PipelineContext(fold=False, optimize=False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
image_size=32
kernel_size=4
example_data = torch.randint(256,(1,3,image_size,image_size))
patch_fn = patchify((kernel_size,kernel_size), (kernel_size,kernel_size))
unpatch_fn = unpatchify((image_size,image_size),(kernel_size,kernel_size), (kernel_size,kernel_size), 3)
patched=patch_fn(example_data)
example_data, patched, unpatch_fn(patched), patched.reshape(-1, 3*kernel_size*kernel_size)

(tensor([[[[ 71, 209,  19,  ..., 200,  18, 244],
           [223,  49,  97,  ...,  20, 106,  36],
           [119,  71, 128,  ..., 199, 227, 218],
           ...,
           [ 44,  16, 150,  ..., 100, 106, 156],
           [ 98,   9, 167,  ...,  40, 107, 253],
           [ 11, 238,  88,  ...,  33,  53, 245]],
 
          [[211,  27, 153,  ..., 163,  86,  25],
           [173, 150,  61,  ...,  14, 236, 230],
           [ 77, 158, 244,  ..., 236, 153, 219],
           ...,
           [175, 230,  37,  ..., 181,  85,  77],
           [ 53, 235, 195,  ..., 203,  64, 150],
           [188,  31,  48,  ..., 116, 171,  28]],
 
          [[ 46,  30,  41,  ..., 214,  48,  53],
           [103,  48, 164,  ..., 138, 115,  16],
           [ 25,  51,  88,  ..., 225, 146,  22],
           ...,
           [ 68,  12, 227,  ..., 130,   0, 128],
           [ 65, 223,   8,  ..., 250, 151,   7],
           [214,   4, 187,  ..., 155, 162,  84]]]]),
 tensor([[[[ 71, 209,  19, 110],
           [223,  49,  97, 

In [71]:
patched.shape

torch.Size([4, 3, 2, 2])

In [None]:
from copy import copy
from cirkit.symbolic.circuit import Scope
from cirkit.symbolic.circuit import Circuit


from cirkit.symbolic.layers import InputLayer, SumLayer
from cirkit.symbolic.parameters import (
    Parameter,
    ReferenceParameter,
    TensorParameter,
)

def copy_circuit(graph: Circuit,root_node_outputs=None):
    new_circ_layers = []
    copy_map = {}
    in_nodes = {}
    outputs = []
    for layer in graph.topological_ordering():
        new_layer = layer.copyref()
        new_circ_layers.append(new_layer)
        copy_map[layer] = new_layer
        inputs = [copy_map[in_node] for in_node in graph.node_inputs(layer)]
        if len(inputs) > 0:
            in_nodes[new_layer] = inputs
    outputs = [copy_map[out_node] for out_node in graph.outputs]
    if root_node_outputs is not None:
        for n in outputs:
            n.num_output_units = root_node_outputs
    return new_circ_layers, in_nodes, outputs


def patch_circuits(top_circuit_param, patch_circuit_param):
    im_shape = [i*k for i,k in zip(top_circuit_param["image_shape"],patch_circuit_param["image_shape"])]
    kernel_shape = patch_circuit_param["image_shape"]

    example_data = torch.arange(im_shape[0]*im_shape[1]*im_shape[2]).reshape(1, *im_shape)
    patch_fn = patchify(kernel_shape[1:], kernel_shape[1:])
    scope_order=patch_fn(example_data).reshape(-1, kernel_shape[0]*kernel_shape[1]*kernel_shape[2])
    print(scope_order)
    top = data_modalities.image_data(**top_circuit_param)
    new_layers = top._nodes.copy()
    new_inputs = top._in_nodes.copy()
    for new_scope,input_node in zip(scope_order, list(top.layerwise_topological_ordering())[0]):
        # Remove input node
        new_layers.remove(input_node)
        # add output of patch (create patch)
        patch = data_modalities.image_data(**patch_circuit_param)
        patch_input = list(patch.layerwise_topological_ordering())[0]
        for idx,inp in enumerate(patch_input):
            inp.scope= Scope([new_scope[list(inp.scope)[0]].item()])
        new_layers.extend(patch._nodes)

        # verify connections
        for node, inputs in top._in_nodes.items():
            if input_node in inputs:
                new_inputs[node].remove(input_node)
                new_inputs[node].extend(patch.outputs)
        new_inputs.update(patch._in_nodes)
        

        
    return Circuit(new_layers, new_inputs, top.outputs)
    # Update scope

units=512
new_circ=patch_circuits(
    {
        "image_shape": (1,16,16),
        "region_graph": "quad-graph",
        "input_layer": "categorical",
        "num_input_units":units,
        "sum_product_layer": "cp",
        "num_sum_units": units,
        "sum_weight_param": utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    },
     {
        "image_shape": (3, kernel_size, kernel_size),
        "region_graph": "quad-graph",
        "input_layer": "categorical",
        "num_input_units": units,
        "sum_product_layer": "cp",
        "num_sum_units": units,
        "num_classes":units,
        "sum_weight_param": utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
    }
)
cnew=compile(new_circ)

# plot_circuit(new_circ)
# new_circ.layers[-1].params['weight'].nodes

tensor([[    0,     1,     2,  ...,  8385,  8386,  8387],
        [    4,     5,     6,  ...,  8389,  8390,  8391],
        [    8,     9,    10,  ...,  8393,  8394,  8395],
        ...,
        [ 3892,  3893,  3894,  ..., 12277, 12278, 12279],
        [ 3896,  3897,  3898,  ..., 12281, 12282, 12283],
        [ 3900,  3901,  3902,  ..., 12285, 12286, 12287]])


In [11]:
test_circ=data_modalities.image_data(
        (3,64,64),
        region_graph="quad-graph",
        input_layer="categorical",
        num_input_units=128,
        sum_product_layer="cp",
        num_sum_units=128,
        num_classes=1,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
)
scopes=list(map(lambda x:list(x.scope)[0], next(test_circ.layerwise_topological_ordering())))
nodes=list(test_circ.layerwise_topological_ordering())[6]
sub_circ=test_circ.subgraph(nodes[0])
sub_circ.scope
ctest=compile(test_circ)

# ctest


In [13]:
ctest.layers[18]

TorchSumLayer(
  folds: 32  arity: 1  input-units: 128  output-units: 128
  input-shape: (32, 1, -1, 128)
  output-shape: (32, -1, 128)
  (weight): TorchParameter(
    shape: (32, 128, 128)
    (0): TorchTensorParameter(output-shape: (32, 128, 128))
    (1): TorchSoftmaxParameter(
      input-shapes: [(32, 128, 128)]
      output-shape: (32, 128, 128)
    )
  )
)

In [14]:

cnew.layers[18]

TorchSumLayer(
  folds: 32  arity: 1  input-units: 512  output-units: 512
  input-shape: (32, 1, -1, 512)
  output-shape: (32, -1, 512)
  (weight): TorchParameter(
    shape: (32, 512, 512)
    (0): TorchTensorParameter(output-shape: (32, 512, 512))
    (1): TorchSoftmaxParameter(
      input-shapes: [(32, 512, 512)]
      output-shape: (32, 512, 512)
    )
  )
)

In [15]:
ckpt=torch.load("pconv/shared.ckpt")

In [18]:
ckpt['hyper_parameters']['config']

{'circuit_type': 'shared',
 'layer_type': 'cp-t',
 'region_graph': 'quad-graph',
 'num_units': 128,
 'lr': 0.05,
 'dataset': 'celeba',
 'kernel_size': [2, 2],
 'colour_transform': 'ycc_lossless',
 'batch_size': 32,
 'early_stopping_delta': 10,
 'experiment_path': 'experiments/bench-shared-64/',
 'image_size': (64, 64),
 'channel': 3,
 'num_classes': 128}

In [20]:
module = BenchPCImage(ckpt['hyper_parameters']['config'])

initializing benchmark PC
benchmark PC initialized


In [23]:
module.circuit

TorchCircuit(
  (0): TorchCategoricalLayer(
    folds: 12288  variables: 1  output-units: 128
    input-shape: (12288, 1, -1, 1)
    output-shape: (12288, -1, 128)
    (probs): TorchParameter(
      shape: (12288, 128, 256)
      (0): TorchTensorParameter(output-shape: (12288, 128, 256))
      (1): TorchSoftmaxParameter(
        input-shapes: [(12288, 128, 256)]
        output-shape: (12288, 128, 256)
      )
    )
  )
  (1): TorchHadamardLayer(
    folds: 4096  arity: 3  input-units: 128  output-units: 128
    input-shape: (4096, 3, -1, 128)
    output-shape: (4096, -1, 128)
  )
  (2): TorchCPTLayer(
    folds: 4096  arity: 2  input-units: 128  output-units: 128
    input-shape: (4096, 2, -1, 128)
    output-shape: (4096, -1, 128)
    (weight): TorchParameter(
      shape: (4096, 128, 128)
      (0): TorchTensorParameter(output-shape: (4096, 128, 128))
      (1): TorchSoftmaxParameter(
        input-shapes: [(4096, 128, 128)]
        output-shape: (4096, 128, 128)
      )
    )
  )
  

In [22]:
module.load_state_dict(ckpt['state_dict'])

RuntimeError: Error(s) in loading state_dict for BenchPCImage:
	Missing key(s) in state_dict: "circuit._nodes.0.probs._nodes.0._ptensor", "circuit._nodes.0.probs._address_book._in_fold_idx_2_0", "circuit._nodes.2.weight._nodes.0._ptensor", "circuit._nodes.2.weight._address_book._in_fold_idx_2_0", "circuit._nodes.3.weight._nodes.0._ptensor", "circuit._nodes.3.weight._address_book._in_fold_idx_2_0", "circuit._nodes.4.weight._nodes.0._ptensor", "circuit._nodes.4.weight._address_book._in_fold_idx_3_0". 
	Unexpected key(s) in state_dict: "circuit._nodes.0.probs.internal_param.0._ptensor", "circuit._nodes.2.weight.internal_param.0._ptensor", "circuit._nodes.3.weight.internal_param.0._ptensor", "circuit._nodes.4.weight.internal_param.0._ptensor". 
	size mismatch for circuit._nodes.18.weight._nodes.0._ptensor: copying a param with shape torch.Size([2, 1, 128]) from checkpoint, the shape in current model is torch.Size([2, 128, 128]).
	size mismatch for circuit._nodes.19.weight._nodes.0._ptensor: copying a param with shape torch.Size([1, 1, 2]) from checkpoint, the shape in current model is torch.Size([1, 128, 2]).

## Test Connection Base Circuit

In [4]:
from cirkit.backend.torch.layers.inner import TorchSumLayer
from cirkit.backend.torch.layers.optimized import TorchCPTLayer
from cirkit.backend.torch.layers.input import TorchInputLayer
from cirkit.backend.torch.circuits import TorchCircuit
from cirkit.backend.torch.parameters.nodes import TorchParameterInput, TorchMixingWeightParameter, TorchTensorParameter, TorchUnaryParameterOp
from cirkit.backend.torch.parameters.parameter import TorchParameter



def copy_parameter(graph: TorchParameter, new_shape):
    new_param_nodes = []
    copy_map = {}
    in_nodes = {}
    outputs = []
    for n in graph.topological_ordering():
        instance = type(n)
        config = n.config
        if isinstance(n, TorchTensorParameter):
            del config["shape"]
            new_param = instance(*new_shape,**config)
            new_param._ptensor = torch.nn.Parameter(torch.zeros((graph.shape[0],*new_shape)))

        elif isinstance(n, TorchUnaryParameterOp):
            config["in_shape"]=new_shape

            new_param = instance(**config)
        new_param_nodes.append(new_param)
        copy_map[n] = new_param
        inputs = [copy_map[in_node] for in_node in graph.node_inputs(n)]
        if len(inputs) > 0:
            in_nodes[new_param] = inputs
    outputs = [copy_map[out_node] for out_node in graph.outputs]
    parameter= TorchParameter(modules=new_param_nodes, in_modules=in_nodes, outputs=outputs)
    return parameter

class TorchSharedParameter(TorchParameterInput):
    def __init__(self,
        in_shape:tuple[int,...],
        parameter:list[torch.nn.Module],
        num_folds:int
):
        super().__init__()
        self._num_folds=num_folds
        self.in_shape=in_shape
        self.internal_param = parameter
    
    def forward(self):
        current_input=None
        for param in self.internal_param:
            if current_input is None:
                current_input=param()
            else:
                current_input=param(current_input)
        share_fold, *inner_units = current_input.shape
        expanded = current_input.expand(self.num_folds//share_fold,share_fold, *inner_units).reshape(self.num_folds,*inner_units)

        return expanded

    @property
    def shape(self):
        return self.in_shape

def share_param_like(base_circ:TorchCircuit, share_struct:TorchCircuit, should_init_mean=False):
    for idx, layer in enumerate(share_struct.layers):
        print(layer)
        if isinstance(layer, TorchInputLayer):
            
            folds = base_circ.layers[idx].probs.num_folds
            shared_param = TorchSharedParameter(base_circ.layers[idx].probs.shape, parameter=layer.probs.nodes, num_folds=folds)
            base_circ.layers[idx].probs = shared_param
        elif isinstance(layer, TorchCPTLayer) or isinstance(layer, TorchSumLayer):
            internal_param=layer.weight.nodes
            if layer.num_output_units != base_circ.layers[idx].num_output_units:
                new_parameter = copy_parameter(layer.weight, base_circ.layers[idx].weight.nodes[0].shape)
                print(internal_param[0]._ptensor.data.shape)

                print(new_parameter.nodes[0]._ptensor.data.shape)
                _,o,i = internal_param[0]._ptensor.data.shape
                _, goal_o, goal_i = new_parameter.nodes[0]._ptensor.data.shape
                num_input = goal_i//i
                num_output = goal_o//o
                new_parameter.nodes[0]._ptensor.data=internal_param[0]._ptensor.data.clone().repeat((1, num_output, num_input))
                print(new_parameter.nodes[0]._ptensor.data.shape)

                internal_param=new_parameter.nodes
            folds = base_circ.layers[idx].weight.num_folds
            shared_param = TorchSharedParameter(base_circ.layers[idx].weight.shape, parameter=internal_param, num_folds=folds)
            base_circ.layers[idx].weight = shared_param
    if should_init_mean:
        init_mean(base_circ, len(share_struct.layers))


def init_mean(circ, start_idx):
    for layer in circ.layers[start_idx:]:
        print(layer)
        if isinstance(layer, TorchCPTLayer) or isinstance(layer, TorchSumLayer):
            param = layer.weight
            tensor = param.nodes[0]._ptensor
            inputs =tensor.shape[-1]

            param.nodes[0]._ptensor.data = torch.full(tensor.shape, torch.exp(torch.tensor(1/inputs)))

patch_circ=data_modalities.image_data(
        (3,kernel_size,kernel_size),
        region_graph="quad-graph",
        input_layer="categorical",
        num_input_units=units,
        sum_product_layer="cp",
        num_sum_units=units,
        num_classes=1,
        sum_weight_param=utils.Parameterization(
            activation="softmax", initialization="normal"
        ),
)
# top_circ=data_modalities.image_data(
#         (1,2,2),
#         region_graph="quad-graph",
#         input_layer="categorical",
#         num_input_units=2,
#         sum_product_layer="cp-t",
#         num_sum_units=2,
#         sum_weight_param=utils.Parameterization(
#             activation="softmax", initialization="normal"
#         ),
# )
cpatch = compile(patch_circ)
cbase = compile(new_circ)
# ctop=compile(top_circ)

share_param_like(cbase, cpatch, should_init_mean=False)
# cbase

TorchCategoricalLayer(
  folds: 48  variables: 1  output-units: 512
  input-shape: (48, 1, -1, 1)
  output-shape: (48, -1, 512)
  (probs): TorchParameter(
    shape: (48, 512, 256)
    (0): TorchTensorParameter(output-shape: (48, 512, 256))
    (1): TorchSoftmaxParameter(
      input-shapes: [(48, 512, 256)]
      output-shape: (48, 512, 256)
    )
  )
)
TorchHadamardLayer(
  folds: 16  arity: 3  input-units: 512  output-units: 512
  input-shape: (16, 3, -1, 512)
  output-shape: (16, -1, 512)
)
TorchSumLayer(
  folds: 32  arity: 1  input-units: 512  output-units: 512
  input-shape: (32, 1, -1, 512)
  output-shape: (32, -1, 512)
  (weight): TorchParameter(
    shape: (32, 512, 512)
    (0): TorchTensorParameter(output-shape: (32, 512, 512))
    (1): TorchSoftmaxParameter(
      input-shapes: [(32, 512, 512)]
      output-shape: (32, 512, 512)
    )
  )
)
TorchCPTLayer(
  folds: 16  arity: 2  input-units: 512  output-units: 512
  input-shape: (16, 2, -1, 512)
  output-shape: (16, -1, 512

In [9]:
copy_parameter(cpatch.layers[-1].weight,(2,2))


TorchParameter(
  shape: (1, 2, 4)
  (0): TorchTensorParameter(output-shape: (1, 2, 2))
  (1): TorchSoftmaxParameter(
    input-shapes: [(1, 2, 2)]
    output-shape: (1, 2, 2)
  )
  (2): TorchMixingWeightParameter(
    input-shapes: [(1, 2, 2)]
    output-shape: (1, 2, 4)
  )
)

In [43]:
patched.reshape(-1, 3*4*4)

tensor([[  0,   1,   2,   3,   8,   9,  10,  11,  16,  17,  18,  19,  24,  25,
          26,  27,  64,  65,  66,  67,  72,  73,  74,  75,  80,  81,  82,  83,
          88,  89,  90,  91, 128, 129, 130, 131, 136, 137, 138, 139, 144, 145,
         146, 147, 152, 153, 154, 155],
        [  4,   5,   6,   7,  12,  13,  14,  15,  20,  21,  22,  23,  28,  29,
          30,  31,  68,  69,  70,  71,  76,  77,  78,  79,  84,  85,  86,  87,
          92,  93,  94,  95, 132, 133, 134, 135, 140, 141, 142, 143, 148, 149,
         150, 151, 156, 157, 158, 159],
        [ 32,  33,  34,  35,  40,  41,  42,  43,  48,  49,  50,  51,  56,  57,
          58,  59,  96,  97,  98,  99, 104, 105, 106, 107, 112, 113, 114, 115,
         120, 121, 122, 123, 160, 161, 162, 163, 168, 169, 170, 171, 176, 177,
         178, 179, 184, 185, 186, 187],
        [ 36,  37,  38,  39,  44,  45,  46,  47,  52,  53,  54,  55,  60,  61,
          62,  63, 100, 101, 102, 103, 108, 109, 110, 111, 116, 117, 118, 119,
         12

In [5]:
# print("Circuit Input: ", patched.reshape(-1,4).shape, patched.reshape(-1,4, 1).permute(1,0,2))
res = cpatch(patched.reshape(-1, 3*kernel_size*kernel_size))
res.sum()

TorchCategoricalLayer:
                       Input shape: torch.Size([48, 64, 1])
                       Input: tensor([[[235],
         [ 26],
         [ 72],
         ...,
         [202],
         [183],
         [ 23]],

        [[ 16],
         [234],
         [201],
         ...,
         [ 73],
         [135],
         [217]],

        [[101],
         [211],
         [218],
         ...,
         [204],
         [117],
         [239]],

        ...,

        [[172],
         [173],
         [ 98],
         ...,
         [184],
         [227],
         [250]],

        [[ 61],
         [156],
         [141],
         ...,
         [ 37],
         [134],
         [141]],

        [[ 98],
         [  1],
         [155],
         ...,
         [ 14],
         [182],
         [110]]])
Output (torch.Size([48, 64, 512])):  tensor([[[-5.9017, -5.2065, -4.6380,  ..., -7.0506, -6.0913, -5.5275],
         [-6.8701, -6.3171, -7.0039,  ..., -5.7161, -6.9219, -6.8037],
         [-6.8242, -4.

tensor(-17045.1230, grad_fn=<SumBackward0>)

In [6]:
print(example_data.reshape(1,-1))

cbase(example_data.reshape(1,-1))

tensor([[235,  44, 101,  ...,  28, 137, 141]])
TorchCategoricalLayer:
                       Input shape: torch.Size([3072, 1, 1])
                       Input: tensor([[[235]],

        [[ 16]],

        [[101]],

        ...,

        [[250]],

        [[141]],

        [[110]]])
Output (torch.Size([3072, 1, 512])):  tensor([[[-5.9017, -5.2065, -4.6380,  ..., -7.0506, -6.0913, -5.5275]],

        [[-5.3310, -6.3777, -3.8229,  ..., -6.1503, -6.5037, -6.8531]],

        [[-7.7115, -4.8916, -5.9795,  ..., -5.3686, -5.3837, -7.6764]],

        ...,

        [[-4.3712, -5.9586, -4.8683,  ..., -5.4677, -5.5597, -6.6288]],

        [[-4.2110, -5.9805, -4.9656,  ..., -7.2668, -5.6935, -5.9949]],

        [[-7.6926, -5.0970, -7.2634,  ..., -6.9023, -7.5560, -6.1037]]],
       grad_fn=<IndexBackward0>)
TorchHadamardLayer:
                       Input shape: torch.Size([1024, 3, 1, 512])
                       Input: tensor([[[[-5.9017, -5.2065, -4.6380,  ..., -7.0506, -6.0913, -5.5275]],

    

tensor([[[-17045.1230]]], grad_fn=<TransposeBackward0>)

In [39]:
cbase.layers[0].probs

TorchSharedParameter(
  output-shape: (3072, 512, 256)
  (internal_param): ModuleList(
    (0): TorchTensorParameter(output-shape: (48, 512, 256))
    (1): TorchSoftmaxParameter(
      input-shapes: [(48, 512, 256)]
      output-shape: (48, 512, 256)
    )
  )
)

In [33]:
cbase.layers[0].probs()

tensor([[[1.2929e-02, 1.2673e-02, 1.3991e-03,  ..., 4.2933e-03,
          5.4251e-03, 1.1941e-03],
         [1.6076e-03, 1.9501e-04, 1.3110e-03,  ..., 3.0377e-03,
          1.2048e-03, 1.7922e-03],
         [2.3351e-03, 1.1709e-03, 2.8509e-03,  ..., 4.5627e-03,
          3.3654e-03, 1.8465e-03],
         ...,
         [3.8859e-03, 3.2517e-04, 6.2499e-03,  ..., 7.5306e-04,
          1.3511e-03, 2.3571e-03],
         [6.6722e-03, 4.4425e-03, 6.9162e-03,  ..., 2.1398e-03,
          1.6363e-03, 2.9410e-03],
         [2.3152e-03, 7.2537e-04, 2.9835e-03,  ..., 1.7328e-03,
          1.0447e-03, 1.4942e-03]],

        [[1.2807e-03, 2.0195e-02, 1.2609e-03,  ..., 1.2189e-03,
          1.3611e-03, 5.2363e-04],
         [1.0251e-02, 2.2320e-03, 6.3639e-03,  ..., 8.4052e-04,
          2.1866e-03, 1.2582e-03],
         [7.5096e-03, 2.1310e-03, 1.2494e-03,  ..., 1.2873e-03,
          9.7402e-03, 5.6151e-03],
         ...,
         [4.3466e-04, 3.4471e-03, 2.9194e-04,  ..., 2.3836e-03,
          5.281

In [35]:
cpatch.layers[0]

TorchCategoricalLayer(
  folds: 48  variables: 1  output-units: 512
  input-shape: (48, 1, -1, 1)
  output-shape: (48, -1, 512)
  (probs): TorchParameter(
    shape: (48, 512, 256)
    (0): TorchTensorParameter(output-shape: (48, 512, 256))
    (1): TorchSoftmaxParameter(
      input-shapes: [(48, 512, 256)]
      output-shape: (48, 512, 256)
    )
  )
)

In [34]:
cpatch.layers[0].probs()


tensor([[[1.2929e-02, 1.2673e-02, 1.3991e-03,  ..., 4.2933e-03,
          5.4251e-03, 1.1941e-03],
         [1.6076e-03, 1.9501e-04, 1.3110e-03,  ..., 3.0377e-03,
          1.2048e-03, 1.7922e-03],
         [2.3351e-03, 1.1709e-03, 2.8509e-03,  ..., 4.5627e-03,
          3.3654e-03, 1.8465e-03],
         ...,
         [3.8859e-03, 3.2517e-04, 6.2499e-03,  ..., 7.5306e-04,
          1.3511e-03, 2.3571e-03],
         [6.6722e-03, 4.4425e-03, 6.9162e-03,  ..., 2.1398e-03,
          1.6363e-03, 2.9410e-03],
         [2.3152e-03, 7.2537e-04, 2.9835e-03,  ..., 1.7328e-03,
          1.0447e-03, 1.4942e-03]],

        [[1.2807e-03, 2.0195e-02, 1.2609e-03,  ..., 1.2189e-03,
          1.3611e-03, 5.2363e-04],
         [1.0251e-02, 2.2320e-03, 6.3639e-03,  ..., 8.4052e-04,
          2.1866e-03, 1.2582e-03],
         [7.5096e-03, 2.1310e-03, 1.2494e-03,  ..., 1.2873e-03,
          9.7402e-03, 5.6151e-03],
         ...,
         [4.3466e-04, 3.4471e-03, 2.9194e-04,  ..., 2.3836e-03,
          5.281

In [61]:
ctest.layers[0].scope_idx.reshape(-1,3)

tensor([[  0,  64, 128],
        [ 65, 129,   1],
        [  2,  66, 130],
        [  3,  67, 131],
        [132,   4,  68],
        [  5,  69, 133],
        [134,  70,   6],
        [  7,  71, 135],
        [  8,  72, 136],
        [ 73, 137,   9],
        [ 10,  74, 138],
        [ 11,  75, 139],
        [140,  12,  76],
        [ 13,  77, 141],
        [142,  78,  14],
        [ 15,  79, 143],
        [ 16,  80, 144],
        [ 81, 145,  17],
        [ 18,  82, 146],
        [ 19,  83, 147],
        [148,  20,  84],
        [ 21,  85, 149],
        [150,  86,  22],
        [ 23,  87, 151],
        [ 24,  88, 152],
        [ 89, 153,  25],
        [ 26,  90, 154],
        [ 27,  91, 155],
        [156,  28,  92],
        [ 29,  93, 157],
        [158,  94,  30],
        [ 31,  95, 159],
        [ 32,  96, 160],
        [ 33, 161,  97],
        [162,  34,  98],
        [ 35,  99, 163],
        [164, 100,  36],
        [ 37, 101, 165],
        [ 38, 166, 102],
        [103,  39, 167],


In [130]:

for variable_id in range(16):
    probs_four = torch.zeros((1,1,256))
    probs_four[0,0,variable_id]=1
    cbase.layers[variable_id].probs.nodes[0]._ptensor.data=probs_four.data*20

In [311]:
cbase_folded =compile(new_circ)
cbase_folded(example_data.reshape(1,-1))
# cbase_folded.layers[0].probs.nodes[0]._ptensor.data.shape
# patched.reshape(1,-1)

tensor([[[-94.7456]]], grad_fn=<TransposeBackward0>)

In [312]:
cbase_folded

TorchCircuit(
  (0): TorchCategoricalLayer(
    folds: 16  variables: 1  output-units: 2
    input-shape: (16, 1, -1, 1)
    output-shape: (16, -1, 2)
    (probs): TorchParameter(
      shape: (16, 2, 256)
      (0): TorchTensorParameter(output-shape: (16, 2, 256))
      (1): TorchSoftmaxParameter(
        input-shapes: [(16, 2, 256)]
        output-shape: (16, 2, 256)
      )
    )
  )
  (1): TorchCPTLayer(
    folds: 16  arity: 2  input-units: 2  output-units: 2
    input-shape: (16, 2, -1, 2)
    output-shape: (16, -1, 2)
    (weight): TorchParameter(
      shape: (16, 2, 2)
      (0): TorchTensorParameter(output-shape: (16, 2, 2))
      (1): TorchSoftmaxParameter(
        input-shapes: [(16, 2, 2)]
        output-shape: (16, 2, 2)
      )
    )
  )
  (2): TorchCPTLayer(
    folds: 8  arity: 2  input-units: 2  output-units: 2
    input-shape: (8, 2, -1, 2)
    output-shape: (8, -1, 2)
    (weight): TorchParameter(
      shape: (8, 2, 2)
      (0): TorchTensorParameter(output-shape: 

In [None]:
folded_tensor=torch.zeros((16,2,256))
for idx,variable_id in enumerate(patched.reshape(-1).long()):
    # probs_four = torch.arange(1,257).reshape(1,1,256)
    probs_four=torch.zeros((1,2,256))
    probs_four[0,0,variable_id]=1
    probs_four[0,1,variable_id+1]=1
    folded_tensor[idx, :, :]=(probs_four.data*20)
cbase_folded.layers[0].probs.nodes[0]._ptensor.data = folded_tensor.data[patched.reshape(-1).long(),:,:]
cbase_folded.layers[0].probs().log()

tensor([[[-4.7684e-07, -2.0000e+01, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01],
         [-2.0000e+01, -4.7684e-07, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01]],

        [[-2.0000e+01, -4.7684e-07, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01],
         [-2.0000e+01, -2.0000e+01, -4.7684e-07,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01]],

        [[-2.0000e+01, -2.0000e+01, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01],
         [-2.0000e+01, -2.0000e+01, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01]],

        ...,

        [[-2.0000e+01, -2.0000e+01, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01],
         [-2.0000e+01, -2.0000e+01, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01]],

        [[-2.0000e+01, -2.0000e+01, -2.0000e+01,  ..., -2.0000e+01,
          -2.0000e+01, -2.0000e+01],
         [-2.0000e+01, -2.0000e+0

tensor([[ 0.,  1.,  4.,  5.,  2.,  3.,  6.,  7.,  8.,  9., 12., 13., 10., 11.,
         14., 15.]])

In [242]:
cbase_folded.layers[0].log_unnormalized_likelihood(example_data.reshape(16,-1, 1))

cbase_folded.layers[0].log_unnormalized_likelihood(patched.reshape(1,-1).reshape(16,-1, 1))


tensor([[[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]],

        [[-4.7684e-07, -2.0000e+01]]], grad_fn=<IndexBackward0>)