In [6]:
import cirq
import numpy as np
import scipy.linalg

def _append_ripple_carry_adder(
    circuit: cirq.Circuit,
    a: list[cirq.Qid],
    b: list[cirq.Qid],
    c: list[cirq.Qid],
):
    """
    Appends a ripple-carry adder to a circuit in-place.

    This computes a += b, where a and b are quantum registers.
    The sum is stored in register a.

    Args:
        circuit: The circuit to append the gates to.
        a: A list of qubits for the first summand (and result).
        b: A list of qubits for the second summand.
        c: A list of carry ancilla qubits. Must have len(c) == len(a).
    """
    n = len(a)
    if len(b) != n or len(c) != n:
        raise ValueError("All qubit registers must have the same length.")

    # MAJ (Majority) gate decomposition
    def maj(p, q, r):
        yield cirq.CNOT(r, q)
        yield cirq.CNOT(r, p)
        yield cirq.TOFFOLI(p, q, r)

    # UMA (Un-Majority and Add) gate decomposition
    def uma(p, q, r):
        yield cirq.TOFFOLI(p, q, r)
        yield cirq.CNOT(r, p)
        yield cirq.CNOT(p, q)

    # 1. Compute carries
    for i in range(n - 1):
        circuit.append(maj(c[i], a[i], b[i]))

    # 2. Compute final sum bit and last carry
    circuit.append(maj(c[n - 1], a[n - 1], b[n - 1]))
    circuit.append(cirq.CNOT(a[n - 1], b[n - 1]))

    # 3. Uncompute carries to restore b
    for i in range(n - 1, -1, -1):
        circuit.append(uma(c[i], a[i], b[i]))


class LessThanGate(cirq.Gate):
    """
    A quantum gate that checks if a register `s` is less than a classical value `d`.

    This gate performs the operation: |s>|anc> -> |s>|anc ⊕ (s < d)>.
    It uses a quantum adder to compute s + (2^n - d), where the final carry
    bit indicates if s >= d.
    """

    def __init__(self, n: int, d: int):
        if not (0 < d <= 2**n):
            raise ValueError(f"d must be in the range [1, 2**{n}].")
        self._n = n
        self._d = d

    def _num_qubits_(self) -> int:
        # n data qubits, n carry ancillas, 1 result ancilla
        return 2 * self._n + 1

    def _decompose_(self, qubits: list[cirq.Qid]) -> cirq.OP_TREE:
        """Constructs the comparator circuit."""
        data_qubits = qubits[0 : self._n]
        carry_qubits = qubits[self._n : 2 * self._n]
        ancilla = qubits[2 * self._n]

        # The value to add for the comparison s + (2^n - d)
        add_val = 2**self._n - self._d
        add_val_bits = bin(add_val)[2:].zfill(self._n)

        # Represent the classical value `add_val` in a temporary register
        # We use CNOTs to flip the bits and unflip them after.
        add_val_circuit = cirq.Circuit()
        add_val_qubits = [cirq.NamedQubit(f"val_{i}") for i in range(self._n)]
        for i, bit in enumerate(add_val_bits):
            if bit == '1':
                add_val_circuit.append(cirq.X(add_val_qubits[i]))

        # The full adder circuit
        adder_circuit = cirq.Circuit()
        _append_ripple_carry_adder(
            adder_circuit, data_qubits, add_val_qubits, carry_qubits
        )

        # The final carry bit is 1 if s >= d. We want s < d.
        # So we flip the carry bit before CNOTing to the ancilla.
        final_carry = carry_qubits[-1]

        # Full sequence: prepare value, add, check carry, un-add, un-prepare
        yield add_val_circuit
        yield adder_circuit
        yield cirq.X(final_carry)
        yield cirq.CNOT(final_carry, ancilla)
        yield cirq.X(final_carry)
        yield cirq.inverse(adder_circuit)
        yield cirq.inverse(add_val_circuit)

def qft_d_gate(n: int, d: int) -> cirq.MatrixGate:
    """
    Creates a gate for a d-dimensional QFT embedded in a 2^n space.

    The resulting unitary matrix is block-diagonal with a d x d QFT
    matrix in the top-left and an identity matrix for the rest.

    Args:
        n: The number of qubits the gate will act on.
        d: The dimension of the QFT.

    Returns:
        A cirq.MatrixGate that implements the desired unitary.
    """
    if not (0 < d <= 2**n):
        raise ValueError(f"d must be in the range [1, 2**{n}].")

    # Create the d x d QFT matrix
    qft_d_matrix = scipy.linalg.dft(d, scale='sqrtn')

    # Embed it into a 2^n x 2^n identity matrix
    full_unitary = np.identity(2**n, dtype=np.complex128)
    full_unitary[0:d, 0:d] = qft_d_matrix

    return cirq.MatrixGate(full_unitary, name=f'QFT({d})')


def create_block_qft_circuit(n: int, d: int) -> cirq.Circuit:
    """
    Creates the full circuit for the block-diagonal QFT unitary.

    Args:
        n: Total number of data qubits.
        d: The dimension of the QFT block.

    Returns:
        A cirq.Circuit that implements the operation.
    """
    # Define the qubits
    data_qubits = cirq.LineQubit.range(n)
    # Ancillas for the comparator's internal logic (carry bits)
    comparator_ancillas = [cirq.NamedQubit(f'c_{i}') for i in range(n)]
    # The main ancilla that controls the QFT application
    control_ancilla = cirq.NamedQubit('anc')

    # Instantiate the gates
    less_than_d_gate = LessThanGate(n, d)
    qft_gate = qft_d_gate(n, d)
    controlled_qft = qft_gate.controlled(1)

    # Build the circuit
    circuit = cirq.Circuit()
    comparator_qubits = data_qubits + comparator_ancillas + [control_ancilla]

    # 1. Compare: Set control_ancilla if data < d
    circuit.append(less_than_d_gate.on(*comparator_qubits))

    # 2. Act: Apply controlled-QFT
    circuit.append(controlled_qft.on(control_ancilla, *data_qubits))

    # 3. Un-compute: Reset the comparator and its ancillas
    circuit.append(cirq.inverse(less_than_d_gate.on(*comparator_qubits)))

    return circuit

# --- Main execution and verification block ---
# --- Parameters for the test ---
N_QUBITS = 3
D_DIM = 2  # The size of the QFT block

print(f"Constructing a circuit for n={N_QUBITS} qubits and d={D_DIM}.")
print("-" * 30)

# --- Create the circuit ---
block_qft_circuit = create_block_qft_circuit(n=N_QUBITS, d=D_DIM)

# To see the full gate decomposition, uncomment the following line:
# print("Circuit Decomposition:\n", block_qft_circuit)

# --- Verification ---
print("Verifying the circuit's unitary matrix...")
try:
    # Get the unitary matrix of the constructed circuit
    # Note: Cirq's qubit ordering is little-endian (q0 is the rightmost bit)
    # This requires passing the ancillas to the unitary calculation.
    all_qubits = (
        cirq.LineQubit.range(N_QUBITS)
        + [cirq.NamedQubit(f'c_{i}') for i in range(N_QUBITS)]
        + [cirq.NamedQubit('anc')]
    )
    circuit_unitary = block_qft_circuit.unitary(qubit_order=all_qubits)

    # Construct the expected unitary matrix directly
    qft_block = scipy.linalg.dft(D_DIM, scale='sqrtn')
    identity_block = np.identity(2**N_QUBITS - D_DIM)
    expected_top_block = scipy.linalg.block_diag(qft_block, identity_block)

    # The full unitary acts on data qubits and ancillas.
    # The ancillas should end up in the |0> state, so the full unitary
    # should be a kronecker product with |0><0| for the ancillas.
    num_ancillas = N_QUBITS + 1
    ancilla_identity = np.zeros((2**num_ancillas, 2**num_ancillas))
    ancilla_identity[0, 0] = 1.0 # |0...0><0...0| projector
    
    # Note: np.kron uses big-endian, cirq.unitary is little-endian.
    # So we kron the data qubit matrix on the left.
    expected_unitary = np.kron(expected_top_block, ancilla_identity)


    # Compare the matrices
    if np.allclose(circuit_unitary, expected_unitary):
        print("\n✅ Verification successful!")
        print("The circuit's unitary matches the expected block-diagonal form.")
    else:
        print("\n❌ Verification failed.")
        print("The circuit's unitary does not match the expected matrix.")

except Exception as e:
    print(f"\nAn error occurred during verification: {e}")
    print("Verification is memory-intensive and may fail for n > 3.")


Constructing a circuit for n=3 qubits and d=2.
------------------------------
Verifying the circuit's unitary matrix...

❌ Verification failed.
The circuit's unitary does not match the expected matrix.


In [7]:
circuit_unitary

array([[0.70710678+0.j, 0.        +0.j, 0.        +0.j, ...,
        0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 1.        +0.j, 0.        +0.j, ...,
        0.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 1.        +0.j, ...,
        0.        +0.j, 0.        +0.j, 0.        +0.j],
       ...,
       [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
        1.        +0.j, 0.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
        0.        +0.j, 1.        +0.j, 0.        +0.j],
       [0.        +0.j, 0.        +0.j, 0.        +0.j, ...,
        0.        +0.j, 0.        +0.j, 1.        +0.j]], shape=(128, 128))