In [1]:
import numpy as np
import itertools

import cirq
from cirq import ops, Circuit
from cirq.ops import CZ, H, CNOT, X, Y, Z, SWAP

#from cirq.google import Simulator, ConvertToXmonGates, MergeRotations, MergeInteractions, EjectZ

from cirq.contrib.interaction_gate2 import InteractionGate
from cirq.contrib.control_basis import *

from cirq.testing import assert_allclose_up_to_global_phase

In [2]:
from IPython.display import display

def _ops_to_matrix(*ops, digits=3):
    c = cirq.Circuit.from_ops(*ops)
    display(c)
    m = c.to_unitary_matrix()
    return m.round(digits)

In [3]:
class Point:
    def __init__(self, qubit: ops.QubitId, axis: int):
        assert 0 <= axis < 6
        self.qubit = qubit
        self.axis = axis
        self.interaction = None
    def set_interaction(self, interaction):
        self.interaction = interaction
    def commutes_with(self, other):
        return self.axis % 3 == other.axis % 3
    def __str__(self):
        return ('X', 'Y', 'Z', 'nX', 'nY', 'nZ')[self.axis]

In [4]:
class Interaction:
    def __init__(self, qubit_point_map, half_turns: float = 1.0):
        self.qubit_point_map = qubit_point_map
        self.half_turns = half_turns
        for point in self.qubit_point_map.values():
            point.set_interaction(self)
            
    def is_single(self):
        return len(qubit_point_map) == 1
    
    def __repr__(self):
        qubit_point_list = self.qubit_point_map.items()  # TODO: Sort?
        repr_list = ('{!s}({!s})'.format(point, qubit)
                     for qubit, point in qubit_point_list)
        out = '-'.join(repr_list)
        if self.half_turns != 1.0:
            out += '**{}'.format(round(self.half_turns, 3))
        return out
    
    def as_op(self):
        if not 1 <= len(self.qubit_point_map) <= 2:
            raise NotImplementedError
        if len(self.qubit_point_map) == 1:
            qubit, point = next(iter(self.qubit_point_map.items()))
            return (X, Y, Z, X**-1, Y**-1, Z**-1)[point.axis](qubit) ** self.half_turns
        elif len(self.qubit_point_map) == 2:
            itms = iter(self.qubit_point_map.items())
            qubit0, point0 = next(itms)
            qubit1, point1 = next(itms)
            bases = (X_BASIS, Y_BASIS, Z_BASIS, nX_BASIS, nY_BASIS, nZ_BASIS)
            return InteractionGate(bases[point0.axis], bases[point1.axis], half_turns=self.half_turns)(qubit0, qubit1)


def interaction_from_op(op):
    conversions = (  # All assumed to be EigenGate subclasses
        (ops.RotXGate, (0,)),
        (ops.RotYGate, (1,)),
        (ops.RotZGate, (2,)),
        (ops.Rot11Gate, (2, 2)),
        (ops.CNotGate, (2, 0)),
        (InteractionGate, '?'),
    )
    for cls, axes in conversions:
        if isinstance(op.gate, cls):
            if axes == '?':
                qubit_map = {qubit: Point(qubit, basis.axis) for qubit, basis in zip(op.qubits, op.gate.bases)}
                half_turns = op.gate.half_turns
            else:
                assert len(op.qubits) == len(axes)
                qubit_map = {qubit: Point(qubit, axis) for qubit, axis in zip(op.qubits, axes)}
                half_turns = op.gate._exponent  # TODO: Don't access private value
            return Interaction(qubit_map, half_turns=half_turns)
    raise TypeError('Gate cannot be converted to interaction: {}'.format(op.gate))

In [5]:
class CommutingSet:
    def __init__(self, qubit: ops.QubitId, axis: int):
        assert 0 <= axis < 6
        self.qubit = qubit
        self.axis = axis % 3
        self.points = set()

    def __repr__(self):
        return '{}*{}'.format(('X', 'Y', 'Z', 'nX', 'nY', 'nZ')[self.axis], len(self.points))

    def commutes_with(self, other):
        return self.axis % 3 == other.axis % 3

    def add(self, point):
        assert point.axis % 3 == self.axis
        self.points.add(point)

In [6]:
def circuit_to_interactions(circuit):
    #qubits = circuit.qubits()
    # TODO: Connectivity
    output = {qubit: [] for qubit in circuit.qubits()}
    last_sets = {}
    for op in circuit.iter_ops():
        interaction = interaction_from_op(op)
        for qubit, point in interaction.qubit_point_map.items():
            commuting_set = last_sets.get(qubit, None)
            if commuting_set is None or not commuting_set.commutes_with(point):
                commuting_set = CommutingSet(qubit, point.axis)
                output[qubit].append(commuting_set)
                last_sets[qubit] = commuting_set
            commuting_set.add(point)
    return output

In [7]:
def interactions_to_ops(interactions):
    qubits = interactions.keys()
    last_sets = {}
    last_index = {qubit: -1 for qubit in qubits}
    count_down = len(qubits)
    used_set = set()
    def get_options(qubit):
        nonlocal count_down
        last_set = last_sets.get(qubit, None)
        if last_set is None or last_set.issubset(used_set):
            i = last_index[qubit]
            if i is None:
                return set()
            i += 1
            if i >= len(interactions[qubit]):
                last_index[qubit] = None
                count_down -= 1
                return set()
            last_index[qubit] = i
            last_set = set((point.interaction for point in interactions[qubit][i].points))
            last_sets[qubit] = last_set
        return last_set - used_set
    def get_next_interactions():
        next_interactions = set()
        interactions_map = {qubit: get_options(qubit) for qubit in qubits}
        for interactions in interactions_map.values():
            for interaction in interactions:
                for qubit in interaction.qubit_point_map.keys():
                    if interaction not in interactions_map[qubit]:
                        break
                else:  # If doesn't break
                    used_set.add(interaction)
                    return interaction
        if count_down > 0:
            raise RuntimeError('Invalid interaction dependencies')
                
    while True:
        interaction = get_next_interactions()
        if count_down <= 0:
            return
        op = interaction.as_op()
        print(interaction, op)
        yield op

def interactions_to_circuit(interactions):
    return Circuit.from_ops(interactions_to_ops(interactions), strategy=cirq.InsertStrategy.EARLIEST)

In [8]:
q0, q1, q2 = (cirq.NamedQubit('q{}'.format(i)) for i in range(3))

In [9]:
CNOT = InteractionGate(Z_BASIS, X_BASIS)
CZ = InteractionGate(Z_BASIS, Z_BASIS)

In [10]:
# Linear connectivity toffoli gate without optimization
circuit = Circuit.from_ops(
    Y(q2) ** 0.5,
    X(q2),
    CNOT(q1, q2),
    Z(q2) ** -0.25,
    CNOT(q1, q2),
    CNOT(q2, q1),
    CNOT(q1, q2),
    CNOT(q0, q1),
    CNOT(q1, q2),
    CNOT(q2, q1),
    CNOT(q1, q2),
    Z(q2) ** 0.25,
    CNOT(q1, q2),
    Z(q2) ** -0.25,
    CNOT(q1, q2),
    CNOT(q2, q1),
    CNOT(q1, q2),
    CNOT(q0, q1),
    CNOT(q1, q2),
    CNOT(q2, q1),
    CNOT(q1, q2),
    Z(q2) ** 0.25,
    Z(q1) ** 0.25,
    CNOT(q0, q1),
    Z(q0) ** 0.25,
    Z(q1) ** -0.25,
    CNOT(q0, q1),
    Y(q2) ** 0.5,
    X(q2),
)
display(circuit)
(circuit.to_unitary_matrix() / 1j + 0.000001 + 0.000001j).round(3)

array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j]])

In [11]:
interactions = circuit_to_interactions(circuit)
#interactions

In [13]:
c = interactions_to_circuit(interactions)
display(c)
(c.to_unitary_matrix() / 1j + 0.000001 + 0.000001j).round(3)

Z(q0)**0.25 Z**0.25(q0)
Y(q2)**0.5 Y**0.5(q2)
Z(q1)-X(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f14c898>(q1, q2)
X(q2) X(q2)
Z(q2)**-0.25 Z**-0.25(q2)
Z(q1)-X(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f14ca90>(q1, q2)
X(q1)-Z(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f14ce80>(q1, q2)
Z(q1)-X(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f14cc88>(q1, q2)
X(q1)-Z(q0) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f12c2e8>(q1, q0)
Z(q1)-X(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f14cf60>(q1, q2)
X(q1)-Z(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f147048>(q1, q2)
Z(q1)-X(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f1470f0>(q1, q2)
Z(q2)**0.25 Z**0.25(q2)
Z(q1)-X(q2) <cirq.contrib.interaction_gate2.InteractionGate object at 0x7f1b2f147278>(q1, q2)
Z(q2)**-0.25 Z**-0.25(q2)
Z(q1)-X(q2) <cirq.co

array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j]])

In [14]:
#c2 = Circuit.from_ops((interaction_from_op(op).as_op() for op in circuit.iter_ops()))
c2 = Circuit.from_ops((op for op in circuit.iter_ops()))
display(c2)
(c2.to_unitary_matrix() / 1j + 0.000001 + 0.000001j).round(3)

array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j]])