In [None]:
from bloqade import squin
from kirin.dialects.ilist import IList
import numpy as np
import bloqade.stim
import bloqade.tsim

In [None]:
def syndrome_extraction(theta, phi,lam, phys_qubits = 7, basis="z"):
    @squin.kernel
    # General MSD encoding with parameterized input state
    def parameterized_MSD_encoding(q_subset,theta,phi,lam): 
        # Prepare injected state on last qubit
        squin.u3(theta,phi,lam,q_subset[6]) 
        
        # Apply MSD encoding circuit
        for i in range(6):
            squin.sqrt_y_adj(q_subset[i]) 
        squin.cz(q_subset[1], q_subset[2])
        squin.cz(q_subset[3], q_subset[4])
        squin.cz(q_subset[5], q_subset[6])
        squin.sqrt_y(q_subset[6])
        squin.cz(q_subset[0],q_subset[3])
        squin.cz(q_subset[2],q_subset[5])
        squin.cz(q_subset[4],q_subset[6])
        for i in range(5):
            squin.sqrt_y(q_subset[i+2])
        squin.cz(q_subset[0],q_subset[1])
        squin.cz(q_subset[2],q_subset[3])
        squin.cz(q_subset[4],q_subset[5])
        squin.sqrt_y(q_subset[1])
        squin.sqrt_y(q_subset[2])
        squin.sqrt_y(q_subset[4])
        
    @squin.kernel
    def circ():
        # Allocate 14 qubits: 7 for injected state, 7 for ancilla
        q=squin.qalloc(14)

        # Encode qubits
        parameterized_MSD_encoding([q[0],q[1],q[2],q[3],q[4],q[5],q[6]],theta=theta,phi=phi,lam=lam) # Injected state encoding
        parameterized_MSD_encoding([q[7],q[8],q[9],q[10],q[11],q[12],q[13]],theta=np.pi/2,phi=0,lam=np.pi) # Z-basis ancilla encoding

        # Entangle injected state with Z-basis ancilla (using CNOTs with injected state as control and ancilla as target)
        squin.cx(q[0],q[7])
        squin.cx(q[1],q[8])
        squin.cx(q[2],q[9])
        squin.cx(q[3],q[10])
        squin.cx(q[4],q[11])
        squin.cx(q[5],q[12])
        squin.cx(q[6],q[13])

        # Measure the ancilla qubits in the Z-basis
        z_measurement = squin.broadcast.measure([q[7],q[8],q[9],q[10],q[11],q[12],q[13]]) 

        # Reset the ancilla qubits to reuse them for X-basis ancilla
        squin.broadcast.reset([q[7], q[8], q[9], q[10], q[11], q[12], q[13]])

        # X-basis ancilla encoding
        parameterized_MSD_encoding([q[7],q[8],q[9],q[10],q[11],q[12],q[13]],theta=0,phi=0,lam=0)

        # Entangle injected state with X-basis ancilla (using CNOTs with ancilla as control and injected state as target)
        squin.cx(q[7],q[0])
        squin.cx(q[8],q[1])
        squin.cx(q[9],q[2])
        squin.cx(q[10],q[3])
        squin.cx(q[11],q[4])
        squin.cx(q[12],q[5])
        squin.cx(q[13],q[6])

        # Measure the ancilla qubits in the X-basis (apply Hadamard before measurement)
        squin.broadcast.h([q[7], q[8], q[9], q[10], q[11], q[12], q[13]])
        x_measurement = squin.broadcast.measure([q[7],q[8],q[9],q[10],q[11],q[12],q[13]])

        l_x = []

        # for i in range(7):
        #     if squin.broadcast.is_one(x_measurement[i]):
        #         l_x[i]=1
        #     else:
        #         l_x[i]=-1
        
        # control_bools = squin.broadcast.is_one(x_measurement)

        # for i in range(7):
        #     if x_measurement[i]:
        #         l_x.append(-1)
        #     else:
        #         l_x.append(1)

        # stabilizers = [[0,2,4,6],[3,4,5,6],[1,2,5,6]]

        # def get_int(bit_list):
        #     y = 0 
        #     for i in range(phys_qubits):
        #         if bit_list[i] == 1:
        #             y += (2**i)
        #     return y
        
        # def get_syndromes(eigenvalues, stabilizers):
        #     syndrome = []
        #     for stab in stabilizers:
        #         eigval = 1
        #         for idx in stab:
        #             eigval *= eigenvalues[idx]
        #         syndrome.append(eigval)
        #         syndrome = tuple(syndrome)
        #         syndrome = get_int(syndrome)
        #     return syndrome
        
        # def update_pauli_corrections_from_syndrome(current_pauli_list, syndrome, syndrome_table):
        #     if syndrome==7:
        #         return
        #     index = syndrome_table[syndrome]
        #     if current_pauli_list[index] == 0:
        #         current_pauli_list[index] = 1
        #     else:
        #         current_pauli_list[index] = 0

        # pauli_x = [0]*7
        # syndrome_table = {3:0, 6:1, 2:2, 5:3, 1:4, 4:5, 0:6}

        # update_pauli_corrections_from_syndrome(pauli_x, get_syndromes(l_x, stabilizers), syndrome_table)

        # for i in range(7):
        #     if pauli_x[i] == 1:
        #         squin.x(q[i])
        
    return circ

In [None]:
MSD_enc = syndrome_extraction(1,0,0)
tsim_circ = bloqade.tsim.Circuit(MSD_enc)
tsim_circ.diagram(height=400)