In [1]:
from qiskit import *
import matplotlib.pyplot as plt
from qiskit import Aer
from qiskit.quantum_info import partial_trace, Statevector
import numpy as np

In [2]:
def encoded_state(sv):
    # traces out ancilla qubits

    # full_statevector = Statevector(cir)
    partial_dm = partial_trace(sv, [5, 6, 7, 8])
    partial_sv = np.diagonal(partial_dm)

    return partial_sv

In [3]:
def stabilizer_measurements(qr, cr):
    gens = [
        ['x','z','z','x','i'],
        ['i','x','z','z','x'],
        ['x','i','x','z','z'],
        ['z','x','i','x','z']
    ]

    cir = QuantumCircuit(qr, cr)
    
    for i, s in enumerate(gens):
        cir.h(i+5)
        for j, g in enumerate(s):
            if (g == 'x'):
                cir.cnot(i+5, j)
            elif (g == 'z'):
                cir.cz(i+5, j)
            else:
                pass
        cir.h(i+5)
        cir.measure(i+5,i)
        # cir.reset(i+5)

    return cir

def encoding(qr):
    cir = QuantumCircuit(qr)

    cir.h(0)
    cir.s(0)
    cir.cz(0,1)
    cir.cz(0,3)
    cir.cy(0,4)

    cir.h(1)
    cir.cz(1,2)
    cir.cz(1,3)
    cir.cx(1,4)

    cir.h(2)
    cir.cz(2,0)
    cir.cz(2,1)
    cir.cx(2,4)

    cir.h(3)
    cir.s(3)
    cir.cz(3,0)
    cir.cz(3,2)
    cir.cy(3,4)
    
    return cir

def error_correction(qr, syn):
    # syn in the error syndrome returned from measuring the stabilizers
    # cir.id(0).c_if(syn, 0)
    cir = QuantumCircuit(qr, syn)

    cir.x(0).c_if(syn, 8) # 8 or 1
    cir.y(0).c_if(syn, 13) # 13 or 11
    cir.z(0).c_if(syn, 5) # 5 or 10

    cir.x(1).c_if(syn, 1) # 1 or 8
    cir.y(1).c_if(syn, 11) # 11 or 13
    cir.z(1).c_if(syn, 10) # 10 or 5

    cir.x(2).c_if(syn, 3) # 3 or 12
    cir.y(2).c_if(syn, 7) # 7 or 14
    cir.z(2).c_if(syn, 4) # 4 or 2

    cir.x(3).c_if(syn, 6)
    cir.y(3).c_if(syn, 15)
    cir.z(3).c_if(syn, 9)

    cir.x(4).c_if(syn, 12) # 12 or 3
    cir.y(4).c_if(syn, 14) # 14 or 7
    cir.z(4).c_if(syn, 2) # 2 or 4

    return cir

In [25]:
q = QuantumRegister(9)
c = ClassicalRegister(4)
cir = QuantumCircuit(q, c)


cir += encoding(q)
# x
cir.x([0,1,2,3,4])

# y
# cir.y([0,1,2,3,4])

# h
cir.h([0,1,2,3,4])
cir.swap(0, 1)
cir.swap(3, 4)
cir.swap(1, 3)

# z
# cir.z([0,1,2,3,4])
cir.x(1)
# s
# cir.s([0,1,2,3,4])
# cir.s([0,1,2,3,4])

# s dag
# cir.sdg([0,1,2,3,4])
# cir.sdg([0,1,2,3,4])
cir += stabilizer_measurements(q, c)

cir += error_correction(q, c)

cir += encoding(q).inverse() # decoding
cir.draw()

In [26]:
backend = Aer.get_backend('statevector_simulator')
job_sim = backend.run(transpile(cir, backend), shots=10000)
result_sim = job_sim.result()
counts = result_sim.get_counts(cir)
sv = result_sim.get_statevector(cir, decimals=3)

In [27]:
print(counts)
for i, s in enumerate(encoded_state(sv)):
# for i, s in enumerate(sv):
    print('{:05b}'.format(i), s)
    # print(i)

{'0001': 1}
00000 (0.49984899999999993+0j)
00001 0j
00010 0j
00011 0j
00100 0j
00101 0j
00110 0j
00111 0j
01000 0j
01001 0j
01010 0j
01011 0j
01100 0j
01101 0j
01110 0j
01111 0j
10000 (0.49984899999999993+0j)
10001 0j
10010 0j
10011 0j
10100 0j
10101 0j
10110 0j
10111 0j
11000 0j
11001 0j
11010 0j
11011 0j
11100 0j
11101 0j
11110 0j
11111 0j
