In [395]:
import numpy as np
from matplotlib import pyplot as plt
from typing import Literal, List
from functools import reduce
from time import perf_counter

In [396]:
# Define unitary gates

def I(nstates: int=2):
    return np.asmatrix(np.identity(nstates, dtype=float))

def X():
    return np.matrix([
        [0, 1],
        [1, 0],
    ], dtype=float)

def Z():
    return np.matrix([
        [1,  0],
        [0, -1],
    ], dtype=float)

def H():
    return 1/np.sqrt(2) * np.matrix([
        [1,  1],
        [1, -1],
    ], dtype=float)

def Rot_X(theta):
    return np.cos(theta / 2) * I() - 1j * X() * np.sin(theta / 2)

def SWAP():
    return np.matrix([
        [1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
    ], dtype=float)

def CNOT():
    return np.matrix([
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 1, 0],
    ], dtype=float)

In [397]:
class Gate:
    """
    Wrapper class for unitary matrices used as quantum gates.
    Can be applied only on Density Matrices.
    
    Attributes
    ----------
    U: :class:`numpy.matrix`
        Unitary matrix of this gate
    """
    def __init__(self, U: np.matrix):
        self.U = U

    def __call__(self, rho: np.matrix):
        return self.U @ rho @ self.U.H
    
    def __str__(self):
        return str(self.U)
    

# For now Moment is just an alias for Gate, 
# but in the future it will have a different purpose
Moment = Gate  


class QuantumCircuit:
    """
    Class encapsulating a sequence of `Gate`s to be applied on a given initial state.
    
    Attributes
    ----------
    gates: list[:class:`Gate`]
        Sequence of gates
    """

    def __init__(self, gates: List[Gate]):
        self.gates = gates

    def __call__(self, rho):
        return reduce(lambda dm, gate: gate(dm), self.gates, rho)

In [398]:
# Define trace-preserving operations

class DepolarizingChannel:

    def __init__(self, p: float):
        assert 0 <= p <= 1
        self.p = p

    def __call__(self, rho: np.matrix, nstates: int=None):
        nrows, ncols = np.shape(rho)
        assert nrows == ncols, \
            'Depolarizing channel can be applied only to density matrices'
        if nstates is None:
            nstates = nrows
        return self.p * rho + (1 - self.p) / nstates * I(nstates)
    

class PartialTrace:
    """
    Compute the partial trace over a given density matrix.

    Attributes
    ----------
    out_qubits: `list[int]`
        Indices of qubits to be traced out
    """
    
    def __init__(self, out_qubits: list[int]):
        self.out_qubits = out_qubits
    
    def __call__(self, rho: np.matrix):
        raise NotImplementedError()

In [399]:
# Define some additional math operations for matrices

def kron(As: list[np.matrix]):
    """Kronecker product"""
    # NOTE: here, kron with foldr is faster than with foldl (i.e., reduce)
    return reduce(lambda accum, A: np.kron(A, accum), reversed(As), 1)

In [400]:
def comp_state(n: int, nstates: int=2):
    assert 0 <= n < nstates
    psi = np.zeros(nstates, dtype=float)
    psi[n] = 1.0
    return np.asmatrix(psi).T

def bell_state(a: Literal[0, 1]=0, b: Literal[0, 1]=0):
    phi = 1/np.sqrt(2) * np.matrix([1, 0, 0, 1], dtype=float).T
    if b:
        phi = kron([I(), Z()]) @ phi
    if a:
        phi = kron([I(), X()]) @ phi
    return phi

def werner_state(p: float):
    phi_00 = bell_state(0, 0)
    rho_00 = phi_00 @ phi_00.H
    return DepolarizingChannel(p)(rho_00)

In [401]:
def POVM(nstates: int=2):
    """Construct POVM in the computational basis."""
    M = []
    for n in range(nstates):
        psi = comp_state(n, nstates)
        M.append(psi @ psi.H)
    return M

In [402]:
# DEJMPS circuit
epr_channel_fidelity = 0.8  # EPR fidelity
epr1 = werner_state(epr_channel_fidelity)
epr2 = werner_state(epr_channel_fidelity)

# Create ensemble of the two EPR pairs (qubits in order: |1A, 1B, 2A, 2B>)
rho = kron([epr1, epr2])

# Apply U_A, U_B gates on each EPR pair
U_A = Rot_X(np.pi / 2)
U_B = Rot_X(- np.pi / 2)

U = kron([U_A, U_B, U_A, U_B])

rho = U @ rho @ U.H

# Swap qubits 1B, 2A (arange qubits in order: |1A, 2A, 1B, 2B>)
U = kron([I(), SWAP(), I()])
rho = U @ rho @ U.H

# Apply CNOT(1A, 2A) and CNOT(1B, 2B)
U = kron([CNOT(), CNOT()])
rho = U @ rho @ U.H

# Swap back qubits 2A, 1B (rearange qubits in order: |1A, 1B, 2A, 2B>)
U = kron([I(), SWAP(), I()])
rho = U @ rho @ U.H

# Construct POVM for the 2nd EPR pair
M = POVM(nstates=4)
M = [kron([I(4), x]) for x in M]

# Measure values a = b = 0 in qubits 2A and 2B
prob_00 = np.trace(M[0] @ rho)
prob_00 = np.real(prob_00)

rho = (M[0] @ rho @ M[0]) / prob_00

print(prob_00)
# np.real(np.round(rho, 5))

0.41
