In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from eqmarl import *
import pennylane as qml
import numpy as np
import functools as ft

2023-12-05 10:58:32.927175: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
def expectedValueOfObservable(psi, obs):
    return np.transpose(np.conjugate(psi)) @ obs @ psi

def kron(*args):
    """Kronecker product of a list of elements."""
    return ft.reduce(np.kron, args)

In [4]:
def dagger(x):
    """Conjugate transpose."""
    return (x.conjugate()).transpose()

def ket(*x):
    """|x0, x1, ...>"""
    assert all(i in (0, 1) for i in x), 'ket only accepts binary 0 or 1'
    x = [np.array([[0], [1]]) if i else np.array([[1], [0]]) for i in x]
    x = kron(*x)
    return x

def bra(*x):
    """<x0, x1, ...|"""
    assert all(i in (0, 1) for i in x), 'bra only accepts binary 0 or 1'
    x = ket(*x)
    return dagger(x)

In [5]:
def bell(src=0, tgt=1, bellstate: Literal[0,1,2,3] = 0, nwires=2, op=qml.CNOT):
    """Matrix representation for circuit that creates a Bell state on the designated source and target qubits.
    
    Supply the `bellstate` argument to prepare a specific Bell state:
    - 0 == Phi+
    - 1 == Psi+
    - 2 == Phi-
    - 3 == Psi-
    
    The Bell state vector can be recovered by applying the dot product with the |0> state as follows:
    >>> U = bell(0, 1, bellstate=0, nwires=2) # Creates matrix which generates Phi+
    >>> U @ np.array([[1], [0], [0], [0]]) # Dot product with |00>
    array([[0.70710678],
       [0.        ],
       [0.        ],
       [0.70710678]])
    """
    def circuit():
        match bellstate:
            case 1: # 01
                qml.PauliX(wires=tgt)
            case 2: # 10
                qml.PauliX(wires=src)
            case 3: # 11
                qml.PauliX(wires=tgt)
                qml.PauliX(wires=src)
            case _: # 00
                pass
        qml.Hadamard(wires=src)
        op(wires=[src, tgt])
    return qml.matrix(circuit, wire_order=range(nwires))()


def ghz(src: int = 0, tgt: list[int] = 1, bellstate: Literal[0,1,2,3] = 0, nwires=2, op=qml.CNOT):
    
    if isinstance(tgt, int):
        tgt = [tgt]
    
    def circuit():
        if bellstate in (2, 3):
            qml.PauliX(wires=src)
        qml.Hadamard(wires=src)
        for t in tgt:
            if bellstate in (1, 3):
                qml.PauliX(wires=t)
            op(wires=[src, t])
    return qml.matrix(circuit, wire_order=range(nwires))()


In [6]:
# Get matrix representations of Pauli operators and other common gates.
X = qml.matrix(qml.PauliX(wires=0))
Y = qml.matrix(qml.PauliY(wires=0))
Z = qml.matrix(qml.PauliZ(wires=0))
H = qml.matrix(qml.Hadamard(wires=0))
I = qml.matrix(qml.Identity(wires=0))
CNOT = lambda src=0, tgt=1, nwires=2: qml.matrix(qml.CNOT(wires=[src, tgt]), wire_order=range(nwires)) # CNOT(source, target, number_of_wires)
CZ = lambda src=0, tgt=1, nwires=2: qml.matrix(qml.CZ(wires=[src, tgt]), wire_order=range(nwires)) # CZ(source, target, number_of_wires)

In [7]:
# Matrices that create specific bell states on 2 qubits next-neighbor style (source is qubit 0, target is qubit 1).
bell_phi_plus = bell(src=0, tgt=1, bellstate=0, nwires=2)
bell_psi_plus = bell(src=0, tgt=1, bellstate=1, nwires=2)
bell_phi_minus = bell(src=0, tgt=1, bellstate=2, nwires=2)
bell_psi_minus = bell(src=0, tgt=1, bellstate=3, nwires=2)

In [19]:
# O = kron(Z, I)
# O = kron(I, Z)
O = kron(Z, Z)
s = bell_phi_plus@ket(0,0)

expectedValueOfObservable(s, O)

array([[1.]])

In [23]:
O = Z

# s = H @ ket(0)
s = ket(1)

expectedValueOfObservable(s, O)

array([[-1]])

## Circuit Tests

### 2 Agents

In [182]:
n_agents = 2
d_qubits = 2
n_layers = 1

circuit = MARLCircuit(
    n_agents=n_agents,
    d_qubits=d_qubits,
    n_layers=n_layers,
    )

weight_shapes = circuit.weight_shapes

agents_var_thetas = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_var_thetas'])
agents_enc_inputs = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_enc_inputs'])

print(qml.draw(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs))


# System unitary (all agents).
U = qml.matrix(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs)

0: ─╭VariationalEncodingPQC─┤  
1: ─╰VariationalEncodingPQC─┤  
2: ─╭VariationalEncodingPQC─┤  
3: ─╰VariationalEncodingPQC─┤  


In [25]:
obs = [
    kron(Z, I, I, I),
    kron(I, Z, I, I),
    kron(Z, Z, I, I),
    kron(I, I, Z, I),
    kron(I, I, I, Z),
    kron(I, I, Z, Z),
    ]
n_obs = len(obs)
n_obs_per_agent = int(n_obs / 2)

In [218]:
M = U @ ket(0,0,0,0) # |0> initial state

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.18761012+0.j]],
 
        [[ 0.15938534+0.j]],
 
        [[-0.11824798+0.j]],
 
        [[-0.22005712+0.j]],
 
        [[ 0.03948987+0.j]],
 
        [[-0.03002138+0.j]]]),
 array([[[ 0.032447  +0.j]],
 
        [[ 0.11989547+0.j]],
 
        [[-0.0882266 +0.j]]]))

In [219]:
B = kron(bell_phi_plus, bell_phi_plus) # Phi+ on [0,1] and [2,3]
M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.70911625+0.j]],
 
        [[ 0.07179181+0.j]],
 
        [[-0.12724351+0.j]],
 
        [[ 0.45386768+0.j]],
 
        [[ 0.15901161+0.j]],
 
        [[-0.00454907+0.j]]]),
 array([[[-1.16298392+0.j]],
 
        [[-0.08721981+0.j]],
 
        [[-0.12269444+0.j]]]))

In [220]:
B = bell(src=0, tgt=2, bellstate=0, nwires=4) @ bell(src=1, tgt=3, bellstate=0, nwires=4) # Phi+ on [0,2] and [1,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-1.52655666e-16+0.j]],
 
        [[ 2.77555756e-17+0.j]],
 
        [[ 5.55111512e-17+0.j]],
 
        [[-1.66533454e-16+0.j]],
 
        [[-8.32667268e-17+0.j]],
 
        [[-5.55111512e-17+0.j]]]),
 array([[[1.38777878e-17+0.j]],
 
        [[1.11022302e-16+0.j]],
 
        [[1.11022302e-16+0.j]]]))

In [221]:
B = ghz(src=0, tgt=[1,2,3], bellstate=0, nwires=4)
M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[ 0.06245469+0.j]],
 
        [[ 0.39410024+0.j]],
 
        [[ 0.2657979 +0.j]],
 
        [[-0.2183963 +0.j]],
 
        [[ 0.01837139+0.j]],
 
        [[ 0.02608205+0.j]]]),
 array([[[0.28085099+0.j]],
 
        [[0.37572885+0.j]],
 
        [[0.23971586+0.j]]]))

In [222]:
B = ghz(src=2, tgt=[0,1,3], bellstate=3, nwires=4)
M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[ 0.06245469+0.j]],
 
        [[ 0.39410024+0.j]],
 
        [[ 0.2657979 +0.j]],
 
        [[ 0.2183963 +0.j]],
 
        [[-0.01837139+0.j]],
 
        [[-0.02608205+0.j]]]),
 array([[[-0.15594161+0.j]],
 
        [[ 0.41247163+0.j]],
 
        [[ 0.29187995+0.j]]]))

In [223]:
B = kron(bell(src=0, tgt=1, bellstate=0, nwires=2, op=qml.CZ), bell(src=0, tgt=1, bellstate=0, nwires=2, op=qml.CZ)) # Bell0 using CZ on [0,1] and [2,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.53881275+0.j]],
 
        [[ 0.3194031 +0.j]],
 
        [[-0.1619899 +0.j]],
 
        [[-0.14426168+0.j]],
 
        [[ 0.73252616+0.j]],
 
        [[-0.05320409+0.j]]]),
 array([[[-0.39455107+0.j]],
 
        [[-0.41312306+0.j]],
 
        [[-0.10878582+0.j]]]))

In [224]:
B = bell(src=0, tgt=2, bellstate=0, nwires=4, op=qml.CZ) @ bell(src=1, tgt=3, bellstate=0, nwires=4, op=qml.CZ) # Bell0 using CZ on [0,2] and [1,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([[[-0.52618419+0.j]],
 
        [[-0.59077723+0.j]],
 
        [[ 0.23428383+0.j]],
 
        [[-0.22005712+0.j]],
 
        [[ 0.03948987+0.j]],
 
        [[-0.03002138+0.j]]]),
 array([[[-0.30612707+0.j]],
 
        [[-0.6302671 +0.j]],
 
        [[ 0.26430521+0.j]]]))

In [225]:
eig_vals, eig_vects = cirq.unitary_eig(U)

In [231]:
B = eig_vects[:,0]

M = (U @ B)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([0.45806   +0.j, 0.21493029+0.j, 0.74224179+0.j, 0.45951492+0.j,
        0.77555546+0.j, 0.44991555+0.j]),
 array([-0.00145492+0.j, -0.56062517+0.j,  0.29232624+0.j]))

In [238]:
B = 1/np.sqrt(2) * eig_vects[:,0] + 1/np.sqrt(2) * eig_vects[:,-1]

M = (U @ B)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([ 0.02529203+0.j,  0.0308524 +0.j,  0.50321821+0.j,  0.52017234+0.j,
         0.2707908 +0.j, -0.09157733+0.j]),
 array([-0.49488031+0.j, -0.2399384 +0.j,  0.59479553+0.j]))

In [236]:
B = 1/np.sqrt(eig_vects.shape[-1]) * np.sum(eig_vects, axis=-1)

M = (U @ B)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

E, diffs

(array([-0.05078311+0.j, -0.23662673+0.j,  0.12823168+0.j,  0.17400391+0.j,
        -0.0772174 +0.j,  0.16059303+0.j]),
 array([-0.22478702+0.j, -0.15940933+0.j, -0.03236136+0.j]))

#### Try more than 1 layer using more complex observations.

In [50]:
names, obs = zip(*[
    # Agent 1 observations
    ## Z
    ('Z @ I @ I @ I', kron(Z, I, I, I)),
    ('I @ Z @ I @ I', kron(I, Z, I, I)),
    ('Z @ Z @ I @ I', kron(Z, Z, I, I)),
    ('X @ I @ I @ I', kron(X, I, I, I)),
    ## X
    ('I @ X @ I @ I', kron(I, X, I, I)),
    ('X @ X @ I @ I', kron(X, X, I, I)),
    ## Y
    ('Y @ I @ I @ I', kron(Y, I, I, I)),
    ('I @ Y @ I @ I', kron(I, Y, I, I)),
    ('Y @ Y @ I @ I', kron(Y, Y, I, I)),
    ## Z*
    ('Z @ X @ I @ I', kron(Z, X, I, I)),
    ('Z @ Y @ I @ I', kron(Z, Y, I, I)),
    ## X*
    ('X @ Y @ I @ I', kron(X, Y, I, I)),
    ('X @ Z @ I @ I', kron(X, Z, I, I)),
    ## Y*
    ('Y @ X @ I @ I', kron(Y, X, I, I)),
    ('Y @ Z @ I @ I', kron(Y, Z, I, I)),
    # ----
    # Agent 2 observations
    ## Z
    ('I @ I @ Z @ I', kron(I, I, Z, I)),
    ('I @ I @ I @ Z', kron(I, I, I, Z)),
    ('I @ I @ Z @ Z', kron(I, I, Z, Z)),
    ## X
    ('I @ I @ X @ I', kron(I, I, X, I)),
    ('I @ I @ I @ X', kron(I, I, I, X)),
    ('I @ I @ X @ X', kron(I, I, X, X)),
    ## Y
    ('I @ I @ Y @ I', kron(I, I, Y, I)),
    ('I @ I @ I @ Y', kron(I, I, I, Y)),
    ('I @ I @ Y @ Y', kron(I, I, Y, Y)),
    ## Z*
    ('I @ I @ Z @ X', kron(I, I, Z, X)),
    ('I @ I @ Z @ Y', kron(I, I, Z, Y)),
    ## X*
    ('I @ I @ X @ Y', kron(I, I, X, Y)),
    ('I @ I @ X @ Z', kron(I, I, X, Z)),
    ## Y*
    ('I @ I @ Y @ X', kron(I, I, Y, X)),
    ('I @ I @ Y @ Z', kron(I, I, Y, Z)),
])
n_obs = len(obs)
n_obs_per_agent = int(n_obs / 2)

n_agents = 2
d_qubits = 2
n_layers = 5

circuit = MARLCircuit(
    n_agents=n_agents,
    d_qubits=d_qubits,
    n_layers=n_layers,
    )

weight_shapes = circuit.weight_shapes

agents_var_thetas = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_var_thetas'])
agents_enc_inputs = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_enc_inputs'])

print(qml.draw(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs))


# System unitary (all agents).
U = qml.matrix(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs)


B = bell(src=0, tgt=2, bellstate=0, nwires=4) @ bell(src=1, tgt=3, bellstate=0, nwires=4) # Phi+ on [0,2] and [1,3]

M = (U @ B) @ ket(0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)
E = E.flatten()

# Difference between agent observables.
diffs = np.array([E[i]-E[i+n_obs_per_agent] for i in range(n_obs_per_agent)])

# Check if all expected values are near zero.
print(f"{np.all(np.abs(E) < 1e-10)=}")

# Check if all expected values are close to each other.
print(f"{np.all(np.isclose(E, E[0]))=}")

for i in range(len(names)):
    print(f"{names[i]}:\t{E[i]}")

0: ─╭VariationalEncodingPQC─┤  
1: ─╰VariationalEncodingPQC─┤  
2: ─╭VariationalEncodingPQC─┤  
3: ─╰VariationalEncodingPQC─┤  
np.all(np.abs(E) < 1e-10)=True
np.all(np.isclose(E, E[0]))=True
Z @ I @ I @ I:	(-3.0531133177191805e-16+0j)
I @ Z @ I @ I:	(-3.885780586188048e-16+0j)
Z @ Z @ I @ I:	(1.3530843112619095e-16+0j)
X @ I @ I @ I:	(5.551115123125783e-17-1.0408340855860843e-17j)
I @ X @ I @ I:	(4.85722573273506e-17+0j)
X @ X @ I @ I:	(4.649058915617843e-16+0j)
Y @ I @ I @ I:	(-9.71445146547012e-17-2.8376923163697078e-18j)
I @ Y @ I @ I:	(1.942890293094024e-16+0j)
Y @ Y @ I @ I:	(5.551115123125783e-17-2.7755575615628914e-17j)
Z @ X @ I @ I:	(-3.5388358909926865e-16+0j)
Z @ Y @ I @ I:	(-5.551115123125783e-17+0j)
X @ Y @ I @ I:	(-1.6653345369377348e-16+1.3877787807814457e-17j)
X @ Z @ I @ I:	(-3.469446951953614e-17-6.938893903907228e-18j)
Y @ X @ I @ I:	(2.983724378680108e-16+6.938893903907228e-18j)
Y @ Z @ I @ I:	(9.71445146547012e-17-5.807524697896793e-18j)
I @ I @ Z @ I:	0j
I @ I @ 

### 3 Agents

In [240]:
n_agents = 3
d_qubits = 2
n_layers = 1

circuit = MARLCircuit(
    n_agents=n_agents,
    d_qubits=d_qubits,
    n_layers=n_layers,
    )

weight_shapes = circuit.weight_shapes

agents_var_thetas = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_var_thetas'])
agents_enc_inputs = np.random.uniform(0, 2*np.pi, size=weight_shapes['agents_enc_inputs'])

print(qml.draw(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs))


# System unitary (all agents).
U = qml.matrix(circuit, wire_order=circuit.wires)(agents_var_thetas, agents_enc_inputs)

0: ─╭VariationalEncodingPQC─┤  
1: ─╰VariationalEncodingPQC─┤  
2: ─╭VariationalEncodingPQC─┤  
3: ─╰VariationalEncodingPQC─┤  
4: ─╭VariationalEncodingPQC─┤  
5: ─╰VariationalEncodingPQC─┤  


In [241]:
obs = [
    kron(Z, I, I, I, I, I),
    kron(I, Z, I, I, I, I),
    kron(Z, Z, I, I, I, I),
    kron(I, I, Z, I, I, I),
    kron(I, I, I, Z, I, I),
    kron(I, I, Z, Z, I, I),
    kron(I, I, I, I, Z, I),
    kron(I, I, I, I, I, Z),
    kron(I, I, I, I, Z, Z),
    ]
n_obs = len(obs)
n_obs_per_agent = int(n_obs / 3)

In [243]:
B = ghz(src=0, tgt=[2,4], bellstate=3, nwires=6) @ ghz(src=1, tgt=[3,5], bellstate=3, nwires=6)

M = (U @ B) @ ket(0,0,0,0,0,0)

# Expected value of observables w.r.t combined unitary M.
E = expectedValueOfObservable(M, obs)

# Difference between agent observables (maximum difference between agents 0,1 and 0,2).
diffs = np.array([max(E[i]-E[i+n_obs_per_agent], E[i]-E[i+2*n_obs_per_agent]) for i in range(n_obs_per_agent)])

E, diffs

(array([[[ 6.93889390e-17+0.j]],
 
        [[-5.55111512e-17+0.j]],
 
        [[-5.55111512e-17+0.j]],
 
        [[-1.38777878e-17+0.j]],
 
        [[-6.93889390e-17+0.j]],
 
        [[-1.38777878e-17+0.j]],
 
        [[-4.16333634e-17+0.j]],
 
        [[-5.55111512e-17+0.j]],
 
        [[ 1.04083409e-17+0.j]]]),
 array([[[ 1.11022302e-16+0.j]],
 
        [[ 1.38777878e-17+0.j]],
 
        [[-4.16333634e-17+0.j]]]))