In [1]:
from qiskit import QuantumRegister, ClassicalRegister, QuantumCircuit

In [2]:
def get_bits(x, n):
    return [int(x) for x in '{:0{size}b}'.format(x, size=n)]

In [3]:
def generate_carry(constant, n):
    name = 'Carry ({0})'.format(constant)
    constant_bits = get_bits(constant, n)
    
    qrA = QuantumRegister(n, name='a') # |a> - input register
    qrC = QuantumRegister(1, name='c') # |c> - output register - information about carry
    
    # special case for n = 1: 
    if n == 1:
        qc = QuantumCircuit(qrA, qrC, name=name)
    
        if constant_bits[0] == 1:
            qc.cx(qrA[0], qrC)
        
        return qc.to_instruction()
    
    # for n >= 2:
    qrG = QuantumRegister(n-1, name='g') # |g> - dirty ancilla register, state have to be restored
    
    qc = QuantumCircuit(qrA, qrG, qrC, name=name)

    qc.cx(qrG[n - 2], qrC)

    for i in range(n - 1, 1, -1):
        if constant_bits[n - (i + 1)] == 1:
            qc.cx(qrA[i], qrG[i - 1])
            qc.x(qrA[i])
        qc.ccx(qrG[i - 2], qrA[i], qrG[i - 1])

    if constant_bits[n - 2] == 1:
        qc.cx(qrA[1], qrG[0])
        qc.x(qrA[1])
        
    if constant_bits[n - 1] == 1:
        qc.ccx(qrA[0], qrA[1], qrG[0])

    for i in range(2, n):
        qc.ccx(qrG[i - 2], qrA[i], qrG[i - 1])

    qc.cx(qrG[n - 2], qrC)
   
    for i in range(n - 1, 1, -1):
        qc.ccx(qrG[i - 2], qrA[i], qrG[i - 1])

    if constant_bits[n - 1] == 1:
        qc.ccx(qrA[0], qrA[1], qrG[0])
        
    if constant_bits[n - 2] == 1:
        qc.x(qrA[1])
        qc.cx(qrA[1], qrG[0])

    for i in range(2, n):
        qc.ccx(qrG[i - 2], qrA[i], qrG[i - 1])
        if constant_bits[n - (i + 1)] == 1:
            qc.x(qrA[i])
            qc.cx(qrA[i], qrG[i - 1])
    
    return qc.to_instruction()

In [4]:
def generate_negation(n):
    qrA = QuantumRegister(n, name='a')
    
    qc = QuantumCircuit(qrA, name='Negation')
    
    for i in range(n):
        qc.x(qrA[i])
    
    return qc.to_instruction()

In [5]:
def generate_comparator(constant, n):
    name = 'Comp ({0})'.format(constant)
    
    qrA = QuantumRegister(n, name='a')
    qrC = QuantumRegister(1, name='cmp')
    
    carry = generate_carry(constant, n)
    neg = generate_negation(n)
    
    if n == 1:
        qc = QuantumCircuit(qrA, qrC, name=name)
        qc.append(neg, qrA[:])
        qc.append(carry, qrA[:] + qrC[:])
        qc.append(neg, qrA[:])
        return qc.to_instruction()

    qrG = QuantumRegister(n-1, name='g')
    qc = QuantumCircuit(qrA, qrG, qrC, name=name)
    
    qc.append(neg, qrA[:])
    qc.append(carry, qrA[:] + qrG[:] + qrC[:])
    qc.append(neg, qrA[:])
    
    return qc.to_instruction()

In [6]:
# _qrA = QuantumRegister(5, name='a')
# _qrG = QuantumRegister(4, name='g')
# _qrC = QuantumRegister(1, name='cmp')
# _qc = QuantumCircuit(_qrA, _qrG, _qrC)

# comp = generate_comparator(21, 5)
# _qc.append(comp, _qrA[:] + _qrG[:] + _qrC[:])
# _qc.draw(output='mpl')

In [7]:
# _dec = _qc.decompose()
# _dec.draw(output='mpl')

In [8]:
def generate_init_part(qc, a, g, n):
    if n == 1:
        [qrA, qrC] = qc.qregs
    
        a_bits = get_bits(a, n)
        for i in range(n):
            if a_bits[n-(i+1)] == 1:
                qc.x(qrA[i])
                
        qc.barrier(qrA[:])
        return
    
    [qrA, qrG, qrC] = qc.qregs
    
    a_bits = get_bits(a, n)
    g_bits = get_bits(g, n)
    
    for i in range(n):
        if a_bits[n-(i+1)] == 1:
            qc.x(qrA[i])

    for i in range(n):
        if g_bits[n-(i+1)] == 1:
            qc.x(qrG[i])
            
    qc.barrier(qrA[:], qrG[:])

In [9]:
# _qrA = QuantumRegister(3, name='a')
# _qrG = QuantumRegister(2, name='g')
# _qrC = QuantumRegister(1, name='cmp')
# _qc = QuantumCircuit(_qrA, _qrG, _qrC)

# generate_init_part(_qc, 6, 3, 3)
# _qc.draw(output='mpl')

In [10]:
def generate_measure_part(n):
    qrA = QuantumRegister(n, name='a')
    qrC = QuantumRegister(1, name='cmp')
    crA = ClassicalRegister(n, name='aValue')
    crC = ClassicalRegister(1, name='cValue')
    
    if n == 1:
        qc = QuantumCircuit(qrA, qrC, crA, crC)
        qc.barrier(qrA[:], qrC[:])
        qc.measure(qrA[:], crA[:])
        qc.measure(qrC[:], crC[:])
        
        return qc
    
    qrG = QuantumRegister(n - 1, name='g')
    crG = ClassicalRegister(n - 1, name='gValue')
    
    qc = QuantumCircuit(qrA, qrG, qrC, crA, crG, crC)
    qc.barrier(qrA[:], qrG[:], qrC[:])
    qc.measure(qrA[:], crA[:])
    qc.measure(qrG[:], crG[:])
    qc.measure(qrC[:], crC[:])
    
    return qc

In [11]:
def test_cmp(a, constant, n, cmp_value):
    expected_value = 1 if a < constant else 0
    print(expected_value, cmp_value, expected_value == cmp_value)

In [12]:
test_cmp(6, 11, 4, 1)
test_cmp(6, 11, 4, 0)

1 1 True
1 0 False


In [13]:
from qiskit import Aer
from qiskit import execute

backend = Aer.get_backend('qasm_simulator')

In [14]:
def test(a, g, constant, n):
    qrA = QuantumRegister(n, name='a')
    qrC = QuantumRegister(1, 'cmp')
    
    if n == 1:
        qc = QuantumCircuit(qrA, qrC)
    else:
        qrG = QuantumRegister(n-1, name='g')
        qc = QuantumCircuit(qrA, qrG, qrC)
        
    generate_init_part(qc, a, g, n)
    if n == 1:
        qc.append(generate_comparator(constant, n), qrA[:] + qrC[:])
    else:
        qc.append(generate_comparator(constant, n), qrA[:] + qrG[:] + qrC[:])
        
    qc += generate_measure_part(n)

    job = execute(qc, backend, shots=1)
    result = job.result()
    values = list(result.get_counts(qc).keys())

    assert len(values) == 1
    values = [int(v, 2) for v in values[0].split(' ')]

    if n == 1:
        c_value = values[0]
        a_value = values[1]

        print(a, a_value, a == a_value)
        test_cmp(a, constant, n, c_value)
    else:
        c_value = values[0]
        g_value = values[1]
        a_value = values[2]

        print(a, a_value, a == a_value)
        print(g, g_value, g == g_value)
        test_cmp(a, constant, n, c_value)

In [15]:
test(6, 5, 11, 4)

  qc += generate_measure_part(n)
  return self.extend(rhs)


6 6 True
5 5 True
1 1 True


In [16]:
# for n in range(5):
#     n = n + 1
#     print(n)
#     G = 2 ** (n-1)
#     N = 2 ** n
    
#     for a in range(N):
#         for g in range(G):
#             for c in range(N):
#                 print('---', a, g, c, '---')
#                 test(a, g, c, n)