In [None]:
import cirq
import numpy as np

# Z
def sum_k_cirq(k, wires):
    """
    Generates gates for a phase shift sum.

    This applies a rotation gate to each wire, rotating by an angle 'k * pi / 2^j'.

    Args:
        k (int): The integer to be added in the phase.
        wires (list[cirq.Qid]): A list of qubits to apply the rotations to.

    Yields:
        cirq.Operation: The RZ gate operations.
    """
    for j, wire in enumerate(wires):
        angle = k * np.pi / (2**j)
        if angle != 0:
            yield cirq.rz(angle).on(wire)


def add_in_k_N_cirq(k, N, wires_a, wires_aux, is_in_standard_basis=True):
    """
    Generates a circuit for in-place modular addition of a classical integer 'k'.

    Args:
        k (int): The classical integer to add.
        N (int): The modulus.
        wires_a (list[cirq.Qid]): The qubit register to add 'k' into.
        wires_aux (list[cirq.Qid]): A list containing a single auxiliary qubit.
        is_in_standard_basis (bool): If True, applies QFT at the beginning and end.

    Yields:
        cirq.Operation: The operations for the modular adder.
    """
    if is_in_standard_basis:
        yield cirq.qft(*wires_a)

    # Step 1: Add 'k' to the register via phase shifts
    yield from sum_k_cirq(k, wires_a)

    # Step 2: Subtract 'N'
    yield from cirq.inverse(list(sum_k_cirq(N, wires_a)))

    # Step 3: Conditionally add 'N' back
    yield cirq.inverse(cirq.qft(*wires_a))
    yield cirq.CNOT(wires_a[0], wires_aux[0]) # Use MSB to control aux bit
    yield cirq.qft(*wires_a)

    # Step 4: Add 'N' to the register conditionally
    # qml.ctrl(op, control=c) is equivalent to op.controlled_by(c)
    controlled_sum_N = [
        op.controlled_by(wires_aux[0]) for op in sum_k_cirq(N, wires_a)
    ]
    yield from controlled_sum_N

    # Step 5: Clear the auxiliary bit
    yield from cirq.inverse(list(sum_k_cirq(k, wires_a)))

    # Step 6: Conditionally add 'N' back (undo)
    yield cirq.inverse(cirq.qft(*wires_a))
    yield cirq.X(wires_a[0])
    yield cirq.CNOT(wires_a[0], wires_aux[0])
    yield cirq.X(wires_a[0])
    yield cirq.qft(*wires_a)

    # Step 7: Add 'k' to the register
    yield from sum_k_cirq(k, wires_a)

    # Recover the standard basis if needed
    if is_in_standard_basis:
        yield cirq.inverse(cirq.qft(*wires_a))


def add_in_N_cirq(N, wires_a, wires_b, aux1, aux2):
    """
    Generates a circuit for inplace modular addition of two quantum registers.
    This is the Cirq equivalent of the PennyLane function add_in_N.

    Args:
        N (int): The modulus.
        wires_a (list[cirq.Qid]): Quantum register for the first input number.
        wires_b (list[cirq.Qid]): Quantum register for the second input number.
        aux1 (list[cirq.Qid]): A list containing one auxiliary qubit.
        aux2 (list[cirq.Qid]): A list containing one auxiliary qubit.

    Yields:
        cirq.Operation: The operations for the full modular addition.
    """
    # The adder circuit works on a register combining wires_b and an auxiliary qubit.
    new_wires_b = aux1 + wires_b

    # Apply the quantum Fourier transform to the target register
    yield cirq.qft(*new_wires_b)

    # Create controlled modular adders, controlled by each qubit in wires_a
    for i, control_qubit in enumerate(wires_a):
        value = 2**(len(wires_a) - 1 - i)

        # Generate the operations for the modular adder for 'value'
        adder_ops = list(
            add_in_k_N_cirq(value, N, new_wires_b, aux2, is_in_standard_basis=False)
        )

        # Control each operation by the corresponding qubit from the 'a' register
        controlled_adder_ops = [op.controlled_by(control_qubit) for op in adder_ops]
        yield from controlled_adder_ops

    # Inverse quantum Fourier transform on 'wires_b'
    yield cirq.inverse(cirq.qft(*new_wires_b))

def add_in_N_cirq_inverse(N, wires_a, wires_b, aux1, aux2):
    yield cirq.inverse(list(add_in_N_cirq(N, wires_a, wires_b, aux1, aux2)))
    
def test_modular_adder(a, b, N):
    """
    Tests the modular adder circuit for specific inputs a, b, and N.

    Args:
        a (int): The first number to add.
        b (int): The second number to add.
        N (int): The modulus.

    Returns:
        bool: True if the test passes, False otherwise.
    """
    # Determine the number of qubits needed to represent numbers up to N-1
    n_qubits = (N - 1).bit_length()

    # Define the qubit registers
    wires_a = [cirq.NamedQubit(f'a{i}') for i in range(n_qubits)]
    wires_b = [cirq.NamedQubit(f'b{i}') for i in range(n_qubits)]
    aux1 = [cirq.NamedQubit('aux1')]
    aux2 = [cirq.NamedQubit('aux2')]
    all_qubits = wires_a + wires_b + aux1 + aux2

    # Create a circuit
    circuit = cirq.Circuit()

    # --- Prepare the initial state ---
    # Set register 'a' to |a>
    for i, bit in enumerate(f'{a:0{n_qubits}b}'):
        if bit == '1':
            circuit.append(cirq.X(wires_a[i]))

    # Set register 'b' to |b>
    for i, bit in enumerate(f'{b:0{n_qubits}b}'):
        if bit == '1':
            circuit.append(cirq.X(wires_b[i]))

    # --- Add the modular addition circuit ---
    # This will compute (a + b) mod N and store it in wires_b
    circuit.append(add_in_N_cirq(N, wires_a, wires_b, aux1, aux2))

    # --- Measure the result register ---
    circuit.append(cirq.measure(*wires_b, key='result_b'))

    # --- Simulate and check the result ---
    simulator = cirq.Simulator()
    result = simulator.run(circuit, repetitions=10) # Run a few times for confidence
    
    # Get the measurement outcome and convert from bitstring to integer
    measurement = result.measurements['result_b'][0]
    actual_result = int("".join(map(str, measurement)), 2)
    
    expected_result = (a + b) % N

    # --- Print results ---
    print(f"--- Testing: ({a} + {b}) mod {N} ---")
    print(f"Expected result in register 'b': {expected_result}")
    print(f"Actual result from simulation:   {actual_result}")
    
    if actual_result == expected_result:
        print(">>> PASS\n")
        return True
    else:
        print(">>> FAIL\n")
        return False


if __name__ == "__main__":
    # --- Test Suite ---
    test_cases = [
        (1, 2, 4),
        (3, 3, 4),
        (5, 3, 8),
        (7, 7, 8),
        (6, 5, 11),
        (10, 15, 16)
    ]

    passed_count = 0
    for a, b, N in test_cases:
        if test_modular_adder(a, b, N):
            passed_count += 1
    
    print("--- Test Summary ---")
    print(f"{passed_count} / {len(test_cases)} tests passed.")



--- Testing: (1 + 2) mod 4 ---
Expected result in register 'b': 3
Actual result from simulation:   3
>>> PASS

--- Testing: (3 + 3) mod 4 ---
Expected result in register 'b': 2
Actual result from simulation:   2
>>> PASS

--- Testing: (5 + 3) mod 8 ---
Expected result in register 'b': 0
Actual result from simulation:   0
>>> PASS

--- Testing: (7 + 7) mod 8 ---
Expected result in register 'b': 6
Actual result from simulation:   6
>>> PASS

--- Testing: (6 + 5) mod 11 ---
Expected result in register 'b': 0
Actual result from simulation:   0
>>> PASS

--- Testing: (10 + 15) mod 16 ---
Expected result in register 'b': 9
Actual result from simulation:   9
>>> PASS

--- Test Summary ---
6 / 6 tests passed.


In [247]:
from cirq_sic import *

n = 1
d = 2**n
d_s = 2
phi = load_sic_fiducial(d_s)
embedded_phi = np.concatenate([phi, np.zeros(d-d_s)])
ansatz_preparation = ansatz_circuit(embedded_phi)

In [None]:
system_qubits = cirq.LineQubit.range(n)
ancilla_qubits = cirq.LineQubit.range(n, 2*n)
aux1 = [cirq.NamedQubit("aux1")]
aux2 = [cirq.NamedQubit("aux2")]

circ = cirq.Circuit((ansatz_preparation(ancilla_qubits, conjugate=True),
                     ansatz_preparation(system_qubits, conjugate=False),
                     CXdag(ancilla_qubits, system_qubits),
                     #add_in_N_cirq(d_s, ancilla_qubits, system_qubits, aux1, aux2),
                    ))#)Fdag(ancilla_qubits)))
                    # cirq.measure(system_qubits+ancilla_qubits, key="result")))
res = cirq.Simulator().simulate(circ)
res.density_matrix_of(ancilla_qubits+system_qubits)

array([[ 0.622+0.j   ,  0.228+0.228j,  0.167-0.j   ,  0.228-0.228j],
       [ 0.228-0.228j,  0.167+0.j   ,  0.061-0.061j, -0.   -0.167j],
       [ 0.167+0.j   ,  0.061+0.061j,  0.045+0.j   ,  0.061-0.061j],
       [ 0.228+0.228j, -0.   +0.167j,  0.061+0.061j,  0.167+0.j   ]],
      dtype=complex64)

In [261]:
circ = cirq.Circuit((ansatz_preparation(ancilla_qubits, conjugate=True),
                     ansatz_preparation(system_qubits, conjugate=False),
                     #CXdag(ancilla_qubits, system_qubits),
                     add_in_N_cirq(d_s, ancilla_qubits, system_qubits, aux1, aux2),
                    ))#Fdag(ancilla_qubits)
                    #))# ))
                    # cirq.measure(system_qubits+ancilla_qubits, key="result")))
res = cirq.Simulator().simulate(circ)
res.density_matrix_of(ancilla_qubits+system_qubits)

array([[ 0.622+0.j   ,  0.228+0.228j, -0.118-0.118j, -0.   +0.322j],
       [ 0.228-0.228j,  0.167+0.j   , -0.086-0.j   ,  0.118+0.118j],
       [-0.118+0.118j, -0.086+0.j   ,  0.045+0.j   , -0.061-0.061j],
       [-0.   -0.322j,  0.118-0.118j, -0.061+0.061j,  0.167+0.j   ]],
      dtype=complex64)

In [52]:
from cirq_sic import *

n = 2
d = 2**n
d_s = 4
phi = rand_ket(d_s)
embedded_phi = np.concatenate([phi, np.zeros(d-d_s)])
ansatz_preparation = ansatz_circuit(embedded_phi)

In [53]:
system_qubits = cirq.LineQubit.range(n)
aux1 = [cirq.NamedQubit("aux1")]

circ = cirq.Circuit((ansatz_preparation(system_qubits, conjugate=False),
                    X(system_qubits)))
                    #Z(system_qubits)))
res = cirq.Simulator().simulate(circ)
res.density_matrix_of(system_qubits)

array([[ 0.585+0.j   , -0.103-0.168j,  0.181+0.183j,  0.084-0.361j],
       [-0.103+0.168j,  0.067+0.j   , -0.085+0.02j ,  0.089+0.088j],
       [ 0.181-0.183j, -0.085-0.02j ,  0.113+0.j   , -0.087-0.138j],
       [ 0.084+0.361j,  0.089-0.088j, -0.087+0.138j,  0.235+0.j   ]],
      dtype=complex64)

In [54]:
circ = cirq.Circuit((ansatz_preparation(system_qubits, conjugate=False),
                     add_in_k_N_cirq(1, d_s, system_qubits, aux1)))
                     #sum_k_cirq(1, system_qubits)))
                    #))
res = cirq.Simulator().simulate(circ)
res.density_matrix_of(system_qubits)

array([[ 0.585+0.j   ,  0.   +0.j   , -0.181-0.183j,  0.   -0.j   ],
       [ 0.   -0.j   ,  0.067+0.j   , -0.   +0.j   , -0.089-0.088j],
       [-0.181+0.183j, -0.   -0.j   ,  0.113+0.j   , -0.   +0.j   ],
       [ 0.   +0.j   , -0.089+0.088j, -0.   -0.j   ,  0.235+0.j   ]],
      dtype=complex64)

$$ |a,b\rangle \rightarrow |a, a+b \mod d\rangle $$

In [141]:
n = 1
d_s = 2

a = 0
b = 0
system_qubits = cirq.LineQubit.range(n)
ancilla_qubits = cirq.LineQubit.range(n, 2*n)
aux1 = [cirq.NamedQubit("aux1")]
aux2 = [cirq.NamedQubit("aux2")]

circuit = cirq.Circuit((add_in_N_cirq(d_s, ancilla_qubits, system_qubits, aux1, aux2)))
initial_state = kron(np.eye(2**n)[a%d_s],np.eye(2**n)[b%d_s],np.eye(4)[0])
final_state = kron(np.eye(2**n)[a%d_s],np.eye(2**n)[(a+b)%d_s])
sim = cirq.Simulator()
result = sim.simulate(circuit, initial_state=initial_state, qubit_order=ancilla_qubits+system_qubits+aux1+aux2)
np.diag(result.density_matrix_of(system_qubits)).real

array([1., 0.], dtype=float32)

In [None]:
res_vec = np.diag(result.density_matrix_of(ancilla_qubits+system_qubits)).real
res_vec, final_state

In [126]:
cirq.dirac_notation(res_vec),  cirq.dirac_notation(final_state)

('|001001⟩', '|001001⟩')

In [127]:
d_res_vec = cirq.dirac_notation(res_vec)[1:-1]
d_final_state = cirq.dirac_notation(final_state)[1:-1]
int(d_res_vec[:n],2), int(d_res_vec[n:], 2)

(1, 1)

In [128]:
int(d_final_state[:n],2), int(d_final_state[n:], 2)

(1, 1)