We'll write a simplifier a la Luckasz

In [1]:
import numpy as np
import sympy
import cirq
import tensorflow as tf
import tensorflow_quantum as tfq
import matplotlib.pyplot as plt
from tqdm import tqdm


class Solver:
    def __init__(self, n_qubits=3, qlr=0.01, qepochs=200,verbose=0, g=1, J=0):

        """"solver with n**2 possible actions: n(n-1) CNOTS + n 1-qubit unitary"""
        self.n_qubits = n_qubits
        self.qubits = cirq.GridQubit.rect(1, n_qubits)
        self.lower_bound_Eg = -2*self.n_qubits
        
        self.qlr = qlr
        self.qepochs=qepochs
        self.verbose=verbose


        self.indexed_cnots = {}
        self.cnots_index = {}
        count = 0
        for control in range(self.n_qubits):
            for target in range(self.n_qubits):
                if control != target:
                    self.indexed_cnots[str(count)] = [control, target]
                    self.cnots_index[str([control,target])] = count
                    count += 1
        self.number_of_cnots = len(self.indexed_cnots)
        
        self.final_params = []
        self.parametrized_unitary = [cirq.rz, cirq.rx, cirq.rz]
        
        self.observable=self.ising_obs(g=g, J=J)
        
    def ising_obs(self, g=1, J=0):
        # g \sum_i Z_i - J \sum_{i} X_i X_{i+1}
        observable = [float(g)*cirq.Z.on(q) for q in self.qubits] 
        for q in range(len(self.qubits)):
            observable.append(float(J)*cirq.X.on(self.qubits[q])*cirq.X.on(self.qubits[(q+1)%len(self.qubits)]))
        self.ground_energy = -g*np.sum(np.sqrt([1+(J/4*g)**2 - (np.cos(q)*(J/2*g)) for q in range(self.n_qubits)]))
        return observable
        
    def index_meaning(self,index):
        if index<self.number_of_cnots:
            print("cnot: ",self.indexed_cnots[str(index)])
            return
        else:
            print("1-qubit unitary on: ",(index-self.number_of_cnots)%self.n_qubits)
            return

    def append_to_circuit(self, ind, circuit, params):
        """
        appends to circuit the index of the gate;
        and if one_hot_gate it implies a rotation,
        appends to params a symbol
        """

        if ind < self.number_of_cnots:
            control, target = self.indexed_cnots[str(ind)]
            circuit.append(cirq.CNOT.on(self.qubits[control], self.qubits[target]))
            return circuit, params
        else:
            qubit = self.qubits[(ind-self.number_of_cnots)%self.n_qubits]
            for par, gate in zip(range(3),self.parametrized_unitary):
                new_param = "th_"+str(len(params))
                params.append(new_param)
                circuit.append(gate(sympy.Symbol(new_param)).on(qubit))
            return circuit, params
        
    def give_circuit(self, lista,one_hot=False):
        circuit, symbols = [], []
        for k in lista:
            circuit, symbols = self.append_to_circuit(k,circuit,symbols)
        circuit = cirq.Circuit(circuit)
        return circuit, symbols

    def dressed_cnot(self,q1,q2):
        u1 = self.number_of_cnots + q1
        u2 = self.number_of_cnots + q2
        cnot = self.cnots_index[str([q1,q2])]
        u3 = self.number_of_cnots + q1
        u4 = self.number_of_cnots + q2
        return [u1,u2,cnot,u3,u4]
    
    def dressed_ansatz(self, layers=1):
        c=[]
        for layer in range(layers):
            qubits = list(range(self.n_qubits))
            qdeph = qubits[layers:]
            for q in qubits[:layers]:
                qdeph.append(q)
            for ind1, ind2 in zip(qubits,qdeph):
                for k in self.dressed_cnot(ind1,ind2):
                    c.append(k)
        return c


    def TFQ_model(self, symbols):
        circuit_input = tf.keras.Input(shape=(), dtype=tf.string)
        output = tfq.layers.Expectation()(
                circuit_input,
                symbol_names=symbols,
                operators=tfq.convert_to_tensor([self.observable]),
                initializer=tf.keras.initializers.RandomNormal()) #we may change this!!!
        model = tf.keras.Model(inputs=circuit_input, outputs=output)
        adam = tf.keras.optimizers.Adam(learning_rate=self.qlr)
        model.compile(optimizer=adam, loss='mse')
        return model


    def run_circuit(self, gates_index, sim_q_state=False):
        """
        takes as input vector with actions described as integer
        and outputsthe energy of that circuit (w.r.t self.observable)
        """
        ### create a vector with the gates on the corresponding qubit(s)
        circuit, symbols = self.give_circuit(gates_index)
        
        ### this is because each qubit should be "activated" in TFQ to do the optimization (if the observable has support on this qubit as well and you don't add I then error)
        effective_qubits = list(circuit.all_qubits())
        for k in self.qubits:
            if k not in effective_qubits:
                circuit.append(cirq.I.on(k))

        tfqcircuit = tfq.convert_to_tensor([circuit])
        if len(symbols) == 0:
            expval = tfq.layers.Expectation()(
                                            tfqcircuit,
                                            operators=tfq.convert_to_tensor([self.observable]))
            energy = np.float32(np.squeeze(tf.math.reduce_sum(expval, axis=-1, keepdims=True)))
            self.final_params = []

        else:
            model = self.TFQ_model(symbols)
            qoutput = tf.ones((1, 1))*self.lower_bound_Eg
            model.fit(x=tfqcircuit, y=qoutput, batch_size=1, epochs=self.qepochs, verbose=self.verbose)
            energy = np.squeeze(tf.math.reduce_sum(model.predict(tfqcircuit), axis=-1))
            self.final_params = [np.squeeze(k.numpy()) for k in model.trainable_variables]

        #if sim_q_state:
            #simulator = cirq.Simulator()
            #result = simulator.simulate(circuit, qubit_order=self.qubits)
            #probs = np.abs(result.final_state)**2
            #return energy, probs
        return energy

In [75]:
sol = Solver()

In [3]:
sol.indexed_cnots

{'0': [0, 1], '1': [0, 2], '2': [1, 0], '3': [1, 2], '4': [2, 0], '5': [2, 1]}

In [14]:
indexed_circuit = [0,7,7,8]
sol.give_circuit(indexed_circuit)

((0, 0): ───@────────────────────────────────────────────────────────────────────────────
           │
(0, 1): ───X──────────Rz(th_0)───Rx(th_1)───Rz(th_2)───Rz(th_3)───Rx(th_4)───Rz(th_5)───

(0, 2): ───Rz(th_6)───Rx(th_7)───Rz(th_8)───────────────────────────────────────────────,
 ['th_0', 'th_1', 'th_2', 'th_3', 'th_4', 'th_5', 'th_6', 'th_7', 'th_8'])

In [71]:
indexed_circuit = [0,0,7,7,7]
sol.give_circuit(indexed_circuit)


((0, 0): ───@───@──────────────────────────────────────────────────────────────────────────────────────────────────────
           │   │
(0, 1): ───X───X───Rz(th_0)───Rx(th_1)───Rz(th_2)───Rz(th_3)───Rx(th_4)───Rz(th_5)───Rz(th_6)───Rx(th_7)───Rz(th_8)───,
 ['th_0', 'th_1', 'th_2', 'th_3', 'th_4', 'th_5', 'th_6', 'th_7', 'th_8'])

In [58]:
def create_connections(sol,indexed_circuit):
    """this function loads the circuit into a dictionary identifying on which qubit each gate acts."""
    #load circuit on each qubit
    connections={str(q):["init"] for q in range(sol.n_qubits)}
    simplified = []
    cnots_order = []
    for gg in indexed_circuit:
        if gg<sol.number_of_cnots:
            cnots_order.append(gg)
    flagged = [False]*len(indexed_circuit) #to check if you have seen the cnots
    for q in range(sol.n_qubits):
        for nn,idq in enumerate(indexed_circuit):
            if idq<sol.number_of_cnots:
                control, target = sol.indexed_cnots[str(idq)]
                if q in [control, target] and not flagged[nn]:
                    connections[str(control)].append(idq) 
                    connections[str(target)].append(idq)
                    flagged[nn] = True #so you don't add the other
            else:
                if idq%sol.n_qubits == q:
                    connections[str(q)].append("u")
    return connections, cnots_order

In [181]:
def mark_to_simplify(connections,cnots_order):
    simplified_connections = connections.copy()
    simplified_cnots_order = cnots_order.copy()

    for q, path in connections.items():
        for ind,gate in enumerate(path):
            
            ##### IF GATE IS SINGLE QUIT UNITARY, CHECK IF THE NEXT ONES ARE ALSO UNITARIES AND KILL 'EM
            if gate == "u":
                for k in range(len(path)-ind-1):
                    if path[ind+k+1]=="u":
                        simplified_connections[q][ind+k+1] = -1
                    else:
                        break
            elif gate in range(sol.number_of_cnots) and ind<len(path)-1: ### this is a CNOT
                ### if the next one is the same CNOT
                if path[ind+1]==gate:
                    qs = sol.indexed_cnots[str(gate)].copy()
                    qs.remove(int(q)) #the other qubit
                    idx = -2
                    for indO, gO in enumerate(connections[str(qs[0])]): #check in the path of the other qubit
                        if gO==gate: #same CNOT
                            if idx+1 == indO: #and repeated! ---> identity !
                                simplified_connections[str(q)][ind] = -1
                                simplified_connections[str(qs[0])][indO-1] = -1
                                simplified_connections[str(q)][ind+1] = -1
                                simplified_connections[str(qs[0])][indO] = -1
                                break
                            else:
                                idx = indO
                    else:
                        break

    return simplified_connections


def simplify(solver,idx_circuit):
    connections,cnots_order = create_connections(solver,idx_circuit) 
    marked_connections = mark_to_simplify(connections)
    simplified_connections={}
    for q,g in marked_connections.items():
        string_gates = []
        for gate in g:
            if gate != -1 or gate != "init":
                string_gates.append(gate)
        simplified_connections[q] = string_gates
    return simplified_connections

In [180]:
def connections_to_vector(connections, cnots_order):
    vector = []
    flagged_connections = connections.copy()
    cnots_order_marked = cnots_order.copy()
    for q, g in connections.items():
        for ind, gate_ind in enumerate(g):
            if gate_ind >= sol.number_of_cnots:#unitary
                if flagged_connections[q][ind] != -2:
                    vector.append(gate_ind)
                    flagged_connections[q][ind]=-2
            else:
                done = False
                while not done:
                    qs = sol.indexed_cnots[str(gate)].copy()
                    qs.remove(int(q)) #the other qubit
                    other_qubit = qs[0]
                    for ind_int, gate_ind_int in enumerate(connections[str(other_qubit)]):
                        if flagged_connections[str(other_qubit)][ind_int] != -2:
                            if gate_ind_int >= sol.number_of_cnots:
                                    vector.append(gate_ind_int)
                                    flagged_connections[str(other_qubit)][ind_int]=-2
                            else:
                                qs_int = sol.indexed_cnots[str(gate_ind_int)].copy()
                                qs.remove(int(other_qubit)) #the other qubit int
                                other_qubit_int = qs[0]
                                ### two scenarios: other_qubit_int = q or not.
                                if other_qubit_int == q:
                                    ### same gate, so I just append it
                                    if gate_ind == gate_ind_int:
                                        vector.append(gate_ind)
                                        flagged_connections[q][ind]=-2
                                        flagged_connections[str(other_qubit)][ind_int]=-2
                                        done = True
                                        break
                                    else:
                                        ### different gate (tehy differ in which is the control). I check the order.
                                        for icnn,cnn in enumerate(cnots_order):
                                            if cnn==gate_ind and cnots_order_marked[icnn]!=-3:
                                                if cnots_order[icnn-1] ==gate_ind_int and icnn>0:
                                                    vector.append(gate_ind_int)
                                                    vector.append(gate)
                                                    cnots_order_marked[icnn] = -3
                                                    cnots_order_marked[icnn-1] = -3
                                                else:
                                                    vector.append(gate)
                                                    vector.append(gate_ind_int)
                                                    cnots_order_marked[icnn] = -3
                                                    cnots_order_marked[icnn+1] = -3
                                            flagged_connections[q][ind]=-2
                                            flagged_connections[str(other_qubit)][ind_int]=-2
                                            break
                                        done = True
                                        break
                                else: #### if I find a CNOT before
                                    
                                    

{'0': ['init'], '1': ['init', 'u', 3], '2': ['init', 3]}

In [186]:
cirq.unitary(cirq.Circuit([cirq.CNOT(sol.qubits[0],sol.qubits[1]),cirq.CNOT(sol.qubits[1],sol.qubits[0])]))

array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]])

In [187]:
cirq.unitary(cirq.Circuit([cirq.CNOT(sol.qubits[1],sol.qubits[0]),cirq.CNOT(sol.qubits[0],sol.qubits[1])]))

array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j]])

In [171]:
sol = Solver()
indexed_circuit = [3,3,7,3,]

In [172]:
sol.give_circuit(indexed_circuit)

((0, 1): ───@───@───Rz(th_0)───Rx(th_1)───Rz(th_2)───@───
           │   │                                    │
(0, 2): ───X───X────────────────────────────────────X───,
 ['th_0', 'th_1', 'th_2'])

In [173]:
create_connections(sol,indexed_circuit)

{'0': ['init'], '1': ['init', 3, 3, 'u', 3], '2': ['init', 3, 3, 3]}

In [174]:
mark_to_simplify(create_connections(sol,indexed_circuit))

{'0': ['init'], '1': ['init', -1, -1, 'u', 3], '2': ['init', -1, -1, 3]}

In [178]:
simplify(mark_to_simplify(create_connections(sol,indexed_circuit)))

{'0': ['init'], '1': ['init', 'u', 3], '2': ['init', 3]}

In [159]:
indexed_circuit = [0,7,7,7]
create_connections(sol,indexed_circuit)

{'0': ['init', 0], '1': ['init', 0, 'u', 'u', 'u'], '2': ['init']}

In [9]:
sol.give_circuit(indexed_circuit)

((0, 0): ───@───@──────────────────────────────────────────────────────────────────────────────────────────────────────
           │   │
(0, 1): ───X───X───Rz(th_0)───Rx(th_1)───Rz(th_2)───Rz(th_3)───Rx(th_4)───Rz(th_5)───Rz(th_6)───Rx(th_7)───Rz(th_8)───,
 ['th_0', 'th_1', 'th_2', 'th_3', 'th_4', 'th_5', 'th_6', 'th_7', 'th_8'])

In [108]:
sol.give_circuit(simplified)

((0, 0): ───@────────────────────────────────────
           │
(0, 1): ───X───Rz(th_0)───Rx(th_1)───Rz(th_2)───,
 ['th_0', 'th_1', 'th_2'])

In [87]:
casa

[1, 2]

((0, 0): ───@───@───@───@───@──────────────────────────────────────────────────────────────────────────────────────────────────────
           │   │   │   │   │
(0, 1): ───X───X───X───X───X───Rz(th_0)───Rx(th_1)───Rz(th_2)───Rz(th_3)───Rx(th_4)───Rz(th_5)───Rz(th_6)───Rx(th_7)───Rz(th_8)───,
 ['th_0', 'th_1', 'th_2', 'th_3', 'th_4', 'th_5', 'th_6', 'th_7', 'th_8'])

In [81]:
sol.give_circuit(simplified)

((0, 0): ───@────────────────────────────────────
           │
(0, 1): ───X───Rz(th_0)───Rx(th_1)───Rz(th_2)───,
 ['th_0', 'th_1', 'th_2'])