In this notebook we give a proof of concept of unitary compiling using TFQ. 

In [7]:
%load_ext autoreload
%autoreload 2

import sympy 
import numpy as np 
import pandas as pd 
import tensorflow as tf
from utilities.circuit_database import CirqTranslater
from utilities.templates import *
from utilities.variational import Minimizer
from utilities.misc import get_qubits_involved, reindex_symbol, shift_symbols_down
import matplotlib.pyplot as plt 
import tensorflow_quantum as tfq
import cirq
from utilities.compiling import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def get_positional_dbs(circuit, circuit_db):

    qubits_involved = get_qubits_involved(circuit, circuit_db)
    
    gates_on_qubit = {q:[] for q in qubits_involved}
    on_qubit_order = {q:[] for q in qubits_involved}

    for order_gate, ind_gate in enumerate( circuit_db["ind"]):
        if ind_gate < translator.number_of_cnots:
            control, target = translator.indexed_cnots[str(ind_gate)]
            gates_on_qubit[control].append(ind_gate)
            gates_on_qubit[target].append(ind_gate)
            on_qubit_order[control].append(order_gate)                
            on_qubit_order[target].append(order_gate)  
        else:
            gates_on_qubit[(ind_gate-translator.n_qubits)%translator.n_qubits].append(ind_gate)
            on_qubit_order[(ind_gate-translator.n_qubits)%translator.n_qubits].append(order_gate)        
    return gates_on_qubit, on_qubit_order

In [137]:
translator = CirqTranslater(2)
db1 = u1_layer(translator)
circuit_db = concatenate_dbs([db1]*2)
circuit, circuit_db  = translator.give_circuit(circuit_db)
gates_on_qubit, on_qubit_order = get_positional_dbs(circuit, circuit_db)
simplified_db = circuit_db.copy()


In [138]:
simplification = False

type_get = lambda x, translator: (x-translator.number_of_cnots)//translator.n_qubits

for q, qubit_gates_path in gates_on_qubit.items():
    if simplification is True:
        break
    for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path[:-2]):
        if simplification is True:
            break
        ind_gate_p1 = qubit_gates_path[order_gate_on_qubit+1]
        ind_gate_p2 = qubit_gates_path[order_gate_on_qubit+2]
        check_rot = lambda ind_gate: translator.number_of_cnots<= ind_gate <(3*translator.n_qubits + translator.number_of_cnots)

        if (check_rot(ind_gate) == True) and (check_rot(ind_gate_p1) == True) and (check_rot(ind_gate_p2) == True):


            type_0 = type_get(ind_gate,translator)
            type_1 = type_get(ind_gate_p1,translator)
            type_2 = type_get(ind_gate_p2,translator)


            if type_0 == type_2:
                types = [type_0, type_1, type_2]
                for next_order_gate_on_qubit, ind_gate_next in enumerate(qubit_gates_path[order_gate_on_qubit+3:]):
                    if (check_rot(ind_gate_next) == True):# and (next_order_gate_on_qubit < len(qubit_gates_path[order_gate_on_qubit+3:])):
                        types.append(type_get(ind_gate_next, translator))
                        simplification=True                        
                    else:
                        break
                if simplification == True:
                    indices_to_compile = [on_qubit_order[q][order_gate_on_qubit+k] for k in range(len(types))]
                    translator.translator_ = CirqTranslater(n_qubits=2)
                    u_to_compile_db = simplified_db.loc[indices_to_compile]
                    u_to_compile_db["ind"] = translator.translator_.n_qubits*type_get(u_to_compile_db["ind"], translator) + translator.translator_.number_of_cnots#type_get(u_to_compile_db["ind"], translator.translator_)#translator.translator_.n_qubits*(u_to_compile_db["ind"] - translator.number_of_cnots)//translator.n_qubits + translator.translator_.number_of_cnots
                    u_to_compile_db["symbol"] = None ##just to be sure it makes no interference with the compiler...


                    compile_circuit, compile_circuit_db = construct_compiling_circuit(translator.translator_, u_to_compile_db)
                    minimizer = Minimizer(translator.translator_, mode="compiling", hamiltonian="Z")

                    cost, resolver, history = minimizer.minimize([compile_circuit], symbols=translator.get_symbols(compile_circuit_db))

                    OneQbit_translator = CirqTranslater(n_qubits=1)
                    u1s = u1_db(OneQbit_translator, 0, params=True)
                    u1s["param_value"] = -np.array(list(resolver.values()))
                    resu_comp, resu_db = OneQbit_translator.give_circuit(u1s,unresolved=False)


                    u_to_compile_db_1q = u_to_compile_db.copy()
                    u_to_compile_db_1q["ind"] = u_to_compile_db["ind"] = type_get(u_to_compile_db["ind"], translator.translator_) ##type_get(u_to_compile_db["ind"],OneQbit_translator)# - translator.translator_.number_of_cnots)//translator.translator_.n_qubits


                    cc, cdb = OneQbit_translator.give_circuit(u_to_compile_db_1q, unresolved=False)
                    c = cc.unitary()
                    r = resu_comp.unitary()



                    ## phase_shift if necessary
                    if np.abs(np.mean(c/r) -1) > 1:
                        u1s.loc[0] = u1s.loc[0].replace(to_replace=u1s["param_value"][0], value=u1s["param_value"][0] + 2*np.pi)# Rz(\th) = e^{-ii \theta \sigma_z / 2}c0, cdb0 = translator.give_circuit(pd.DataFrame([gate_template(0, param_value=2*np.pi)]), unresolved=False)
                    resu_comp, resu_db = translator.give_circuit(u1s,unresolved=False)

                    
                    
                    first_symbols = simplified_db["symbol"][indices_to_compile][:3]

                    for new_ind, typ, pval in zip(indices_to_compile[:3],[0,1,0], list(u1s["param_value"])):
                        simplified_db.loc[new_ind+0.1] = gate_template(translator.number_of_cnots + q + typ*translator.n_qubits,
                                                                         param_value=pval, block_id=simplified_db.loc[new_ind]["block_id"], 
                                                                         trainable=True, symbol=first_symbols[new_ind])

                    for old_inds in indices_to_compile:
                        simplified_db = simplified_db.drop(labels=[old_inds],axis=0)#

                    simplified_db = simplified_db.sort_index().reset_index(drop=True)
                    killed_indices = indices_to_compile[3:]
                    db_follows = circuit_db[circuit_db.index>indices_to_compile[-1]]

                    if len(db_follows)>0:
                        gates_to_lower = list(db_follows.index)
                        number_of_shifts = len(killed_indices)
                        for k in range(number_of_shifts):
                            simplified_db = shift_symbols_down(translator, gates_to_lower[0]-number_of_shifts, simplified_db)



6
7
8
6
7
8
6
7
8


In [141]:
s, scdb = translator.give_circuit(simplified_db, unresolved=False)

In [142]:
s

In [145]:
c, cc = translator.give_circuit(circuit_db, unresolved=False)
c

In [146]:
s.unitary()

array([[ 0.55148605+0.19670999j,  0.8072302 -0.03219678j,
         0.01633568-0.03587059j,  0.00243843-0.05432856j],
       [-0.80654442-0.0462961j ,  0.55483885-0.18704419j,
        -0.00769946+0.05383546j, -0.00938563-0.03828137j],
       [ 0.00938563-0.03828137j, -0.00769946-0.05383546j,
         0.55483885+0.18704419j,  0.80654442-0.0462961j ],
       [ 0.00243843+0.05432856j, -0.01633568-0.03587059j,
        -0.8072302 -0.03219678j,  0.55148605-0.19670999j]])

In [150]:
np.max(np.abs(c.unitary() - s.unitary()))

0.0009719984141248654

In [158]:
def rule_5(translator, simplified_db, on_qubit_order, gates_on_qubit):
    """
    compile 1-qubit gates into euler rotations
    """
    simplification = False

    type_get = lambda x, translator: (x-translator.number_of_cnots)//translator.n_qubits

    for q, qubit_gates_path in gates_on_qubit.items():
        if simplification is True:
            break
        for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path[:-2]):
            if simplification is True:
                break
            ind_gate_p1 = qubit_gates_path[order_gate_on_qubit+1]
            ind_gate_p2 = qubit_gates_path[order_gate_on_qubit+2]
            check_rot = lambda ind_gate: translator.number_of_cnots<= ind_gate <(3*translator.n_qubits + translator.number_of_cnots)

            if (check_rot(ind_gate) == True) and (check_rot(ind_gate_p1) == True) and (check_rot(ind_gate_p2) == True):


                type_0 = type_get(ind_gate,translator)
                type_1 = type_get(ind_gate_p1,translator)
                type_2 = type_get(ind_gate_p2,translator)


                if type_0 == type_2:
                    types = [type_0, type_1, type_2]
                    for next_order_gate_on_qubit, ind_gate_next in enumerate(qubit_gates_path[order_gate_on_qubit+3:]):
                        if (check_rot(ind_gate_next) == True):# and (next_order_gate_on_qubit < len(qubit_gates_path[order_gate_on_qubit+3:])):
                            types.append(type_get(ind_gate_next, translator))
                            simplification=True                        
                        else:
                            break
                    if simplification == True:
                        indices_to_compile = [on_qubit_order[q][order_gate_on_qubit+k] for k in range(len(types))]
                        translator.translator_ = CirqTranslater(n_qubits=2)
                        u_to_compile_db = simplified_db.loc[indices_to_compile]
                        u_to_compile_db["ind"] = translator.translator_.n_qubits*type_get(u_to_compile_db["ind"], translator) + translator.translator_.number_of_cnots#type_get(u_to_compile_db["ind"], translator.translator_)#translator.translator_.n_qubits*(u_to_compile_db["ind"] - translator.number_of_cnots)//translator.n_qubits + translator.translator_.number_of_cnots
                        u_to_compile_db["symbol"] = None ##just to be sure it makes no interference with the compiler...


                        compile_circuit, compile_circuit_db = construct_compiling_circuit(translator.translator_, u_to_compile_db)
                        minimizer = Minimizer(translator.translator_, mode="compiling", hamiltonian="Z")

                        cost, resolver, history = minimizer.minimize([compile_circuit], symbols=translator.get_symbols(compile_circuit_db))

                        OneQbit_translator = CirqTranslater(n_qubits=1)
                        u1s = u1_db(OneQbit_translator, 0, params=True)
                        u1s["param_value"] = -np.array(list(resolver.values()))
                        resu_comp, resu_db = OneQbit_translator.give_circuit(u1s,unresolved=False)


                        u_to_compile_db_1q = u_to_compile_db.copy()
                        u_to_compile_db_1q["ind"] = u_to_compile_db["ind"] = type_get(u_to_compile_db["ind"], translator.translator_) ##type_get(u_to_compile_db["ind"],OneQbit_translator)# - translator.translator_.number_of_cnots)//translator.translator_.n_qubits


                        cc, cdb = OneQbit_translator.give_circuit(u_to_compile_db_1q, unresolved=False)
                        c = cc.unitary()
                        r = resu_comp.unitary()



                        ## phase_shift if necessary
                        if np.abs(np.mean(c/r) -1) > 1:
                            u1s.loc[0] = u1s.loc[0].replace(to_replace=u1s["param_value"][0], value=u1s["param_value"][0] + 2*np.pi)# Rz(\th) = e^{-ii \theta \sigma_z / 2}c0, cdb0 = translator.give_circuit(pd.DataFrame([gate_template(0, param_value=2*np.pi)]), unresolved=False)
                        resu_comp, resu_db = translator.give_circuit(u1s,unresolved=False)



                        first_symbols = simplified_db["symbol"][indices_to_compile][:3]

                        for new_ind, typ, pval in zip(indices_to_compile[:3],[0,1,0], list(u1s["param_value"])):
                            simplified_db.loc[new_ind+0.1] = gate_template(translator.number_of_cnots + q + typ*translator.n_qubits,
                                                                             param_value=pval, block_id=simplified_db.loc[new_ind]["block_id"], 
                                                                             trainable=True, symbol=first_symbols[new_ind])

                        for old_inds in indices_to_compile:
                            simplified_db = simplified_db.drop(labels=[old_inds],axis=0)#

                        simplified_db = simplified_db.sort_index().reset_index(drop=True)
                        killed_indices = indices_to_compile[3:]
                        db_follows = circuit_db[circuit_db.index>indices_to_compile[-1]]

                        if len(db_follows)>0:
                            gates_to_lower = list(db_follows.index)
                            number_of_shifts = len(killed_indices)
                            for k in range(number_of_shifts):
                                simplified_db = shift_symbols_down(translator, gates_to_lower[0]-number_of_shifts, simplified_db)



        break                
    return simplification, simplified_db


def apply_rule(original_circuit_db, rule, **kwargs):
    max_cnt = kwargs.get('max_cnt',10)
    simplified, cnt = True, 0
    original_circuit, original_circuit_db = translator.give_circuit(original_circuit_db)
    gates_on_qubit, on_qubit_order = get_positional_dbs(original_circuit, original_circuit_db)
    simplified_db = original_circuit_db.copy()
    while simplified and cnt < max_cnt:
        simplified, simplified_circuit_db = rule(translator, simplified_db, on_qubit_order, gates_on_qubit)
        circuit, simplified_db = translator.give_circuit(simplified_circuit_db)
        gates_on_qubit, on_qubit_order = get_positional_dbs(circuit, simplified_db)
        cnt+=1
    return cnt, simplified_db

In [152]:
translator = CirqTranslater(2)
db1 = u1_layer(translator)
circuit_db = concatenate_dbs([db1]*2)
circuit, circuit_db  = translator.give_circuit(circuit_db)
gates_on_qubit, on_qubit_order = get_positional_dbs(circuit, circuit_db)
simplified_db = circuit_db.copy()
simplification, ssimplified_db = rule_5(translator, simplified_db, on_qubit_order, gates_on_qubit)

In [153]:
circuit, circuit_db  = translator.give_circuit(circuit_db, unresolved=False)
scircuit, scircuit_db  = translator.give_circuit(ssimplified_db, unresolved=False)

In [154]:
scircuit.unitary()

array([[ 0.71045872-0.51652071j, -0.04064287+0.08568959j,
         0.15623585+0.43877286j, -0.03749268-0.03351489j],
       [ 0.0891612 -0.03232383j,  0.58163915-0.65821056j,
         0.00307766+0.0501944j ,  0.24761015+0.39448761j],
       [-0.24761015+0.39448761j,  0.00307766-0.0501944j ,
         0.58163915+0.65821056j, -0.0891612 -0.03232383j],
       [-0.03749268+0.03351489j, -0.15623585+0.43877286j,
         0.04064287+0.08568959j,  0.71045872+0.51652071j]])

In [155]:
circuit.unitary()

array([[ 0.71040036-0.51653149j, -0.04063678+0.0856876j ,
         0.15608009+0.43891008j, -0.03748501-0.03353595j],
       [ 0.08915547-0.0323267j ,  0.58157985-0.65820843j,
         0.00305739+0.05020397j,  0.24748783+0.39465532j],
       [-0.24748783+0.39465532j,  0.00305739-0.05020397j,
         0.58157985+0.65820843j, -0.08915547-0.0323267j ],
       [-0.03748501+0.03353595j, -0.15608009+0.43891008j,
         0.04063678+0.0856876j ,  0.71040036+0.51653149j]])

In [156]:
circuit

In [157]:
scircuit

In [168]:
translator = CirqTranslater(2)
db1 = u1_layer(translator)
db2 = u2_layer(translator)
circuit_db = concatenate_dbs([db1,db1,db2,db2])
circuit, circuit_db  = translator.give_circuit(circuit_db)
gates_on_qubit, on_qubit_order = get_positional_dbs(circuit, circuit_db)
simplified_db = circuit_db.copy()
simplification, ssimplified_db = rule_5(translator, simplified_db, on_qubit_order, gates_on_qubit)

In [169]:
circuit, circuit_db  = translator.give_circuit(circuit_db, unresolved=False)
scircuit, scircuit_db  = translator.give_circuit(ssimplified_db, unresolved=False)

np.max(np.abs(circuit.unitary() - scircuit.unitary()))

0.00011662576940821021

In [170]:
circuit.unitary() - scircuit.unitary()

array([[ 4.47957068e-05+8.33739899e-05j, -2.24269189e-05+8.39146801e-06j,
        -3.27533560e-05-1.84682919e-05j,  9.69874299e-05-3.76264632e-05j],
       [ 9.10239306e-05+2.49819556e-05j,  3.66639517e-05-5.90326355e-05j,
        -1.00604967e-06+4.52783057e-05j, -1.20188435e-05+7.63759925e-05j],
       [ 5.27058440e-05+2.83923398e-05j, -5.10068951e-05+1.27153106e-05j,
         6.31555666e-05-9.16520777e-05j, -5.43449698e-05+8.82696372e-06j],
       [ 1.77769970e-05-2.21881278e-07j,  4.46477532e-05+1.07741117e-04j,
         7.01994105e-05+3.13966556e-05j,  3.36612250e-05+2.83373811e-05j]])