<a href="https://colab.research.google.com/github/chaor11/twoqubitec/blob/master/Bacon_Shor_Circuit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install cirq~=0.9.0

Collecting cirq~=0.9.0
[?25l  Downloading https://files.pythonhosted.org/packages/18/05/39c24828744b91f658fd1e5d105a9d168da43698cfaec006179c7646c71c/cirq-0.9.1-py3-none-any.whl (1.6MB)
[K     |████████████████████████████████| 1.6MB 3.9MB/s 
Collecting freezegun~=0.3.15
  Downloading https://files.pythonhosted.org/packages/17/5d/1b9d6d3c7995fff473f35861d674e0113a5f0bd5a72fe0199c3f254665c7/freezegun-0.3.15-py2.py3-none-any.whl
Installing collected packages: freezegun, cirq
Successfully installed cirq-0.9.1 freezegun-0.3.15


In [None]:
import cirq
import numpy as np
#import cirq.ion as ci
#from cirq import Simulator
#import itertools
#import random

## Bacon Shor Code
$$\newcommand{\bra}[1]{\middle|{#1}>\rangle}$$
logical $X_L = X_1X_4X_7$, logical $Z_L = Z_1Z_2Z_3$
$$|+_L>= (|000>+|111>)^{\otimes 3},
|0_L> = (|+++>+|--->)^{\otimes 3}$$


In [None]:
### 9-qubit Bacon Shor code, following the qubit layout and circuit from https://arxiv.org/pdf/2009.11482.pdf
class BS3x3Code:
    def __init__(self):
        self.num_physical_qubits = 13 # 9 plus 4 ancilla, denoted from 0 to 13
        self.physical_qubits = cirq.LineQubit.range(self.num_physical_qubits)

    def encode(self):
        # encode into logical-X basis
        index = [0,3,6]
        yield [cirq.ms(0.25*np.pi).on(self.physical_qubits[i],self.physical_qubits[i+1]) for i in index]
        yield [cirq.rz(0.5*np.pi).on(self.physical_qubits[i+1]) for i in index]
        yield [cirq.ms(-0.25*np.pi).on(self.physical_qubits[i+1],self.physical_qubits[i+2]) for i in index]
        yield [cirq.rz(-0.5*np.pi).on(self.physical_qubits[i+2]) for i in index]

    def logical_ry(self):
        #convert from logical-X basis to logical Z-basis
        #logical X = X_1X_4X_7, logical Z = Z_1Z_2Z_3
        yield cirq.Moment([cirq.ry(0.5*np.pi).on(self.physical_qubits[i]) for i in range(9)])

    def z_stab(self):
        #yield [cirq.reset(self.physical_qubits[9]),cirq.reset(self.physical_qubits[10])]
        yield cirq.Moment([cirq.ry(0.5*np.pi).on(self.physical_qubits[i]) for i in range(9)])
        yield [
                [cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[9],self.physical_qubits[i])),
                 cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[9],self.physical_qubits[i+3]))
                ] 
          for i in range(3)] 
        yield [
                [cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[10],self.physical_qubits[i+3])),
                 cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[10],self.physical_qubits[i+6]))
                ] 
          for i in range(3)] 
        yield cirq.Moment(
            [cirq.rx(-0.5*np.pi).on(self.physical_qubits[i]) for i in range(3)],
            [cirq.rx(0.5*np.pi).on(self.physical_qubits[i]) for i in range(6,9)],
            [cirq.measure(self.physical_qubits[9]),cirq.measure(self.physical_qubits[10])]
            )
        yield cirq.Moment([cirq.ry(-0.5*np.pi).on(self.physical_qubits[i]) for i in range(9)])

    def x_stab(self):
        #yield [cirq.reset(self.physical_qubits[11]),cirq.reset(self.physical_qubits[12])]
        yield [
               [cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[11],self.physical_qubits[i])),
                cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[11],self.physical_qubits[i+1]))]
               for i in [0,3,6]
        ]
        yield [
               [cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[12],self.physical_qubits[i+1])),
                cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[12],self.physical_qubits[i+2]))]
               for i in [0,3,6]
        ]
        yield cirq.Moment(
            [
             [cirq.rx(-0.5*np.pi).on(self.physical_qubits[i]) for i in [0,3,6]],
             [cirq.rx(0.5*np.pi).on(self.physical_qubits[i+2]) for i in [0,3,6]],
             [cirq.measure(self.physical_qubits[11]),cirq.measure(self.physical_qubits[12])]
            ]
        )
        pass

circuit=cirq.Circuit()
code=BS3x3Code()
circuit.append(code.encode())
circuit.append(code.logical_ry())
circuit.append(code.z_stab())
circuit.append(code.x_stab())
#print(repr(circuit))
print(circuit) #This is a commbination of Fig. S12, S15 and S16 in https://arxiv.org/pdf/2009.11482.pdf

0: ────MS(0.25π)───────────────────────────────────────Ry(0.5π)───Ry(0.5π)───MS(0.25π)─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────Rx(-0.5π)───Ry(-0.5π)───MS(-0.25π)────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────Rx(-0.5π)───
       │                                                                     │                                                                                                                                                                             │
1: ────MS(0.25π)───Rz(0.5π)───MS(-0.25π)───────────────Ry(0.5π)───Ry(0.5π)───┼────────────────────────MS(0.25π)────────────────────────────────────────────────────────────────────────────────────────────────────────────────────Rx(-0.5π)───Ry(-0.5π)───┼────────────MS(0.25π)─────────────────────────────────────────────────────MS(-0.2

In [None]:
### Compared to BS3x3, remove qubit 6,7,8 and ancilla 10
class BS3x2Code:
    def __init__(self):
        self.num_physical_qubits = 13 # keep same indexing, (only 9 are used)
        self.physical_qubits = cirq.LineQubit.range(self.num_physical_qubits)

    def encode(self):
        # encode into logical-X basis
        index = [0,3]
        yield [cirq.ms(0.25*np.pi).on(self.physical_qubits[i],self.physical_qubits[i+1]) for i in index]
        yield [cirq.rz(0.5*np.pi).on(self.physical_qubits[i+1]) for i in index]
        yield [cirq.ms(-0.25*np.pi).on(self.physical_qubits[i+1],self.physical_qubits[i+2]) for i in index]
        yield [cirq.rz(-0.5*np.pi).on(self.physical_qubits[i+2]) for i in index]

    def logical_ry(self):
        #convert from logical-X basis to logical Z-basis
        #logical X = X_1X_4X_7, logical Z = Z_1Z_2Z_3
        yield cirq.Moment([cirq.ry(0.5*np.pi).on(self.physical_qubits[i]) for i in range(6)])

    def z_stab(self):
        #yield cirq.reset(self.physical_qubits[9])
        yield cirq.Moment([cirq.ry(0.5*np.pi).on(self.physical_qubits[i]) for i in range(6)])
        yield [
                [cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[9],self.physical_qubits[i])),
                 cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[9],self.physical_qubits[i+3]))
                ] 
          for i in range(3)] 
        yield cirq.Moment(
            [cirq.rx(-0.5*np.pi).on(self.physical_qubits[i]) for i in range(3)],
            [cirq.rx(0.5*np.pi).on(self.physical_qubits[i]) for i in range(3,6)],
            cirq.measure(self.physical_qubits[9])
            )
        yield cirq.Moment([cirq.ry(-0.5*np.pi).on(self.physical_qubits[i]) for i in range(6)])

    def x_stab(self):
        index =[0,3]
        #yield [cirq.reset(self.physical_qubits[11]),cirq.reset(self.physical_qubits[12])]
        yield [
               [cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[11],self.physical_qubits[i])),
                cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[11],self.physical_qubits[i+1]))]
               for i in index
        ]
        yield [
               [cirq.Moment(cirq.ms(-0.25*np.pi).on(self.physical_qubits[12],self.physical_qubits[i+1])),
                cirq.Moment(cirq.ms(0.25*np.pi).on(self.physical_qubits[12],self.physical_qubits[i+2]))]
               for i in index
        ]
        yield cirq.Moment(
            [
             [cirq.rx(-0.5*np.pi).on(self.physical_qubits[i]) for i in index],
             [cirq.rx(0.5*np.pi).on(self.physical_qubits[i+2]) for i in index],
             [cirq.measure(self.physical_qubits[11]),cirq.measure(self.physical_qubits[12])]
            ]
        )
        pass

circuit=cirq.Circuit()
code=BS3x2Code()
circuit.append(code.encode())
circuit.append(code.logical_ry())
circuit.append(code.z_stab())
circuit.append(code.x_stab())
#print(repr(circuit))
print(circuit) 

0: ────MS(0.25π)───────────────────────────────────────Ry(0.5π)───Ry(0.5π)───MS(0.25π)──────────────────────────────────────────────────────────────────Rx(-0.5π)───Ry(-0.5π)───MS(-0.25π)──────────────────────────────────────────────────────────────────────────────────────────Rx(-0.5π)───
       │                                                                     │                                                                                                  │
1: ────MS(0.25π)───Rz(0.5π)───MS(-0.25π)───────────────Ry(0.5π)───Ry(0.5π)───┼────────────────────────MS(0.25π)─────────────────────────────────────────Rx(-0.5π)───Ry(-0.5π)───┼────────────MS(0.25π)────────────────────────────MS(-0.25π)────────────────────────────────────────────────────
                              │                                              │                        │                                                                         │            │                                    │
2: ─────────────

Encoded state: 0.35|000000000⟩ + 0.35|000000111⟩ + 0.35|000111000⟩ + 0.35|000111111⟩ + 0.35|111000000⟩ + 0.35|111000111⟩ + 0.35|111111000⟩ + 0.35|111111111⟩
After stabilizer measurement: 0.35|00000000000⟩ + 0.35|00000011101⟩ + 0.35|00011100011⟩ + 0.35|00011111110⟩ + 0.35|11100000010⟩ + 0.35|11100011111⟩ + 0.35|11111100001⟩ + 0.35|11111111100⟩
0: ────MS(0.25π)───────────────────────────────────────Ry(0.5π)───Ry(0.5π)───MS(0.25π)─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────Rx(-0.5π)───Ry(-0.5π)───
       │                                                                     │
1: ────MS(0.25π)───Rz(0.5π)───MS(-0.25π)───────────────Ry(0.5π)───Ry(0.5π)───┼────────────────────────MS(0.25π)────────────────────────────────────────────────────────────────────────────────────────────────────────────────────Rx(-0.5π)───Ry(-0.5π)───
                              │                                         

In [None]:
# Test session
def test_encoding():
  circuit=cirq.Circuit()
  code=BS3x3Code()
  circuit.append(code.encode())  
  print("logical state |1>:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0b001001001))) 
  print("logical state |0>:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0)))   
  print(circuit)

  circuit.append(code.logical_ry())
  print("logical state |+>:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0b001001001))) 
  print("logical state |->:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0))) 
  print(circuit)

def test_z_stab():
  circuit=cirq.Circuit()
  code=BS3x3Code()
  circuit.append(code.encode())
  #circuit.append(code.logical_ry())
  circuit.append(code.z_stab())

  print("After stabilizer measurement:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0b001001001))) 
  print("After stabilizer measurement:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0))) 
  # The last two syndrome bits should give zero all the time for this stabilized states, however it is not true here. It is fixed but not constant
  print(circuit)

def test_x_stab():
  circuit=cirq.Circuit()
  code=BS3x3Code()
  circuit.append(code.encode())
  #circuit.append(code.logical_ry())
  circuit.append(code.z_stab())

  print("After stabilizer measurement:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0b001001001))) 
  print("After stabilizer measurement:",cirq.dirac_notation(circuit.final_state_vector(initial_state=0))) 
  # The last two syndrome bits should give zero all the time for this stabilized states, however it is not true here. It is fixed but not constant
  print(circuit)

#test_encoding()
test_z_stab()
#test_x_stab()

After stabilizer measurement: -0.12|00000100111⟩ - 0.12|00000101011⟩ + 0.12|00000110011⟩ + 0.12|00000111111⟩ - 0.12|00001000111⟩ - 0.13|00001001011⟩ + 0.12|00001010011⟩ + 0.12|00001011111⟩ + 0.12|00010000111⟩ + 0.13|00010001011⟩ - 0.12|00010010011⟩ - 0.12|00010011111⟩ + 0.12|00011100111⟩ + 0.12|00011101011⟩ - 0.12|00011110011⟩ - 0.12|00011111111⟩ - 0.12|01100100111⟩ - 0.12|01100101011⟩ + 0.12|01100110011⟩ + 0.12|01100111111⟩ - 0.12|01101000111⟩ - 0.12|01101001011⟩ + 0.12|01101010011⟩ + 0.12|01101011111⟩ + 0.12|01110000111⟩ + 0.12|01110001011⟩ - 0.12|01110010011⟩ - 0.12|01110011111⟩ + 0.12|01111100111⟩ + 0.12|01111101011⟩ - 0.12|01111110011⟩ - 0.12|01111111111⟩ - 0.12|10100100111⟩ - 0.12|10100101011⟩ + 0.12|10100110011⟩ + 0.12|10100111111⟩ - 0.13|10101000111⟩ - 0.13|10101001011⟩ + 0.12|10101010011⟩ + 0.12|10101011111⟩ + 0.12|10110000111⟩ + 0.12|10110001011⟩ - 0.12|10110010011⟩ - 0.12|10110011111⟩ + 0.12|10111100111⟩ + 0.12|10111101011⟩ - 0.12|10111110011⟩ - 0.12|10111111111⟩ - 0.12|1100

In [None]:
# check circuit from Fig.5 https://arxiv.org/pdf/1810.01040.pdf
def test_stabilizer_slicing():
  #define Clifford circuit
  circuit1=cirq.Circuit()
  qubit_list1 = cirq.LineQubit.range(7)
  circuit1.append([
                   [cirq.H(qubit_list1[i]) for i in range(6)],
                  cirq.H(qubit_list1[6]),
                  [cirq.CNOT(qubit_list1[6],qubit_list1[i]) for i in [0,2,1,3,5,4]],
                  cirq.H(qubit_list1[6]),
                   cirq.measure(qubit_list1[6])
  ])
  print(circuit1)

  # define native circuit
  circuit2=cirq.Circuit()
  qubit_list2 = cirq.LineQubit.range(7)
  circuit2.append([
                   [cirq.H(qubit_list1[i]) for i in range(6)],                   
                  [cirq.ms(0.25*np.pi*j).on(qubit_list2[6], qubit_list2[i]) for i,j in [(0,1),(2,-1),(1,1),(3,-1),(5,+1),(4,-1)] ],
                  cirq.Moment([cirq.rx(0.5 *np.pi*j).on(qubit_list2[i]) for i,j in [(0,1),(2,-1),(1,1),(3,-1),(5,+1),(4,-1)] ]),
                  cirq.measure(qubit_list2[6])
  ])
  print(circuit2)
  print("circuit 1:",cirq.dirac_notation(circuit1.final_state_vector(initial_state=0b111111))) 
  print("circuit 2:",cirq.dirac_notation(circuit2.final_state_vector(initial_state=0b111111))) 
  print("check passes if return 0 on last bit all the time")

test_stabilizer_slicing()


0: ───H───X───────────────────────────────
          │
1: ───H───┼───────X───────────────────────
          │       │
2: ───H───┼───X───┼───────────────────────
          │   │   │
3: ───H───┼───┼───┼───X───────────────────
          │   │   │   │
4: ───H───┼───┼───┼───┼───────X───────────
          │   │   │   │       │
5: ───H───┼───┼───┼───┼───X───┼───────────
          │   │   │   │   │   │
6: ───H───@───@───@───@───@───@───H───M───
0: ───H───MS(0.25π)──────────────────────────────────────────────────────────────────Rx(0.5π)────
          │
1: ───H───┼────────────────────────MS(0.25π)─────────────────────────────────────────Rx(0.5π)────
          │                        │
2: ───H───┼───────────MS(-0.25π)───┼─────────────────────────────────────────────────Rx(-0.5π)───
          │           │            │
3: ───H───┼───────────┼────────────┼───────────MS(-0.25π)────────────────────────────Rx(-0.5π)───
          │           │            │           │
4: ───H───┼───────────┼─────────

In [None]:
c=cirq.Circuit()
qubit_list = cirq.LineQubit.range(2)
c.append(cirq.measure(qubit_list[0]))
print("circuit:",cirq.dirac_notation(c.final_state_vector(initial_state=1))) 

circuit: |1⟩
