In this notebook we give a proof of concept of unitary compiling using TFQ. 

In [1]:
%load_ext autoreload
%autoreload 2

import sympy 
import numpy as np 
import pandas as pd 
import tensorflow as tf
from utilities.circuit_database import CirqTranslater
from utilities.templates import *
from utilities.variational import Minimizer
from utilities.misc import get_qubits_involved, reindex_symbol, shift_symbols_down
import matplotlib.pyplot as plt 
import tensorflow_quantum as tfq
import cirq

In [2]:
translator = CirqTranslater(n_qubits = 2) #system + ancilla

In [3]:
from utilities.compiling import *

In [4]:
u1ss = u1_db(translator, 1, params=True)
v_to_compile_db = concatenate_dbs([u1ss]*2) ###target gates to compile
compile_circuit, compile_circuit_db = construct_compiling_circuit(translator, conjugate_db(translator, v_to_compile_db))
minimizer = Minimizer(translator, mode="compiling", hamiltonian="Z")

In [5]:
cost, resolver, history = minimizer.minimize([compile_circuit], symbols=translator.get_symbols(compile_circuit_db))

In [6]:
u1s = u1_db(translator, 0, params=True)
u1s["param_value"] = list(resolver.values())
resu_comp, resu_db = translator.give_circuit(u1s,unresolved=False)

In [7]:
target_qubit = 1
resu_db["ind"] = resu_db["ind"] + target_qubit

In [None]:
def get_db_to_compile(simplified_db):

False

In [8]:
def get_positional_dbs(circuit, circuit_db):

    qubits_involved = get_qubits_involved(circuit, circuit_db)
    
    gates_on_qubit = {q:[] for q in qubits_involved}
    on_qubit_order = {q:[] for q in qubits_involved}

    for order_gate, ind_gate in enumerate( circuit_db["ind"]):
        if ind_gate < translator.number_of_cnots:
            control, target = translator.indexed_cnots[str(ind_gate)]
            gates_on_qubit[control].append(ind_gate)
            gates_on_qubit[target].append(ind_gate)
            on_qubit_order[control].append(order_gate)                
            on_qubit_order[target].append(order_gate)  
        else:
            gates_on_qubit[(ind_gate-translator.n_qubits)%translator.n_qubits].append(ind_gate)
            on_qubit_order[(ind_gate-translator.n_qubits)%translator.n_qubits].append(order_gate)        
    return gates_on_qubit, on_qubit_order

In [65]:
def rule_1(translator, simplified_db, on_qubit_order, gates_on_qubit):
    simplification = False
    
    for q, qubit_gates_path in gates_on_qubit.items():
        if simplification is True:
            break
        for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path):
            if ind_gate < translator.number_of_cnots:
                control, target = translator.indexed_cnots[str(ind_gate)]
                if (q == control) and (order_gate_on_qubit == 0):
                    pos_gate_to_drop = on_qubit_order[q][order_gate_on_qubit]
                    
                    block_id = circuit_db.loc[pos_gate_to_drop]["block_id"]
                    simplified_db.loc[int(pos_gate_to_drop)+0.1] = gate_template(translator.number_of_cnots + translator.n_qubits + control, param_value=0.0, block_id=circuit_db.loc[0]["block_id"])
                    simplified_db.loc[int(pos_gate_to_drop)+0.11] = gate_template(translator.number_of_cnots + translator.n_qubits + target, param_value=0.0, block_id=circuit_db.loc[0]["block_id"])
                    
                    simplified_db = simplified_db.drop(labels=[pos_gate_to_drop],axis=0)
                    
                    simplification = True
                    break
    simplified_db = simplified_db.sort_index().reset_index(drop=True)
    return simplification, simplified_db


def rule_2(translator, simplified_db, on_qubit_order, gates_on_qubit):
    simplification = False
    
    for q, qubit_gates_path in gates_on_qubit.items():
        if simplification is True:
            break
        for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path[:-1]):
            
            next_ind_gate = qubit_gates_path[order_gate_on_qubit+1]
            if (ind_gate < translator.number_of_cnots) and (ind_gate == next_ind_gate):
                control, target = translator.indexed_cnots[str(ind_gate)]
                not_gates_in_between = False
                this_qubit = q
                other_qubits = [control, target]
                other_qubits.remove(q)
                other_qubit = other_qubits[0]
                
                ## now we need to check what happens in the other_qubit
                for qord_other, ind_gate_other in enumerate(gates_on_qubit[other_qubit][:-1]):
                    if (ind_gate_other == ind_gate) and (gates_on_qubit[other_qubit][qord_other +1] == ind_gate):
                        ## if we append the CNOT for q and other_q on the same call, and also for the consecutive
                        ## note that in between there can be other calls for other qubits
                        order_call_q = on_qubit_order[q][order_gate_on_qubit]
                        order_call_other_q = on_qubit_order[other_qubit][qord_other]
                        
                        order_call_qP1 = on_qubit_order[q][order_gate_on_qubit+1]
                        order_call_other_qP1 = on_qubit_order[other_qubit][qord_other+1]
                        
                        ## then it's kosher to say they are consecutively applied (if only looking at the two qubits)
                        if (order_call_q == order_call_other_q) and (order_call_qP1 == order_call_other_qP1):
                            
                            pos_gate_to_drop = on_qubit_order[q][order_gate_on_qubit]
                            simplified_db = simplified_db.drop(labels=[pos_gate_to_drop],axis=0)
                            pos_gate_to_drop = on_qubit_order[q][order_gate_on_qubit+1]
                            simplified_db = simplified_db.drop(labels=[pos_gate_to_drop],axis=0)

                            simplification = True
                            break
                if simplification is True:
                    break
    simplified_db = simplified_db.reset_index(drop=True)
    return simplification, simplified_db



def rule_3(translator, simplified_db, on_qubit_order, gates_on_qubit):
    simplification = False
    for q, qubit_gates_path in gates_on_qubit.items():
        if simplification is True:
            break
        for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path[:-1]):
            if order_gate_on_qubit == 0 and (translator.number_of_cnots <= ind_gate< translator.number_of_cnots+ translator.n_qubits ):
                pos_gate_to_drop = on_qubit_order[q][order_gate_on_qubit]
                simplified_db = simplified_db.drop(labels=[pos_gate_to_drop],axis=0)
                simplified_db = simplified_db.reset_index(drop=True)
                simplified_db = shift_symbols_down(translator, pos_gate_to_drop, simplified_db)
                simplification = True
                break
    return simplification, simplified_db



def rule_4(translator, simplified_db, on_qubit_order, gates_on_qubit):
    """
    Repeated rotations: add the values
    """
    simplification = False
    for q, qubit_gates_path in gates_on_qubit.items():
        if simplification is True:
            break
        for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path[:-1]):
            if ind_gate>=translator.number_of_cnots:
                next_ind_gate = qubit_gates_path[order_gate_on_qubit+1]
                if next_ind_gate == ind_gate:
                    pos_gate_to_drop = on_qubit_order[q][order_gate_on_qubit]
                    pos_gate_to_add = on_qubit_order[q][order_gate_on_qubit+1]
                    
                    value_1 = simplified_db.loc[pos_gate_to_drop]["param_value"]
                    value_2 = simplified_db.loc[pos_gate_to_add]["param_value"]
                    
                    simplified_db.loc[pos_gate_to_add] = simplified_db.loc[pos_gate_to_add].replace(to_replace=value_2, value=value_1 + value_2)
                    simplified_db = simplified_db.drop(labels=[pos_gate_to_drop],axis=0)
                    simplified_db = simplified_db.reset_index(drop=True)

                    simplified_db = shift_symbols_down(translator, pos_gate_to_drop, simplified_db)
                    simplification = True
                    break
    return simplification, simplified_db


def rule_5(translator, simplified_db, on_qubit_order, gates_on_qubit):
    """
    compile 1-qubit gates into euler rotations
    """
    simplification = False
    for q, qubit_gates_path in gates_on_qubit.items():
        if simplification is True:
            break
        for order_gate_on_qubit, ind_gate in enumerate(qubit_gates_path[:-2]):
            if simplification is True:
                break
            ind_gate_p1 = qubit_gates_path[order_gate_on_qubit+1]
            ind_gate_p2 = qubit_gates_path[order_gate_on_qubit+2]
            check_rot = lambda ind_gate: translator.number_of_cnots<= ind_gate <(3*translator.n_qubits + translator.number_of_cnots)
            
            if (check_rot(ind_gate) == True) and (check_rot(ind_gate_p1) == True) and (check_rot(ind_gate_p2) == True):

                type_get = lambda x: (x-translator.number_of_cnots)//translator.n_qubits
                type_0 = type_get(ind_gate)
                type_1 = type_get(ind_gate_p1)
                type_2 = type_get(ind_gate_p2)
                
                
                if type_0 == type_2:
                    types = [type_0, type_1, type_2]
                    for next_order_gate_on_qubit, ind_gate_next in enumerate(qubit_gates_path[order_gate_on_qubit+3:-2]):
                        if (check_rot(ind_gate_next) == True) and (next_order_gate_on_qubit != len():
                            types.append(type_get(ind_gate_next))
                        else:
                            indices_to_compile = [on_qubit_order[q][order_gate_on_qubit+k] for k in range(len(types))]
                            if len(indices_to_compile)>3:
                                v_to_compile_db = simplified_db.loc[indices_to_compile]
                                v_to_compile_db["ind"] = (v_to_compile_db["ind"]-q +1) #target qubit in unitary compilation set to 1
                                v_to_compile_db["symbol"] = None ##just to be sure it makes no interference with the compiler...
                            compile_circuit, compile_circuit_db = construct_compiling_circuit(translator, conjugate_db(translator, v_to_compile_db))
                            if hasattr(translator, "compiler") == False:
                                translator.compiler = Minimizer(translator, mode="compiling", hamiltonian="Z")
                            cost, resolver, history = compiler.minimize([compile_circuit], symbols=translator.get_symbols(compile_circuit_db))
                    
                            for new_ind, typ, pval in zip(indices_to_compile[:3],[0,1,0], list(resolver.values())):
                                simplified_db.loc[new_ind + 0.1] = gate_template(translator.number_of_cnots + q + typ*translator.n_qubits,
                                                                                 param_value=pval, block_id=simplified_db.loc[new_ind]["block_id"], 
                                                                                 trainable=True, symbol=simplified_db.loc[new_ind]["symbol"])
                            for old_inds in indices_to_compile:
                                simplified_db = simplified_db.drop(labels=[old_inds],axis=0)
                            
                            ## check this..!
                            for down_symbols in range(len(indices_to_compile)-3):
                                simplified_db = shift_symbols_down(translator, old_inds, simplified_db)

                            simplification = True
                            break
    return simplification, simplified_db


def apply_rule(original_circuit_db, rule, **kwargs):
    max_cnt = kwargs.get('max_cnt',10)
    simplified, cnt = True, 0
    original_circuit, original_circuit_db = translator.give_circuit(original_circuit_db)
    gates_on_qubit, on_qubit_order = get_positional_dbs(original_circuit, original_circuit_db)
    simplified_db = original_circuit_db.copy()
    while simplified and cnt < max_cnt:
        simplified, simplified_circuit_db = rule(translator, simplified_db, on_qubit_order, gates_on_qubit)
        circuit, simplified_db = translator.give_circuit(simplified_circuit_db)
        gates_on_qubit, on_qubit_order = get_positional_dbs(circuit, simplified_db)
        cnt+=1
    return cnt, simplified_db

In [66]:
translator = CirqTranslater(3)
db1 = u1_layer(translator)
circuit_db = concatenate_dbs([db1]*5)
circuit, circuit_db  = translator.give_circuit(circuit_db)
gates_on_qubit, on_qubit_order = get_positional_dbs(circuit, circuit_db)
simplified_db = circuit_db.copy()
rule_5(translator, simplified_db, on_qubit_order, gates_on_qubit)

(False,
     ind symbol  param_value  trainable  block_id
 0     6   th_0    -2.493851       True         0
 1     9   th_1     0.502264       True         0
 2     6   th_2    -1.157961       True         0
 3     7   th_3     7.267639       True         0
 4    10   th_4   -10.027060       True         0
 5     7   th_5     4.884089       True         0
 6     8   th_6    -1.345441       True         0
 7    11   th_7     8.410865       True         0
 8     8   th_8    -4.651597       True         0
 9     6   th_9    -2.493851       True         0
 10    9  th_10     0.502264       True         0
 11    6  th_11    -1.157961       True         0
 12    7  th_12     7.267639       True         0
 13   10  th_13   -10.027060       True         0
 14    7  th_14     4.884089       True         0
 15    8  th_15    -1.345441       True         0
 16   11  th_16     8.410865       True         0
 17    8  th_17    -4.651597       True         0
 18    6  th_18    -2.493851       True   

139
139
139
139
139
139
139
139
139
139
139
139
139
139
139


(False,
     ind symbol  param_value  trainable  block_id
 0     6   th_0    10.463439       True         0
 1     9   th_1     5.316989       True         0
 2     6   th_2     1.903573       True         0
 3     7   th_3     1.766271       True         0
 4    10   th_4    -0.315403       True         0
 5     7   th_5    -0.443162       True         0
 6     8   th_6     5.933082       True         0
 7    11   th_7    -1.573780       True         0
 8     8   th_8     5.308694       True         0
 9     6   th_9    10.463439       True         0
 10    9  th_10     5.316989       True         0
 11    6  th_11     1.903573       True         0
 12    7  th_12     1.766271       True         0
 13   10  th_13    -0.315403       True         0
 14    7  th_14    -0.443162       True         0
 15    8  th_15     5.933082       True         0
 16   11  th_16    -1.573780       True         0
 17    8  th_17     5.308694       True         0
 18    6  th_18    10.463439       True   

In [64]:
cnt, simplified_db = apply_rule(circuit_db, rule_4)

In [65]:
translator.give_circuit(simplified_db)[0]