In [1]:
%load_ext autoreload
%autoreload 2

In [63]:
from eqmarl import *
import pennylane as qml
import numpy as np
import functools as ft

In [64]:
def expectedValueOfObservable(psi, obs):
    return np.transpose(np.conjugate(psi)) @ obs @ psi

def kron(*args):
    """Kronecker product of a list of elements."""
    return ft.reduce(np.kron, args)

In [104]:
def dagger(x):
    """Conjugate transpose."""
    return (x.conjugate()).transpose()

def ket(*x):
    """|x0, x1, ...>"""
    assert all(i in (0, 1) for i in x), 'ket only accepts binary 0 or 1'
    x = [np.array([[0], [1]]) if i else np.array([[1], [0]]) for i in x]
    x = kron(*x)
    return x

def bra(*x):
    """<x0, x1, ...|"""
    assert all(i in (0, 1) for i in x), 'bra only accepts binary 0 or 1'
    x = ket(*x)
    return dagger(x)

In [195]:
def bell(src=0, tgt=1, bellstate: Literal[0,1,2,3] = 0, nwires=2, op=qml.CNOT):
    """Matrix representation for circuit that creates a Bell state on the designated source and target qubits.
    
    Supply the `bellstate` argument to prepare a specific Bell state:
    - 0 == Phi+
    - 1 == Psi+
    - 2 == Phi-
    - 3 == Psi-
    
    The Bell state vector can be recovered by applying the dot product with the |0> state as follows:
    >>> U = bell(0, 1, bellstate=0, nwires=2) # Creates matrix which generates Phi+
    >>> U @ np.array([[1], [0], [0], [0]]) # Dot product with |00>
    array([[0.70710678],
       [0.        ],
       [0.        ],
       [0.70710678]])
    """
    def circuit():
        match bellstate:
            case 1: # 01
                qml.PauliX(wires=tgt)
            case 2: # 10
                qml.PauliX(wires=src)
            case 3: # 11
                qml.PauliX(wires=tgt)
                qml.PauliX(wires=src)
            case _: # 00
                pass
        qml.Hadamard(wires=src)
        op(wires=[src, tgt])
    return qml.matrix(circuit, wire_order=range(nwires))()


def ghz(src: int = 0, tgt: list[int] = 1, bellstate: Literal[0,1,2,3] = 0, nwires=2, op=qml.CNOT):
    
    if isinstance(tgt, int):
        tgt = [tgt]
    
    def circuit():
        if bellstate in (2, 3):
            qml.PauliX(wires=src)
        qml.Hadamard(wires=src)
        for t in tgt:
            if bellstate in (1, 3):
                qml.PauliX(wires=t)
            op(wires=[src, t])
    return qml.matrix(circuit, wire_order=range(nwires))()


In [143]:
# Get matrix representations of Pauli operators and other common gates.
X = qml.matrix(qml.PauliX(wires=0))
Y = qml.matrix(qml.PauliY(wires=0))
Z = qml.matrix(qml.PauliZ(wires=0))
H = qml.matrix(qml.Hadamard(wires=0))
I = qml.matrix(qml.Identity(wires=0))
CNOT = lambda src=0, tgt=1, nwires=2: qml.matrix(qml.CNOT(wires=[src, tgt]), wire_order=range(nwires)) # CNOT(source, target, number_of_wires)
CZ = lambda src=0, tgt=1, nwires=2: qml.matrix(qml.CZ(wires=[src, tgt]), wire_order=range(nwires)) # CZ(source, target, number_of_wires)

In [196]:
# Matrices that create specific bell states on 2 qubits next-neighbor style (source is qubit 0, target is qubit 1).
bell_phi_plus = bell(src=0, tgt=1, bellstate=0, nwires=2)
bell_psi_plus = bell(src=0, tgt=1, bellstate=1, nwires=2)
bell_phi_minus = bell(src=0, tgt=1, bellstate=2, nwires=2)
bell_psi_minus = bell(src=0, tgt=1, bellstate=3, nwires=2)

In [182]:
n_agents = 2
d_qubits = 2
n_layers = 1

circuit = MARLCircuit(
    n_agents=n_agents,
    d_qubits=d_qubits,
    n_layers=n_layers,
    )

weight_shapes = circuit.weight_shapes

agents_var_thetas = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_var_thetas'])
agents_enc_inputs = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_enc_inputs'])

print(qml.draw(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs))


# System unitary (all agents).
U = qml.matrix(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs)

0: ─╭VariationalEncodingPQC─┤  
1: ─╰VariationalEncodingPQC─┤  
2: ─╭VariationalEncodingPQC─┤  
3: ─╰VariationalEncodingPQC─┤  


In [203]:
obs = [
    kron(Z, I, I, I),
    kron(I, Z, I, I),
    kron(Z, Z, I, I),
    kron(I, I, Z, I),
    kron(I, I, I, Z),
    kron(I, I, Z, Z),
    ]
n_obs = len(obs)
n_obs_per_agent = int(n_obs / 2)

In [218]:
M = U @ ket(0,0,0,0) # |0> initial state

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.18761012+0.j]],
 
        [[ 0.15938534+0.j]],
 
        [[-0.11824798+0.j]],
 
        [[-0.22005712+0.j]],
 
        [[ 0.03948987+0.j]],
 
        [[-0.03002138+0.j]]]),
 array([[[ 0.032447  +0.j]],
 
        [[ 0.11989547+0.j]],
 
        [[-0.0882266 +0.j]]]))

In [219]:
B = kron(bell_phi_plus, bell_phi_plus) # Phi+ on [0,1] and [2,3]
M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.70911625+0.j]],
 
        [[ 0.07179181+0.j]],
 
        [[-0.12724351+0.j]],
 
        [[ 0.45386768+0.j]],
 
        [[ 0.15901161+0.j]],
 
        [[-0.00454907+0.j]]]),
 array([[[-1.16298392+0.j]],
 
        [[-0.08721981+0.j]],
 
        [[-0.12269444+0.j]]]))

In [220]:
B = bell(src=0, tgt=2, bellstate=0, nwires=4) @ bell(src=1, tgt=3, bellstate=0, nwires=4) # Phi+ on [0,2] and [1,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-1.52655666e-16+0.j]],
 
        [[ 2.77555756e-17+0.j]],
 
        [[ 5.55111512e-17+0.j]],
 
        [[-1.66533454e-16+0.j]],
 
        [[-8.32667268e-17+0.j]],
 
        [[-5.55111512e-17+0.j]]]),
 array([[[1.38777878e-17+0.j]],
 
        [[1.11022302e-16+0.j]],
 
        [[1.11022302e-16+0.j]]]))

In [221]:
B = ghz(src=0, tgt=[1,2,3], bellstate=0, nwires=4)
M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[ 0.06245469+0.j]],
 
        [[ 0.39410024+0.j]],
 
        [[ 0.2657979 +0.j]],
 
        [[-0.2183963 +0.j]],
 
        [[ 0.01837139+0.j]],
 
        [[ 0.02608205+0.j]]]),
 array([[[0.28085099+0.j]],
 
        [[0.37572885+0.j]],
 
        [[0.23971586+0.j]]]))

In [222]:
B = ghz(src=2, tgt=[0,1,3], bellstate=3, nwires=4)
M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[ 0.06245469+0.j]],
 
        [[ 0.39410024+0.j]],
 
        [[ 0.2657979 +0.j]],
 
        [[ 0.2183963 +0.j]],
 
        [[-0.01837139+0.j]],
 
        [[-0.02608205+0.j]]]),
 array([[[-0.15594161+0.j]],
 
        [[ 0.41247163+0.j]],
 
        [[ 0.29187995+0.j]]]))

In [223]:
B = kron(bell(src=0, tgt=1, bellstate=0, nwires=2, op=qml.CZ), bell(src=0, tgt=1, bellstate=0, nwires=2, op=qml.CZ)) # Bell0 using CZ on [0,1] and [2,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.53881275+0.j]],
 
        [[ 0.3194031 +0.j]],
 
        [[-0.1619899 +0.j]],
 
        [[-0.14426168+0.j]],
 
        [[ 0.73252616+0.j]],
 
        [[-0.05320409+0.j]]]),
 array([[[-0.39455107+0.j]],
 
        [[-0.41312306+0.j]],
 
        [[-0.10878582+0.j]]]))

In [224]:
B = bell(src=0, tgt=2, bellstate=0, nwires=4, op=qml.CZ) @ bell(src=1, tgt=3, bellstate=0, nwires=4, op=qml.CZ) # Bell0 using CZ on [0,2] and [1,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.52618419+0.j]],
 
        [[-0.59077723+0.j]],
 
        [[ 0.23428383+0.j]],
 
        [[-0.22005712+0.j]],
 
        [[ 0.03948987+0.j]],
 
        [[-0.03002138+0.j]]]),
 array([[[-0.30612707+0.j]],
 
        [[-0.6302671 +0.j]],
 
        [[ 0.26430521+0.j]]]))

In [225]:
eig_vals, eig_vects = cirq.unitary_eig(U)

In [231]:
B = eig_vects[:,0]

M = (U @ B)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([0.45806   +0.j, 0.21493029+0.j, 0.74224179+0.j, 0.45951492+0.j,
        0.77555546+0.j, 0.44991555+0.j]),
 array([-0.00145492+0.j, -0.56062517+0.j,  0.29232624+0.j]))

In [238]:
B = 1/np.sqrt(2) * eig_vects[:,0] + 1/np.sqrt(2) * eig_vects[:,-1]

M = (U @ B)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([ 0.02529203+0.j,  0.0308524 +0.j,  0.50321821+0.j,  0.52017234+0.j,
         0.2707908 +0.j, -0.09157733+0.j]),
 array([-0.49488031+0.j, -0.2399384 +0.j,  0.59479553+0.j]))

In [236]:
B = 1/np.sqrt(eig_vects.shape[-1]) * np.sum(eig_vects, axis=-1)

M = (U @ B)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([-0.05078311+0.j, -0.23662673+0.j,  0.12823168+0.j,  0.17400391+0.j,
        -0.0772174 +0.j,  0.16059303+0.j]),
 array([-0.22478702+0.j, -0.15940933+0.j, -0.03236136+0.j]))

---

Now try 3 agents

In [240]:
n_agents = 3
d_qubits = 2
n_layers = 1

circuit = MARLCircuit(
    n_agents=n_agents,
    d_qubits=d_qubits,
    n_layers=n_layers,
    )

weight_shapes = circuit.weight_shapes

agents_var_thetas = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_var_thetas'])
agents_enc_inputs = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_enc_inputs'])

print(qml.draw(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs))


# System unitary (all agents).
U = qml.matrix(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs)

0: ─╭VariationalEncodingPQC─┤  
1: ─╰VariationalEncodingPQC─┤  
2: ─╭VariationalEncodingPQC─┤  
3: ─╰VariationalEncodingPQC─┤  
4: ─╭VariationalEncodingPQC─┤  
5: ─╰VariationalEncodingPQC─┤  


In [241]:
obs = [
    kron(Z, I, I, I, I, I),
    kron(I, Z, I, I, I, I),
    kron(Z, Z, I, I, I, I),
    kron(I, I, Z, I, I, I),
    kron(I, I, I, Z, I, I),
    kron(I, I, Z, Z, I, I),
    kron(I, I, I, I, Z, I),
    kron(I, I, I, I, I, Z),
    kron(I, I, I, I, Z, Z),
    ]
n_obs = len(obs)
n_obs_per_agent = int(n_obs / 3)

In [243]:
B = ghz(src=0, tgt=[2,4], bellstate=3, nwires=6) @ ghz(src=1, tgt=[3,5], bellstate=3, nwires=6)

M = (U @ B) @ ket(0,0,0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables (maximum difference between agents 0,1 and 0,2).
diffs = np.array([max(E[i]-E[i+n_obs_per_agent], E[i]-E[i+2*n_obs_per_agent]) for i in range(n_obs_per_agent)])

E, diffs

(array([[[ 6.93889390e-17+0.j]],
 
        [[-5.55111512e-17+0.j]],
 
        [[-5.55111512e-17+0.j]],
 
        [[-1.38777878e-17+0.j]],
 
        [[-6.93889390e-17+0.j]],
 
        [[-1.38777878e-17+0.j]],
 
        [[-4.16333634e-17+0.j]],
 
        [[-5.55111512e-17+0.j]],
 
        [[ 1.04083409e-17+0.j]]]),
 array([[[ 1.11022302e-16+0.j]],
 
        [[ 1.38777878e-17+0.j]],
 
        [[-4.16333634e-17+0.j]]]))