In [None]:
import z3
import numpy as np
from mqt.qecc.circuit_synthesis.synthesis_utils import symbolic_scalar_mult, symbolic_vector_add, symbolic_vector_eq
import itertools
import numpy.typing as npt

In [None]:
w = 4
errors = np.vstack(([np.eye(w, dtype=np.int8),
                    np.array([[1,1,0,0],
                             [0,0,1,1]])]))

solver = z3.Solver()
n_errors = errors.shape[0]

In [None]:
w = 8
errors = np.vstack(([np.eye(w, dtype=np.int8),
                    np.array([[1,1,0,0,0,0,0,0],
                             [0,0,1,1,0,0,0,0],
                             [0,0,0,0,1,1,0,0],
                             [0,0,0,0,0,0,1,1]]),
                   np.array([[1,1,1,1,0,0,0,0],
                             [0,0,0,0,1,1,1,1]])]))

solver = z3.Solver()
n_errors = errors.shape[0]

In [None]:
w = 32
errors = generate_errors(32)
solver = z3.Solver()
n_errors = errors.shape[0]

In [None]:
w = 16

errors = np.vstack([np.eye(w, dtype=np.int8),
                   np.array([[1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
                             [0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0],
                             [0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0],
                             [0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0],
                             [0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0],
                             [0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0],
                             [0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0],
                             [0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1]]),
                   np.array([[1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0],
                             [0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0],
                             [0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0],
                             [0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1]]),
                   np.array([[1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0],
                             [0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]])])

solver = z3.Solver()
n_errors = errors.shape[0]

In [None]:
permutation = [z3.BitVec(f"P_{i}", int(np.log2(w))) for i in range(w)]
solver.add(z3.Distinct(permutation))
solver.add(permutation[0]==0)
solver.add(permutation[1]==3)
solver.add(permutation[2]==7)
solver.add(permutation[3]==15)
for t in range(1, w//2-1):
    for t_prime in range(1, w//2-t):
        product_vars = [z3.Bool(f"e_{i}_{t}_{t_prime}") for i in range(n_errors)]
        permuted_product_vars = [z3.Bool(f"ep_{i}_{t}_{t_prime}") for i in range(n_errors)]
        
        weight_t = z3.PbEq([(p,1) for p in product_vars], t)
        weight_t_prime = z3.PbEq([(p,1) for p in permuted_product_vars], t_prime)
        
        # symbolically encode products of errors
        v = symbolic_scalar_mult(errors[0], product_vars[0])
        vp = symbolic_scalar_mult(errors[0], permuted_product_vars[0])
        for i in range(1, n_errors):
            v = symbolic_vector_add(v, symbolic_scalar_mult(errors[i], product_vars[i]))
            vp = symbolic_vector_add(vp, symbolic_scalar_mult(errors[i], permuted_product_vars[i]))
        permuted_vp = [z3.Bool(f"pvp_{i}_{t}_{t_prime}") for i in range(len(vp))]
        permutation_constr = []
       
        for i in range(w):
            for j in range(w):
                j_bv = z3.BitVecVal(j, int(np.log2(w)))
                permutation_constr.append(z3.Implies(permutation[i]==j_bv, permuted_vp[j] == vp[i]))
        permutation_constr = z3.And(permutation_constr)
        
        solver.add(z3.ForAll(
            product_vars+permuted_product_vars+permuted_vp,
            z3.Implies(
                z3.And(
                    weight_t,
                    weight_t_prime,
                    permutation_constr,
                ),
                z3.And(
                    z3.Implies(
                        symbolic_vector_eq(v, permuted_vp),
                            z3.Or(
                                z3.PbLe([(entry, 1) for entry in v], t_prime+t),
                                z3.PbGe([(entry, 1) for entry in v], w-t_prime-t)
                            ),
                        )
                    )
            )))
while solver.check() == z3.sat:
    model = solver.model()
    print ([model.eval(permutation[i]) for i in range(w)])
    solver.add(z3.Or([permutation[i]!=model.eval(permutation[i]) for i in range(w)])) # prevent next model from using the same assignment as a previous model

In [None]:
m = solver.model()
print([m.eval(permutation[i]) for i in range(w)])

In [None]:
def generate_errors_power_two(w: int) -> npt.ANDArray[np.int8]:
    """Generate error patterns for perfect balanced binary tree CNOT circuits to prepare cat states.

    Args:
        w: number of qubits in the cat state.
    Returns:
        numpy array representing the binary representations of all possible X errors stemming from one propagated error. Every row corresponds to one error.
    """
    errors = np.eye(w, dtype=np.int8)
    for i in range(1, int(np.log2(w))):
        new_errors = np.zeros((w//(2**i), w), dtype=np.int8)
        for j in range(w//(2**i)):
            new_errors[j, [k for k in range(j*(2**i), (j+1)*(2**i))]]=1
        errors = np.vstack([errors, new_errors])
    return errors


def find_ft_permutation_smt(errors_circ_1:npt.ANDArray[np.int8], errors_circ_2:npt.ANDArray[np.int8])->list[int]|None:
    """Try to find a permutation of qubits that ensures that two error sets do not overlap too much when one of them is permuted.

    Args:
        errors_circ_1: error matrix for first circuit.
        errors_circ_2: error matrix for second circuit.
    Returns: 
        A list of integers representing an ft permutation of qubits if it exists, else None"""
    assert errors_circ_1.shape[1] == errors_circ_2.shape[1]
    n_errors_1 = errors_circ_1.shape[0]
    n_errors_2 = errors_circ_2.shape[0]
    w = errors_circ_1.shape[1]
    
    solver = z3.Solver()
    permutation = [z3.BitVec(f"P_{i}", int(np.log2(w))) for i in range(w)]
    solver.add(z3.Distinct(permutation))
    for t in range(1, w//2-1):
        for t_prime in range(1, w//2-t):
            product_vars = [z3.Bool(f"e_{i}_{t}_{t_prime}") for i in range(n_errors_1)]
            permuted_product_vars = [z3.Bool(f"ep_{i}_{t}_{t_prime}") for i in range(n_errors_2)]
            
            weight_t = z3.PbEq([(p,1) for p in product_vars], t)
            weight_t_prime = z3.PbEq([(p,1) for p in permuted_product_vars], t_prime)
            
            # symbolically encode products of errors
            v = symbolic_scalar_mult(errors[0], product_vars[0])
            vp = symbolic_scalar_mult(errors[0], permuted_product_vars[0])
            for i in range(1, n_errors_1):
                v = symbolic_vector_add(v, symbolic_scalar_mult(errors[i], product_vars[i]))
            for i in range(1, n_errors_2):
                vp = symbolic_vector_add(vp, symbolic_scalar_mult(errors[i], permuted_product_vars[i]))
            permuted_vp = [z3.Bool(f"pvp_{i}_{t}_{t_prime}") for i in range(len(vp))]
            permutation_constr = []
           
            for i in range(w):
                for j in range(w):
                    j_bv = z3.BitVecVal(j, int(np.log2(w)))
                    permutation_constr.append(z3.Implies(permutation[i]==j_bv, permuted_vp[j] == vp[i]))
            permutation_constr = z3.And(permutation_constr)
            
            solver.add(z3.ForAll( # This is really expensive!
                product_vars+permuted_product_vars+permuted_vp,
                z3.Implies(
                    z3.And(
                        weight_t,
                        weight_t_prime,
                        permutation_constr,
                    ),
                    z3.And(
                        z3.Implies(
                            symbolic_vector_eq(v, permuted_vp),
                                z3.Or(
                                    z3.PbLe([(entry, 1) for entry in v], t_prime+t),
                                    z3.PbGe([(entry, 1) for entry in v], w-t_prime-t)
                                ),
                            )
                        )
                )))
    
    result = solver.check()
    if result == "sat":
        m = solver.model()
        return [m.eval(permutation[i]) for i in range(w)])
    return None