In [1]:
import pennylane as qml
from pennylane import numpy as np

In [2]:
# Oracle unitary operatior

U_f1 = np.array([[1,0,0,0],
                 [0,1,0,0],
                 [0,0,1,0],
                 [0,0,0,1]])

U_f2 = np.array([[0,0,1,0],
                 [0,0,0,1],
                 [1,0,0,0],
                 [0,1,0,0]])

U_f3 = np.array([[1,0,0,0],
                 [0,1,0,0],
                 [0,0,0,1],
                 [0,0,1,0]])

U_f4 = np.array([[0,1,0,0],
                 [1,0,0,0],
                 [0,0,1,0],
                 [0,0,0,1]])


In [3]:
device = qml.device('default.qubit', shots=None, wires=2)

In [4]:
def f1():
    pass

def f2():
    qml.PauliX(wires=[1])

def f3():
    qml.CNOT(wires=[0, 1])

def f4():
    qml.PauliX(wires=0)
    qml.CNOT(wires=[0, 1])
    qml.PauliX(wires=0)

In [5]:
@qml.qnode(device)
def circuit(U):
    qml.Hadamard(wires=0)
    qml.PauliX(wires=1)
    qml.Hadamard(wires=1)
    qml.QubitUnitary(U, wires=[0,1])
    qml.Hadamard(wires=0)
    return qml.probs(wires=[0])
              
probs = circuit(U_f1)
print (circuit.draw())
print (f"probs: {probs}")

 0: ──H─────╭U0──H──┤ Probs 
 1: ──X──H──╰U0─────┤       
U0 =
[[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]]

probs: [1. 0.]


## With a Cat

In [6]:
#@qml.qnode(device)
def circuit_U(U):
    qml.Hadamard(wires=0)
    qml.PauliX(wires=1)
    #qml.Hadamard(wires=1)
    #qml.RY(1*np.pi/4, wires=1)
    qml.RY(3*np.pi/4, wires=1)
    qml.QubitUnitary(U, wires=[0,1])
    qml.Hadamard(wires=0)
    return qml.probs(wires=[0])
              
for U in [U_f1, U_f2, U_f3, U_f4]:
    device = qml.device('default.qubit', shots=None, wires=2)
    probs = qml.QNode(circuit_U, device)(U)
    print (f"probs: {probs.round(6)}")

probs: [1. 0.]
probs: [1. 0.]
probs: [0.146447 0.853553]
probs: [0.146447 0.853553]


## Broken circuit

In [7]:
@qml.qnode(device)
def circuit(U):
    qml.Hadamard(wires=0)
    qml.PauliX(wires=1)
    qml.RY(1*np.pi/4, wires=1)
    qml.QubitUnitary(U, wires=[0,1])
    qml.Hadamard(wires=0)
    return qml.probs(wires=[0])
              
probs = circuit(U_f3)
print (circuit.draw())
print (f"probs: {probs}")

 0: ──H─────────────╭U0──H──┤ Probs 
 1: ──X──RY(0.785)──╰U0─────┤       
U0 =
[[1 0 0 0]
 [0 1 0 0]
 [0 0 0 1]
 [0 0 1 0]]

probs: [0.14644661 0.85355339]
