In [2]:
import functools

import numpy as np
import pennylane as qml
from scipy.linalg import expm

from graddft_qnn.dft_qnn import DFTQNN
from graddft_qnn.unitary_rep import O_h

# Defining...

## Unitary reps

In [3]:
theta = 1.0

In [4]:
unitary_reps = O_h._180_deg_rot()
unitary_reps_3_axis = O_h._180_deg_rot_3_axis()
unitary_reps_rot_ref = O_h._180_deg_rot_ref()

## Ansatz

In [5]:
XII = functools.reduce(np.kron, [qml.X.compute_matrix(), np.eye(2), np.eye(2)])
IXI = functools.reduce(np.kron, [np.eye(2), qml.X.compute_matrix(), np.eye(2)])
IIX = functools.reduce(np.kron, [np.eye(2), np.eye(2), qml.X.compute_matrix()])

In [6]:
YII = functools.reduce(np.kron, [qml.Y.compute_matrix(), np.eye(2), np.eye(2)])
IYI = functools.reduce(np.kron, [np.eye(2), qml.Y.compute_matrix(), np.eye(2)])
IIY = functools.reduce(np.kron, [np.eye(2), np.eye(2), qml.Y.compute_matrix()])

In [7]:
ZII = functools.reduce(np.kron, [qml.Z.compute_matrix(), np.eye(2), np.eye(2)])
IZI = functools.reduce(np.kron, [np.eye(2), qml.Z.compute_matrix(), np.eye(2)])
IIZ = functools.reduce(np.kron, [np.eye(2), np.eye(2), qml.Z.compute_matrix()])

$$
ZZ(\phi) = \exp\left(-i \frac{\phi}{2} (Z \otimes Z)\right) =
\begin{bmatrix}
    e^{-i \phi / 2} & 0 & 0 & 0 \\
    0 & e^{i \phi / 2} & 0 & 0 \\
    0 & 0 & e^{i \phi / 2} & 0 \\
    0 & 0 & 0 & e^{-i \phi / 2}
\end{bmatrix}
$$

$ZZ(0) = I$

In [8]:
_ZZZ = functools.reduce(
    np.kron,
    [qml.Z.compute_matrix(), qml.Z.compute_matrix(), qml.Z.compute_matrix()],
)
ZZZ_gen = -1j * theta / 2 * _ZZZ

expm(-1 theta X) exmp(-1 theta Z) -> measurement XZ
any output of the twirling is measurement, and generator

In [9]:
ZZZ = expm(ZZZ_gen)

In [10]:
_XYZ = functools.reduce(
    np.kron,
    [qml.X.compute_matrix(), qml.Y.compute_matrix(), qml.Z.compute_matrix()],
)
XYZ_gen = -1j * theta / 2 * _XYZ

XYZ = expm(XYZ_gen)

In [11]:
def process(gate_matrix, u_reprs: list[np.array], return_gen=False):
    gen = DFTQNN.twirling(gate_matrix, unitary_reps=u_reprs)
    if isinstance(gen, np.ndarray):
        lcu = qml.pauli_decompose(
            gen, check_hermitian=False, hide_identity=False, pauli=True
        )
        return (lcu, gen) if return_gen else lcu
    return None, gen if return_gen else None

# Run the twirling + generator

In [12]:
[
    process(XII, unitary_reps),
    process(IXI, unitary_reps),
    process(IIX, unitary_reps),
    process(YII, unitary_reps),
    process(IYI, unitary_reps),
    process(IIY, unitary_reps),
    process(ZII, unitary_reps),
    process(IZI, unitary_reps),
    process(IIZ, unitary_reps),
]

[(1+0j) * X(0),
 (1+0j) * X(1),
 (1+0j) * X(2),
 (None, None),
 (None, None),
 (1+0j) * Y(2),
 (None, None),
 (None, None),
 (1+0j) * Z(2)]

In [13]:
[
    process(XII, unitary_reps_3_axis),
    process(IXI, unitary_reps_3_axis),
    process(IIX, unitary_reps_3_axis),
    process(YII, unitary_reps_3_axis),
    process(IYI, unitary_reps_3_axis),
    process(IIY, unitary_reps_3_axis),
    process(ZII, unitary_reps_3_axis),
    process(IZI, unitary_reps_3_axis),
    process(IIZ, unitary_reps_3_axis),
]

[(1+0j) * X(0),
 (1+0j) * X(1),
 (1+0j) * X(2),
 (-0.3333333432674408+0j) * Y(0),
 (-0.3333333432674408+0j) * Y(1),
 (-0.3333333432674408+0j) * Y(2),
 (-0.3333333432674408+0j) * Z(0),
 (-0.3333333432674408+0j) * Z(1),
 (-0.3333333432674408+0j) * Z(2)]

In [14]:
[
    process(XII, unitary_reps_rot_ref),
    process(IXI, unitary_reps_rot_ref),
    process(IIX, unitary_reps_rot_ref),
    process(YII, unitary_reps_rot_ref),
    process(IYI, unitary_reps_rot_ref),
    process(IIY, unitary_reps_rot_ref),
    process(ZII, unitary_reps_rot_ref),
    process(IZI, unitary_reps_rot_ref),
    process(IIZ, unitary_reps_rot_ref),
]

[(1+0j) * X(0),
 (1+0j) * X(1),
 (1+0j) * X(2),
 (-0.5+0j) * Y(0),
 (None, None),
 (None, None),
 (-0.5+0j) * Z(0),
 (None, None),
 (None, None)]

In [15]:
zzz_pauli_gen, zzz_gen = process(ZZZ, unitary_reps_3_axis, True)

In [16]:
expm((-0.20000000298023224 + 0j) * qml.Y(0).compute_matrix())

array([[1.02006676+0.j        , 0.        +0.20133601j],
       [0.        -0.20133601j, 1.02006676+0.j        ]])

In [17]:
zzz_pauli_gen

(0.8775825500488281+0j) * I
+ -0.4794255495071411j * Z(0) @ Z(1) @ Z(2)

In [18]:
zzz = expm(zzz_gen)

In [19]:
zzz

array([[2.13393-1.1093895j, 0.     +0.j       , 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       , 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       ],
       [0.     +0.j       , 2.13393+1.1093895j, 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       , 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       ],
       [0.     +0.j       , 0.     +0.j       , 2.13393+1.1093895j,
        0.     +0.j       , 0.     +0.j       , 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       ],
       [0.     +0.j       , 0.     +0.j       , 0.     +0.j       ,
        2.13393-1.1093895j, 0.     +0.j       , 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       ],
       [0.     +0.j       , 0.     +0.j       , 0.     +0.j       ,
        0.     +0.j       , 2.13393+1.1093895j, 0.     +0.j       ,
        0.     +0.j       , 0.     +0.j       ],
       [0.     +0.j       , 0.     +0.j       , 0.     +0.j       ,
       

In [20]:
diag = [zzz[i][i] for i in range(len(zzz))]

In [21]:
assert np.allclose(np.diag(diag), zzz)

In [22]:
diag

[(2.13393-1.1093895j),
 (2.13393+1.1093895j),
 (2.13393+1.1093895j),
 (2.13393-1.1093895j),
 (2.13393+1.1093895j),
 (2.13393-1.1093895j),
 (2.13393-1.1093895j),
 (2.13393+1.1093895j)]

In [23]:
(1 + 2j) * (-1j)

(2-1j)

In [24]:
process(XYZ, unitary_reps_rot_ref)

(0.8775825500488281+0j) * I

# Design a circuit

In [25]:
expm(-1j * 2 * qml.X.compute_matrix())

array([[-0.41614684+0.j        ,  0.        -0.90929743j],
       [ 0.        -0.90929743j, -0.41614684+0.j        ]])

In [26]:
qml.RX.compute_matrix(4)

array([[-0.41614684+0.j        ,  0.        -0.90929743j],
       [ 0.        -0.90929743j, -0.41614684+0.j        ]])

In [27]:
def twirling_(ansatz: np.array, unitary_reps: list[np.array]):
    generator = np.zeros_like(ansatz, dtype=np.complex64)
    ansatz = ansatz.astype(np.complex64)
    for unitary_rep in unitary_reps:
        twirled = 0.5 * (ansatz + unitary_rep @ ansatz @ unitary_rep.conjugate())
        if np.allclose(twirled, np.zeros_like(twirled)):
            print("All zero")
        else:
            print(twirled)
            print(qml.pauli_decompose(twirled))
        print()

In [28]:
def commute_(A: np.array, B: np.array):
    return np.allclose(A @ B, B @ A)

In [29]:
for u in unitary_reps_3_axis:
    assert commute_(ZZZ, u)

In [30]:
ZZZ

array([[0.87758256-0.47942554j, 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        +0.j        , 0.87758256+0.47942554j,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        +0.j        , 0.        +0.j        ,
        0.87758256+0.47942554j, 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.87758256-0.47942554j,
        0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0.j        ],
       [0.        +0.j        , 0.        +0.j        ,
        0.        +0.j        , 0.        +0

In [76]:
# important
all_gates = [XII, IXI, IIX, YII, IYI, IIY, ZII, IZI, IIZ]
for i, gate in enumerate(all_gates):
    print(i)
    twirling_(gate, unitary_reps_3_axis)
    print("====")

0
[[0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j]
 [1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]]
1.0 * (X(0) @ I(1) @ I(2))

[[0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j]
 [1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]]
1.0 * (X(0) @ I(1) @ I(2

In [32]:
a = qml.X(0) @ qml.Y(0)
qml.matrix(a)

array([[0.+1.j, 0.+0.j],
       [0.+0.j, 0.-1.j]])

In [33]:
qml.MultiRZ.compute_matrix(1, 4).shape

(16, 16)

In [34]:
qml.ctrl(qml.X, control=0, control_values=(1), work_wires=(1)).compute_matrix()

array([[0, 1],
       [1, 0]])

## Defining the Pooling gate

In [35]:
RZZ = qml.MultiRZ(1,[0,1])

In [36]:
RXX = expm(-1j*0.5*qml.matrix(qml.X(0)@qml.X(1)))
RXY = expm(-1j*0.5*qml.matrix(qml.X(0)@qml.Y(1)))
RXZ = expm(-1j*0.5*qml.matrix(qml.X(0)@qml.Z(1)))

RYX = expm(-1j*0.5*qml.matrix(qml.Y(0)@qml.X(1)))
RYY = expm(-1j*0.5*qml.matrix(qml.Y(0)@qml.Y(1)))
RYZ = expm(-1j*0.5*qml.matrix(qml.Y(0)@qml.Z(1)))

RZX = expm(-1j*0.5*qml.matrix(qml.Z(0)@qml.X(1)))
RZY = expm(-1j*0.5*qml.matrix(qml.Z(0)@qml.Y(1)))
RZZ = expm(-1j*0.5*qml.matrix(qml.Z(0)@qml.Z(1)))

In [37]:
def generate_control_xy(xy):
    controlled_u = qml.ctrl(qml.QubitUnitary(xy, wires=(1,2)), control_values=(1), control=0)
    return qml.matrix(controlled_u)

Test generated control gate

In [38]:
generate_control_xy(RXX) @ functools.reduce(np.kron, [[0, 1], [1,0], [0,1]])

array([0.        +0.j        , 0.        +0.j        ,
       0.        +0.j        , 0.        +0.j        ,
       0.        +0.j        , 0.87758256+0.j        ,
       0.        -0.47942554j, 0.        +0.j        ])

In [39]:
generate_control_xy(RXX) @ functools.reduce(np.kron, [[1, 0], [1,0], [0,1]])

array([0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j])

In [40]:
RXX @ [0, 1, 0, 0]

array([0.        +0.j        , 0.87758256+0.j        ,
       0.        -0.47942554j, 0.        +0.j        ])

Done test

In [41]:
for i, pooling_g in enumerate([RXX, RXY, RXZ, RYX, RYY, RYZ, RZX, RZY, RZZ]):
    for u in unitary_reps_3_axis:
        if not commute_(generate_control_xy(pooling_g), u):
            print(f"{i} failed")
            break

0 failed
1 failed
2 failed
3 failed
4 failed
5 failed
6 failed
7 failed
8 failed


In [45]:
np.allclose(CRXX, np.conj(CRXX))

False

In [59]:
np.set_printoptions(2, suppress=True)

In [74]:
XX = qml.matrix(qml.X(0)@qml.X(1))
XY =qml.matrix(qml.X(0)@qml.Y(1))
XZ =qml.matrix(qml.X(0)@qml.Z(1))

YX =qml.matrix(qml.Y(0)@qml.X(1))
YY =qml.matrix(qml.Y(0)@qml.Y(1))
YZ =qml.matrix(qml.Y(0)@qml.Z(1))

ZX =qml.matrix(qml.Z(0)@qml.X(1))
ZY =qml.matrix(qml.Z(0)@qml.Y(1))
ZZ =qml.matrix(qml.Z(0)@qml.Z(1))

In [77]:
all_gates = [XX, XY, XZ, YX, YY, YZ, ZX, ZY, ZZ]
for i, gate in enumerate(all_gates):
    print(i)
    twirling_(generate_control_xy(gate), unitary_reps_3_axis)
    print("====")

0
[[1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j]
 [0.+0.j 0.+0.j 0.+0.j 0.+0.j 1.+0.j 0.+0.j 0.+0.j 0.+0.j]]
0.5 * (I(0) @ I(1) @ I(2)) + 0.5 * (I(0) @ X(1) @ X(2)) + 0.5 * (Z(0) @ I(1) @ I(2)) + -0.5 * (Z(0) @ X(1) @ X(2))

[[0.5+0.j 0. +0.j 0. +0.j 0.5+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0.5+0.j 0.5+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0.5+0.j 0.5+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0.5+0.j 0. +0.j 0. +0.j 0.5+0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0.5+0.j 0. +0.j 0. +0.j 0.5+0.j]
 [0. +0.j 0. +0.j 0. +0.j 0. +0.j 0. +0.j 0.5+0.j 0.5+0.j 0. +0.j]
 [0. +0

In [88]:
_crxx = expm(-1j*generate_control_xy(XX))

In [89]:
_crxx @ functools.reduce(np.kron, [[1,0], [0,1], [0,1]])

array([0.  +0.j  , 0.  +0.j  , 0.  +0.j  , 0.54-0.84j, 0.  +0.j  ,
       0.  +0.j  , 0.  +0.j  , 0.  +0.j  ])

In [91]:
RXX @ functools.reduce(np.kron, [[0,1], [0,1]])

array([0.  -0.48j, 0.  +0.j  , 0.  +0.j  , 0.88+0.j  ])

In [98]:
generate_control_xy(XX)

array([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j],
       [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]])

In [97]:
unitary_reps_3_axis

[array([[0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0]]),
 array([[0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 1, 0],
        [0, 1, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0]]),
 array([[0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0]])]

## 