In [1]:
from collections.abc import Generator

import pyzx
import pennylane as qml
from pennylane import numpy as np

In [2]:
# Quantum RPN calculator with ZX calculus
# q_1(n1) q_2(n2) op -> could we resolve the number of qubits needed?
# resolve 'q_1(n1) q_2(n2) +' = q_3(2*max(n1, n2)+1)
# resolve 'q_1(n1) q_2(n2) -' = 'complement(q_1(n1) q_2(n2) +)' = q_3(2*max(n1, n2)+1)
# resolve 'q_1(n1) q_2(n2) *' = q_3(4*max(n1, n2)+1)
# resolve 'q_1(n1) q_2(n2) /' = q_3(3*max(n1, n2))

In [3]:
def get_bit_field(k: int) -> Generator:
    for digit in bin(k)[2:]:
        yield 1 if digit=='1' else 0


def pad_zero(bit_field, pad_size: int) -> Generator:
    if pad_size > 0:
        for _ in range(pad_size):
            yield 0
    for bit in bit_field:
        yield bit


def pre_process(a, b) -> tuple[list[int], list[int]]:
    if isinstance(a, int):
        a = list(get_bit_field(a))
    if isinstance(b, int):
        b = list(get_bit_field(b))
    if len(a) != len(b):
        a = list(pad_zero(a, len(b)-len(a)))
        b = list(pad_zero(b, len(a)-len(b)))
    return a, b


def did_padding_work(a, b):
    assert len(a) == len(b)
    return len(a)

In [189]:
def load_to_wires(value, control, wires):
    for j, wire in enumerate(wires):
        qml.CRZ(value * np.pi / (2**j), wires=[control, wire])


def add_(a_wires, b_wires, qft_wires):
    qml.QFT(qft_wires)

    # add a to the counter
    for i, a_wire in enumerate(a_wires):
        load_to_wires(2**(len(a_wires) - i - 1), a_wire, qft_wires)

    # add b to the counter
    for i, b_wire in enumerate(b_wires):
        load_to_wires(2**(len(b_wires) - i - 1), b_wire, qft_wires)

    # return to computational basis
    qml.adjoint(qml.QFT)(wires=qft_wires)


def add(a: int|list, b: int|list) -> (np.tensor, dict):
    a, b = pre_process(a, b)
    n = did_padding_work(a, b)
    num_wires = 3*n + 1

    dev = qml.device("default.qubit", wires=num_wires, shots=1)
    @qml.qnode(dev)
    def circuit():
        total_wires=list(range(num_wires))
        a_wires = total_wires[:n]
        b_wires = total_wires[n:2*n]
        qft_wires = total_wires[2*n:]

        qml.BasisEmbedding(features=a, wires=a_wires)
        qml.BasisEmbedding(features=b, wires=b_wires)
        
        add_(a_wires, b_wires, qft_wires)

        return qml.sample(wires=qft_wires)

    return circuit, qml.specs(circuit)

In [228]:
def subtract_(a_wires, b_wires, qft_wires):
    qml.broadcast(qml.PauliX, pattern="single", wires=a_wires)
    add_(a_wires, b_wires, qft_wires)
    qml.broadcast(qml.PauliX, pattern="single", wires=qft_wires)


def subtract(a: int|list, b: int|list) -> (np.tensor, dict):
    if b > a:
        raise NotImplementedError("Arithmetic operations for negative integers are not supported yet")

    a, b = pre_process(a, b)
    n = did_padding_work(a, b)
    num_wires = 3*n+1

    dev = qml.device("default.qubit", wires=num_wires, shots=1)
    @qml.qnode(dev)
    def circuit():
        total_wires=list(range(num_wires))
        a_wires = total_wires[:n]
        b_wires = total_wires[n:2*n]
        qft_wires = total_wires[2*n:]

        qml.BasisEmbedding(features=a, wires=a_wires)
        qml.BasisEmbedding(features=b, wires=b_wires)

        subtract_(a_wires, b_wires, qft_wires)

        return qml.sample(wires=qft_wires[1:])

    return circuit, qml.specs(circuit)

In [235]:
def multiply_(a_wires, b_wires, qft_wires):
    qml.QFT(qft_wires)
    for i, a_wire in enumerate(a_wires):
        for i, b_wire in enumerate(b_wires):
            qml.ctrl(load_to_wires, control=a_wire)(
                2**(len(a_wires) - i - 1) + 2**(len(b_wires) - i - 1), control=b_wire, wires=qft_wires)
    # return to computational basis
    qml.adjoint(qml.QFT)(wires=qft_wires)


def multiply(a: int|list, b: int|list) -> (np.tensor, dict):
    a, b = pre_process(a, b)
    n = did_padding_work(a, b)
    num_wires = 4*n

    dev = qml.device("default.qubit", wires=num_wires, shots=1)
    @qml.qnode(dev)
    def circuit():
        total_wires=list(range(num_wires))
        a_wires = total_wires[:n]
        b_wires = total_wires[n:2*n]
        qft_wires = total_wires[2*n:]

        qml.BasisEmbedding(features=a, wires=a_wires)
        qml.BasisEmbedding(features=b, wires=b_wires)
        
        multiply_(a_wires, b_wires, qft_wires)

        return qml.sample(wires=qft_wires)

    return circuit, qml.specs(circuit)

In [238]:
circuit, _ = multiply(4,8)

In [239]:
circuit()

array([0, 0, 0, 1, 0, 0, 0, 0])

In [None]:
def _unitary_copy(a_wires, b_wires):
    