Classical code to build the Wallace Tree for integer multiplication, as described in [this paper](https://journals.aps.org/pra/abstract/10.1103/PhysRevA.107.042621).

In [1]:
def inv_perm(perm):
    n = len(perm)
    ans = [0]*n
    for i in range(n):
        ans[perm[i]] = i
    return ans

def apply_perm_to_op(op, perm):
    return tuple((perm[x] if type(x) is int else x) for x in op)

# 3 types of operations:
#   1. AND(a,b,c):         c:=a&b.
#   2. HalfAdder(a,b,c):   b,c:=a⊕b,a&b.     
#   3. FullAdder(a,b,c,d): c,d=MAJ(a,b,c),a⊕b⊕c. 
#   4. CNOT(a,b):          b:=a⊕b.

def BuildWallaceTree(n1, n2):
    assert(n1>=2 and n2>=2)
    
    anc_count = [0]
    def allocate_ancilla():
        anc_count[0]+=1
        return n1+n2+anc_count[0]-1
        
    ops = []
    groups = [[] for _ in range(n1+n2)]
    
    # AND array.
    for i1 in range(n1):
        for i2 in range(n2):
            target = allocate_ancilla()
            level = i1+i2
            ops.append(("AND", i1, n1+i2, target))  
            groups[level].append(target)
    lcnt=0
    
    # Add reduction layer, until every group has at most 2 qubits.
    while max(len(g) for g in groups) > 2:
        old_max = max(len(g) for g in groups)
        new_groups = [[] for _ in range(n1+n2)]
        for i in range(n1+n2):
            old_group = groups[i]
            if len(old_group)+len(new_groups[i])<=3 and old_max>=4:
                # Just pass it to the next layer.
                new_groups[i] += old_group
                continue
            for j in range(len(old_group)//3):
                # Add full adder.
                triple = old_group[3*j:3*j+3]
                sum_bit = allocate_ancilla()
                carry_bit = triple[2]
                ops.append(("FullAdder", triple[0], triple[1], carry_bit, sum_bit))  
                new_groups[i].append(sum_bit)
                new_groups[i+1].append(carry_bit)
            rem = old_group[3*(len(old_group)//3):]
            assert 0<=len(rem)<=2
            if len(rem)==1:
                new_groups[i].append(rem[0])
            elif len(rem)==2:
                carry_bit = allocate_ancilla()
                ops.append(("HalfAdder", rem[0], rem[1], carry_bit))  
                new_groups[i].append(rem[1])
                new_groups[i+1].append(carry_bit)
        groups = new_groups
        lcnt+=1
    
    # The last layer is regular addition of 2 binary numbers.
    output_bits = []
    for i in range(0, n1+n2):
        assert 1<=len(groups[i])<=3
        if len(groups[i])==1:
            output_bits.append(groups[i][0])
        elif len(groups[i])==2:
            if i == n1+n2-1:
                # Last bit cannot overflow, so we don't need the carry bit.
                ops.append(("CNOT", groups[i][0], groups[i][1]))
            else:
                carry_bit = allocate_ancilla()
                ops.append(("HalfAdder", groups[i][0], groups[i][1], carry_bit))  
                groups[i+1].append(carry_bit)
            output_bits.append(groups[i][1])
        else:
            assert len(groups[i])==3
            sum_bit = allocate_ancilla()
            carry_bit = groups[i][2]
            ops.append(("FullAdder", groups[i][0], groups[i][1], carry_bit, sum_bit))  
            output_bits.append(sum_bit)
            groups[i+1].append(carry_bit)
    
    # Remap qubit indexes so outputs are written right after inputs.
    perm_before= list(range(n1+n2)) + output_bits + [i for i in range(n1+n2,n1+n2+anc_count[0]) if i not in output_bits]
    perm = inv_perm(perm_before)
    ops = [apply_perm_to_op(op, perm) for op in ops]
    num_ancialls = anc_count[0] - (n1+n2)
    
    # Bit indexes, by convention.
    #   [0:n1] - input A.
    #   [n1:n1+n2] - input B.
    #   [n1+n2:2*(n1+n2)] - output. 
    #   Everything else is ancillas. 
    print("reduction layers", lcnt)
    return ops, num_ancialls

In [2]:
import random

def test_wallace_tree(n1, n2):
    ops, num_ancialls = BuildWallaceTree(n1, n2)
    num_qubits = 2*(n1+n2) + num_ancialls
    print(f"{n1}x{n2} num_ops={len(ops)} num_ancialls={num_ancialls}")
    print("==========")
    
    for _ in range(100):
        num1, num2 = random.randint(0, 2**n1-1), random.randint(0, 2**n2-1)
        qubits = [0] * num_qubits
        for i in range(n1):
            qubits[i]=(num1>>i)%2
        for i in range(n2):
            qubits[n1+i]=(num2>>i)%2


        for op in ops:
            if op[0] == "AND":
                _, a, b, c = op
                qubits[c] = qubits[a]*qubits[b]
            elif op[0] == "HalfAdder":
                _, a, b, c = op
                qubits[b], qubits[c] = qubits[a]^qubits[b], qubits[a]*qubits[b]  
            elif  op[0] == "FullAdder":
                _, a, b, c, d = op
                qubits[c],qubits[d] = (qubits[a]+qubits[b]+qubits[c])//2, qubits[a]^qubits[b]^qubits[c] 
            else:
                assert op[0] == "CNOT"
                _, a, b = op
                qubits[b] ^= qubits[a]
                
        for q in qubits:
            assert q in [0,1]

        # Verify inputs were not changed.
        for i in range(n1):
            assert qubits[i] == (num1>>i)%2
        for i in range(n2):
            assert qubits[n1+i] == (num2>>i)%2
            

        # Reconstruct and check output.
        ans = sum(qubits[n1+n2+i]<<i for i in range(n1+n2))
        assert ans == num1*num2
    
for n1, n2 in [(2,2), (2,4), (2,3), (3,2),(3,4), (4,4), (4,5), (5,4), (8,8), (10, 10), 
               (16, 16), (2, 32), (32,32), (64, 10), (64, 64), (128, 128), (256, 256)]:
    test_wallace_tree(n1, n2)
print("OK")


reduction layers 0
2x2 num_ops=6 num_ancialls=2
reduction layers 0
2x4 num_ops=12 num_ancialls=6
reduction layers 0
2x3 num_ops=9 num_ancialls=4
reduction layers 0
3x2 num_ops=9 num_ancialls=4
reduction layers 1
3x4 num_ops=20 num_ancialls=13
reduction layers 2
4x4 num_ops=28 num_ancialls=20
reduction layers 2
4x5 num_ops=35 num_ancialls=26
reduction layers 2
5x4 num_ops=35 num_ancialls=26
reduction layers 4
8x8 num_ops=127 num_ancialls=110
reduction layers 5
10x10 num_ops=200 num_ancialls=179
reduction layers 6
16x16 num_ops=520 num_ancialls=487
reduction layers 0
2x32 num_ops=96 num_ancialls=62
reduction layers 8
32x32 num_ops=2093 num_ancialls=2028
reduction layers 5
64x10 num_ops=1280 num_ancialls=1205
reduction layers 10
64x64 num_ops=8350 num_ancialls=8221
reduction layers 11
128x128 num_ops=33235 num_ancialls=32978
reduction layers 13
256x256 num_ops=132307 num_ancialls=131794
OK
