In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Flexible_IM_to_HW_Encoder(nn.Module):
    """
    A module that:
      1) Takes an image of shape (B, C, input_size, input_size).
      2) Averages it down to (B, C, m, m) via adaptive pooling.
      3) Flattens to (B, C*m*m) => 'IM basis' with dimension m^2.
      4) Embeds into the HW=2 subspace for 2m qubits => dimension choose(2m,2).
         Zeroes for pairs on the same side, IM values for cross-side pairs.
         
    Args:
        input_size (int): The height/width of the input image (e.g., 28 for MNIST).
        m (int): The dimension to pool down to (e.g., 10).
        in_channels (int): Number of input channels (e.g., 1 for MNIST).
        
    Example: 
        For MNIST, you'd typically do:
        
            encoder = Flexible_IM_to_HW_Encoder(input_size=28, m=10, in_channels=1)
            # Then feed a batch of shape (B, 1, 28, 28).
    """
    def __init__(self, input_size=28, m=10, in_channels=1):
        super().__init__()
        self.input_size = input_size
        self.m = m
        self.in_channels = in_channels

        # We'll adaptively pool from (input_size x input_size) to (m x m)
        self.pool = nn.AdaptiveAvgPool2d((m, m))

        # Dimension of IM basis = m*m
        # Dimension of HW=2 subspace for 2m qubits = comb(2m, 2) = (2m)(2m-1)/2
        self.hw_dim = (2*m)*(2*m - 1)//2

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (B, in_channels, input_size, input_size)
        Returns:
            x_hw: Tensor of shape (B, self.hw_dim)
                  (which is (2m choose 2)) 
        """
        B = x.size(0)

        # 1) Average-pool down to (B, C, m, m)
        x_pooled = self.pool(x)  # shape = (B, in_channels, m, m)

        # 2) Flatten the pooled output to (B, in_channels * m * m)
        x_im = x_pooled.view(B, -1)  # shape = (B, in_channels*m*m)

        # For typical MNIST (C=1), x_im has shape (B, m*m).
        # If C>1, then effectively we have C copies of (m*m) we can treat as separate channels 
        # for the "IM basis." 
        # But let's keep it straightforward and just treat it as one long vector of length C*m*m.

        # 3) Prepare output tensor in "HW basis" dimension = choose(2m,2)
        x_hw = x_im.new_zeros(B, self.hw_dim)  # shape = (B, hw_dim)

        # We'll define a helper for indexing the combination (a,b) where 0 <= a < b < 2m.
        def comb_index(a, b, n=2*self.m):
            """
            Map the pair (a,b) with a<b, a,b in [0..n-1]
            into a unique index in [0..nC2-1].
            nC2 = n*(n-1)/2.
            
            We'll use a known formula:
              index = (2n - a - 1)*a/2 + (b - a - 1)
            """
            return ( (2*n - a - 1)*a )//2 + (b - a - 1)

        # 4) Fill in cross-side pairs.
        #    - The left bits are [0..(m-1)]
        #    - The right bits are [m..(2m-1)]
        #
        # If we treat (C*m*m) as "C channels of an m x m patch," 
        # for each channel c in [0..C-1], 
        # for each row i in [0..m-1],
        # for each col j in [0..m-1],
        #   the amplitude is x_im[:, c*m*m + i*m + j].
        #
        # That amplitude goes into the HW-basis index comb_index(i, m+j).

        # Number of channels we flattened = C
        # Each channel has m*m entries.
        # We'll map each (channel, i, j) triple to the appropriate index in HW basis.

        total_entries_per_channel = self.m * self.m
        for c in range(self.in_channels):
            channel_offset = c*total_entries_per_channel
            for i in range(self.m):
                for j in range(self.m):
                    a = i
                    b = self.m + j
                    idx = comb_index(a, b, 2*self.m)

                    # The amplitude in x_im is at position [channel_offset + i*m + j]
                    x_hw[:, idx] += x_im[:, channel_offset + i*self.m + j]

        return x_hw



In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from math import comb
from torch.utils.data import Subset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

############################################
# 1) Perceval-related helper code (minimal)
############################################
def all_bitstrings_with_k_ones(n, k):
    from itertools import combinations
    bitstrings = []
    for ones_positions in combinations(range(n), k):
        bits = ['0'] * n
        for pos in ones_positions:
            bits[pos] = '1'
        bitstrings.append(''.join(bits))
    return bitstrings

def state_str_to_bits(state_str):
    bit_str_list = state_str.strip('|>').split(',')
    return ''.join(bit_str_list)

def distribution_to_vector(prob_dict, n, k):
    bitstrings = all_bitstrings_with_k_ones(n, k)
    idx_map = {bs: i for i, bs in enumerate(bitstrings)}
    output_vec = [0]*len(bitstrings)
    for k_str, v in prob_dict.items():
        bs = state_str_to_bits(str(k_str))
        if bs.count('1') == k and bs in idx_map:
            output_vec[idx_map[bs]] = v
    return output_vec

def generate_perceval_circuit(m, gate_list, encode_angles, train_params):
    """
    Minimal circuit builder:
      * gate_list = list of (i, j) 
      * first len(encode_angles) parameters => encode portion
      * next len(train_params) => training portion
    """
    import perceval as pcvl
    from perceval.components import BS, PERM

    circuit = pcvl.Circuit(m)
    param_index = 0

    # Combine them for demonstration (encode + train gates).
    # Or keep them separate if you prefer. We'll just do a single loop:
    for (i, j) in gate_list:
        # Possibly do some permutation
        if i+1 != j:
            n_ = abs(j - i)
            permutation = [n_-1] + list(range(1, n_-1)) + [0]
            circuit.add(i+1, PERM(permutation))

        # Insert a parameterized BS
        circuit.add(i, BS.H(theta=pcvl.P(f'phi_{param_index}')))
        param_index += 1

        if i+1 != j:
            n_ = abs(j - i)
            permutation = [n_-1] + list(range(1, n_-1)) + [0]
            circuit.add(i+1, PERM(permutation))

    # Set parameter values
    params = circuit.get_parameters()
    # encode portion
    for i in range(len(encode_angles)):
        params[i].set_value(encode_angles[i])
    # train portion
    start = len(encode_angles)
    for i in range(len(train_params)):
        params[start + i].set_value(train_params[i])

    return circuit

def run_perceval_circuit(m, n, circuit, input_state_list, postselect, samples):
    """
    Return sqrt(prob) as (dim_state)-tensor
    """
    import perceval as pcvl
    proc = pcvl.Processor("SLOS", m)
    proc.set_circuit(circuit)
    proc.min_detected_photons_filter(postselect)
    proc.thresholded_output(True)
    proc.with_input(pcvl.BasicState(input_state_list))

    sampler = pcvl.algorithm.Sampler(proc, max_shots_per_call=samples)
    res = sampler.probs(samples)
    distribution = res["results"]

    vec = distribution_to_vector(distribution, m, n)
    t = torch.tensor(vec, dtype=torch.float32)
    return torch.sqrt(t)


############################################
# 2) Custom autograd with finite difference
############################################
class FiniteDiffFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, theta, m, n, gate_list, encode_angles, input_state_list, postselect, samples, eps):
        """
        Returns (dim_state)-tensor from the circuit
        """
        ctx.m = m
        ctx.n = n
        ctx.gate_list = gate_list
        ctx.encode_angles = encode_angles
        ctx.input_state_list = input_state_list
        ctx.postselect = postselect
        ctx.samples = samples
        ctx.eps = eps
        ctx.save_for_backward(theta)

        # forward pass
        circuit = generate_perceval_circuit(m, gate_list, encode_angles, theta.detach().cpu().numpy())
        output = run_perceval_circuit(m, n, circuit, input_state_list, postselect, samples)
        return output  # shape (comb(m, n),)

    @staticmethod
    def backward(ctx, grad_output):
        (theta,) = ctx.saved_tensors
        m = ctx.m
        n = ctx.n
        gate_list = ctx.gate_list
        encode_angles = ctx.encode_angles
        input_state_list = ctx.input_state_list
        postselect = ctx.postselect
        samples = ctx.samples
        eps = ctx.eps

        D = theta.shape[0]
        grad_theta = torch.zeros_like(theta)
        half = eps / 2.0

        # For each param, do central difference
        for i in range(D):
            theta_plus = theta.clone()
            theta_minus = theta.clone()
            theta_plus[i] += half
            theta_minus[i] -= half

            circuit_plus = generate_perceval_circuit(m, gate_list, encode_angles, theta_plus.detach().cpu().numpy())
            out_plus = run_perceval_circuit(m, n, circuit_plus, input_state_list, postselect, samples)

            circuit_minus = generate_perceval_circuit(m, gate_list, encode_angles, theta_minus.detach().cpu().numpy())
            out_minus = run_perceval_circuit(m, n, circuit_minus, input_state_list, postselect, samples)

            diff = (out_plus - out_minus) / eps
            grad_theta[i] = torch.dot(grad_output, diff)

        return grad_theta, None, None, None, None, None, None, None, None


class PercevalCircuitModule(nn.Module):
    def __init__(self, m, n, gate_list, encode_angles, init_params,
                 postselect, samples, eps=1e-4):
        super().__init__()
        self.theta = nn.Parameter(init_params.clone())
        self.m = m
        self.n = n
        self.gate_list = gate_list
        self.encode_angles = encode_angles
        self.input_state_list = [1]*n + [0]*(m-n)
        self.postselect = postselect
        self.samples = samples
        self.eps = eps

    def forward(self):
        return FiniteDiffFunction.apply(
            self.theta,
            self.m, self.n,
            self.gate_list,
            self.encode_angles,
            self.input_state_list,
            self.postselect,
            self.samples,
            self.eps
        )


m = 4
n = 2
postselect = n
samples = 1
dim_state = comb(m, n)
batch_size = 1
# A list of gates
gate_list = [(i,j) for i in range(m) for j in range(m) if i<j]
encode_angles = []
init_params = torch.rand(len(gate_list), dtype=torch.float32)

# Build model
model = PercevalCircuitModule(
    m, n, gate_list,
    encode_angles, init_params,
    postselect=postselect,
    samples=samples,
    eps=1e-5  # smaller FD step => more accurate grads
)

N = 10
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root=".", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=".", train=False, transform=transform, download=True)
train_subset_indices = torch.arange(N)  # Select the first N indices
train_dataset = Subset(train_dataset, train_subset_indices)
test_subset_indices = torch.arange(N)  # Select the first N indices
test_dataset = Subset(test_dataset, test_subset_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

encoding_parameters = []
for images, labels in train_loader:
    # Let's define a small target distribution in 10D 
    # (random for demonstration).  We'll treat it as sqrt(prob).
    # target_dist = torch.rand(dim_state)
    # target_dist /= target_dist.sum()  # sum=1
    # target_sqrt = torch.sqrt(target_dist.float())

    # dummy_x = torch.randn(1, 1, 28, 28)
    # Instantiate our encoder with m=10
    encoder = Flexible_IM_to_HW_Encoder(input_size=28, m=m//2, in_channels=1)
    # Forward-pass
    out_hw = encoder(images)
    target_sqrt = out_hw.squeeze()
    # MSE Loss
    criterion = nn.MSELoss()

    # Use Adam
    optimizer = optim.Adam(model.parameters(), lr=3e-2)  # smaller LR

    # We can do more iterations to let Adam converge
    epochs = 1
    steps_per_epoch = 50

    for e in range(epochs):
        for step in range(steps_per_epoch):
            optimizer.zero_grad()
            output = model()  # shape(10,)

            loss = criterion(output, target_sqrt)
            loss.backward()
            optimizer.step()

            # if step % 10 == 0:
            #     print(f"Epoch={e+1}, Step={step}/{steps_per_epoch}, Loss={loss.item():.5f}")

    # Print final results
    final_output = model().detach()
    # final_output = final_output / final_output.sum()  # re-normalize to compare
    target_sqrt_renorm = target_sqrt / target_sqrt.sum()

    # print("\n--- Training finished ---")
    # print("Final target:", target_sqrt_renorm)
    # print("Final output:", final_output)
    encoding_parameters.append(model.theta.data)
    # print("Gate angles:", model.theta.data)
    print("Final MSE Loss:", criterion(final_output, target_sqrt_renorm).item())

Final MSE Loss: 0.033728644251823425
Final MSE Loss: 0.04162076860666275
Final MSE Loss: 0.03639378771185875
Final MSE Loss: 0.03733782470226288
Final MSE Loss: 0.03822031617164612
Final MSE Loss: 0.038432639092206955
Final MSE Loss: 0.039317868649959564
Final MSE Loss: 0.033871494233608246
Final MSE Loss: 0.02241045981645584
Final MSE Loss: 0.043453801423311234


# second part
Now we have encoding parameters! We can start our training to dignit labels:

In [12]:
len(encoding_parameters)

10

In [None]:
class Net(nn.Module):
    """
    Simple network for MNIST. It:
    1) Flattens the 28x28 input into a 784-dim vector
    2) Feeds it to a hidden layer (fc1)
    3) Passes it to Blackbox
    4) Feeds the result to the final classification layer (fc2)
    """
    def __init__(self, n, m, list_gates):
        super(Net, self).__init__()
        self.perceval_circuit = PercevalCircuitModule(
            m, n, list_gates,
            encode_angles, init_params,
            postselect=postselect,
            samples=samples,
            eps=1e-5  # smaller FD step => more accurate grads
        )

    def forward(self, x):
        # Flatten the input from (B, 1, 28, 28) -> (B, 784)
        # x = x.view(x.size(0), -1)
        x = self.perceval_circuit()  # Apply Blackbox to the hidden representation
        return x


