In [1]:
from functools import cache
from collections import defaultdict
from collections import Counter
from typing import Tuple # TODO: See pbs. Should we add type annotations?

import math
import numpy as np
from numpy import sqrt, exp, cos, sin, cosh, sinh, conj
from numpy.polynomial import Polynomial as P
from scipy.special import factorial, comb
from permanent import permanent

import matplotlib.pyplot as plt
import mpl_scatter_density  # noqa
from matplotlib.offsetbox import AnchoredText

plt.style.use('seaborn')

# Notes

#### This is some convention used in the code below for variable names:

`gamma` = measurement threshold

`r` = squeezing strength

`N` = number of samples

`n` = number of particles

`m` = number of modes

`M` = dimension of Bosonic Hilbert space; $M := \binom{n + m - 1}{n}$

Most matrices are `m` x `N`, meaning that there is a row for each mode, and each column is one sample.

#### Section headers follow the following indentation:

\# Main Section Header

\#\# Subsection Header

\#\#\# Large group of related functions/code/markdown

\#\#\#\# Single function or tightly related code/markdown blocks

# Simulation Model: Threshold Detectors and Squeezed States

### Measurement

Detectors take in matrices of `N` samples and return boolean arrays of detection events. 

In [2]:
DEFAULT_MEAS_THRESH = 1.9494

#### `threshold_detector`
Standard many-mode threshold detector. Imagines every mode has its own detector which independently clicks if the amplitude on that mode exceeds the measurement threshold.

In [3]:
def threshold_detector(a, gamma=DEFAULT_MEAS_THRESH):
    '''Clicks if an amplitude exceeds measurement threshold.
    a : m x N complex matrix = [a_11 ... a_1N; ...; a_m1 ... a_mN]
        Each column is a sample with m components (modes)
    gamma : double = measurement threshold
    output : m x N bool array = [b_11 ... b_1N; ...; b_m1 ... b_mN]
        b_ij := (a_ij > gamma)'''

    return abs(a) > gamma

#### `h_or_v_detector`

Effectively the same as `threshold_detector(a, gamma).any(axis=0)` limited to matrices with two rows.

In [4]:
def h_or_v_detector(a: np.ndarray, gamma=DEFAULT_MEAS_THRESH):
    '''Clicks if horizontal or vertical polarization component of sample exceeds measurement threshold.
    a : 2 x N complex matrix = [h_1 ... h_n; v_1 ... v_n]
        Each column is a sample [h_i; v_i] with a horizontal and vertical polarization component
    gamma : double = measurement threshold
    output : 1 x N bool array = [b_1 ... b_n]
        b_i := (|h_i| > gamma) or (|v_i| > gamma)'''
        
    if a.shape[0] != 2:
        raise ValueError("Sample array 'a' must have 2 rows.")

    return np.bitwise_or(abs(a[0,:]) > gamma, abs(a[1,:]) > gamma)

#### `norm_detector`

Not currently used anywhere.

In [5]:
def norm_detector(a: np.ndarray, gamma=DEFAULT_MEAS_THRESH):
    '''Clicks if norm of sample exceeds measurement threshold.
    a : m x N complex matrix = [s_1 ... s_N]
        Each column is a sample s_i with m components (modes)
    gamma : double = measurement threshold
    output : 1 x N bool array = [b_1 ... b_n]
        b_i := (||s_i|| > gamma)'''
    return np.linalg.norm(a, axis=1) > gamma

#### `get_coincidence_count`

In [6]:
def get_coincidence_count(detections: np.ndarray):
    '''Returns the number of times all provided detectors simultanously clicked.
    detections : m x N bool matrix
    output : int = number of all-one columns in detections matrix'''
    return np.count_nonzero(np.sum(detections, axis=0) == detections.shape[0])

#### `get_all_coincidence_counts`

In [7]:
def get_all_coincidence_counts(detections: np.ndarray):
    '''Counts the number of times every possible detection event occurs
    A detection event is a bitstring with 1's at the indices of all detectors that 
    clicked and 0's at the indices of all detectors which did not.
    Returns a Counter object with detection events (as tuples) mapped to counts.
    Counters behave like dicts.
    detections : m x N bool matrix
    output : Dict[tuple, int] 
        Example: if output[(1,0,0)] == 21, then the event (1,0,0), aka
        "only the detector on mode 0 clicked," happened 21 times.'''
    return Counter(map(tuple, detections.T))

#### `print_all_coincidence_counts`

Exponential runtime in m, the number of modes. Linear runtime in N, the number of samples. 

Should only be used on inputs with few modes. Can be used when there are many samples.

In [8]:
def print_all_coincidence_counts(detections: np.ndarray):
    coincidence_counts = get_all_coincidence_counts(detections)
    m = detections.shape[0]
    for i in range(2**m):
        bitvector = np.frombuffer(np.binary_repr(i, width=m).encode(), dtype='S1').astype(int)
        detectors = np.binary_repr(i, width=m) # ', '.join(np.flatnonzero(bitvector).astype(str))
        print(detectors, coincidence_counts[tuple(bitvector)])

#### `trials_with_outcome`

In [9]:
def trials_with_outcome(outcome_tup, detections):
    '''Returns a 1 x N boolean array which is true at 
    index i iff the ith column of detections equals outcome.
    Caution: only works on antibunching states (boolean valued Fock states)
    outcome_tup : m x 1 bool tuple
    detections : m x N bool matrix
    output : 1 x N bool row vector
        true at index i iff the ith column of detections equals outcome'''
    outcome = np.array([outcome_tup], dtype=bool).T
    return np.all(np.logical_not(np.logical_xor(detections, outcome)), axis=0)

#### `trials_with_outcome_in`

In [10]:
def trials_with_outcome_in(outcomes, detections):
    '''Returns a 1 x N boolean array which is true at 
    index i iff the ith column of detections equals outcome.
    outcomes : list of m x 1 bool column vectors
    detections : m x N bool matrix
    output : 1 x N bool row vector
        true at index i iff the ith column of detections is in outcomes'''
    prev = np.zeros(detections.shape[0], dtype=bool)
    for outcome in outcomes:
        matching_trials = trials_with_outcome(outcome, detections)
        prev = np.logical_or(prev, matching_trials)
    return prev

### State Preparation

#### `zpf`

In [11]:
def zpf(N: int, m=1, sigma=1/sqrt(2)):
    '''Zero-point field
    N is the number of samples
    m is the number of modes
    sigma is the standard deviation of the quantum noise
    sigma = 1/sqrt(2) corresponds to the vacuum state
    sigma > 1/sqrt(2) corresponds to a thermal state
    sigma = 0 corresponds to a classical (i.e., no ZPF) state
    
    Returns an m x N matrix of complex Gaussian random variables.'''
    
    return sigma * (np.random.normal(size=(m,N)) + 1j * np.random.normal(size=(m,N))) / sqrt(2)

#### `laser`

In [12]:
def laser(N: int, alphaH=1, alphaV=0):
    '''Laser
    alphaH and alphaV are each complex numbers
    
    Returns a 2 x N matrix of N samples of coherent light.
        Each sample has a horizontal and vertical polarization component.
        The output is identical to a zpf with the horizontal component
        offset by alphaH and the vertical component offset by alphaV.'''

    # scale of vacuum fluctations
    sigma0 = 1/sqrt(2)

    # input random variables for the entanglement source
    zH = zpf(N, 1, sigma0)
    zV = zpf(N, 1, sigma0)

    aH = alphaH + zH
    aV = alphaV + zV

    return np.concatenate((aH, aV))

#### `ent`

In [13]:
def ent(N: int, r: float, t=1, phase=0):
    '''Entanglement source
    N is the number of samples
    r is the squeezing strength (non-negative)
    t (type) is 1 or 2
    phase is in degrees.
    
    Returns a tuple of two 2 x N matrices representing entangled
        light with horizonal and vertical polarization modes.'''

    # convert degrees to radians
    phase = phase * np.pi/180
    
    #input random variables for the entanglement source
    z1H = zpf(N, 1)
    z1V = zpf(N, 1)
    z2H = zpf(N, 1)
    z2V = zpf(N, 1)

    if t == 1:
        aH = cosh(r)*z1H + sinh(r)*conj(z2H)
        aV = cosh(r)*z1V + exp(1j*phase)*sinh(r)*conj(z2V)
        bH = cosh(r)*z2H + sinh(r)*conj(z1H)
        bV = cosh(r)*z2V + exp(1j*phase)*sinh(r)*conj(z1V)
    elif t == 2:
        aH = cosh(r)*z1H + sinh(r)*conj(z2V)
        aV = cosh(r)*z1V + exp(1j*phase)*sinh(r)*conj(z2H)
        bH = cosh(r)*z2H + exp(1j*phase)*sinh(r)*conj(z1V)
        bV = cosh(r)*z2V + sinh(r)*conj(z1H)
    else:
        print(f"{t} is not a valid type.")
        return

    return np.concatenate((aH, aV)), np.concatenate((bH, bV))

### Gates 

Filters, Waveplates, and Beamsplitters

#### `ndf` (neutral density filter)

In [14]:
def ndf(a: np.ndarray, d=10):
    '''Neutral density filter
    d is the optical density (a non-negative number)'''

    if d < 0:
        raise ValueError("Optical density d must be non-negative.")

    return 10**(-d/2) @ a + (1-10**(-d/2)) @ zpf(a.shape[1], 2)

#### `hwp`

In [15]:
def hwp(a: np.ndarray, theta=0):
    '''Half-wave plate
    a is a 2 x N complex matrix
    theta is the fast-axis angle in degrees'''

    # convert degrees to radians
    theta = theta * np.pi/180

    u = np.array([
        [cos(2*theta), sin(2*theta)], 
        [sin(2*theta), -cos(2*theta)]])
        
    return u @ a

#### `qwp`

In [16]:
def qwp(a: np.ndarray, theta=0):
    '''Quarter-wave plate
    a is a 2 x N complex matrix
    theta is the fast-axis angle in degrees'''

    # convert degrees to radians
    theta = theta * np.pi/180

    u = np.array([
        [cos(theta)^2 + 1j*sin(theta)^2, (1-1j)*cos(theta)*sin(theta)], 
        [(1-1j)*cos(theta)*sin(theta), sin(theta)^2 + 1j*cos(theta)^2]])
        
    return u @ a

#### `polarizer`

In [17]:
def polarizer(a: np.ndarray, theta=0, phi=0):
    '''Polarizer
    a is a 2 x N complex matrix
    theta, phi are in degrees'''

    # convert degrees to radians
    theta = theta * np.pi/180
    phi = phi * np.pi/180

    # make projector p
    bra = np.array([[cos(theta)], [exp(1j*phi)*sin(theta)]])
    p = bra @ bra.T

    return p @ a + (np.eye(2) - p) @ zpf(a.shape[1], 2)

#### `bs`

In [18]:
def bs(a: np.ndarray, b: np.ndarray, r=1/sqrt(2)):
    '''Beam splitter
    a and b are each 2 x N complex matrices
    r is the reflectance (0 <= r <= 1)'''

    #TODO: where are the defaults / vacuum state logic? (see polarizing beam splitter)

    t = sqrt(1-r^2)

    out = np.kron(np.array([[t, r], [r, -t]]), np.eye(2)) @ np.concatenate((a, b))

    return out[0:2,:], out[2:4,:]

    # TODO: why not just 
    # u = np.array([[t, r], [r, -t]])
    # return u @ a, u @ b

#### `pbs`

In [19]:
HV = np.array([[1, 0, 0, 0], 
               [0, 0, 0, 1], 
               [0, 0, 1, 0], 
               [0, 1, 0, 0]])

DA = np.array([[1, 1, 1, -1],
               [1, 1, -1, 1], 
               [1, -1, 1, 1], 
               [-1, 1, 1, 1]])

RL = np.array([[1, -1j, 1, 1j], 
               [1j, 1, -1j, 1],
               [1, 1j, 1, -1j],
               [-1j, 1, 1j, 1]])

In [20]:
def pbs(
    a: np.ndarray = None, 
    b: np.ndarray = None, 
    basis=HV
    ) -> Tuple[np.ndarray, np.ndarray]:
    """Polarizing beam splitter
    a and b are each 2 x N complex matrices
    basis is a 4 x 4 complex matrix"""

    if a is None and b is None:
        raise ValueError("At least one input beam must be specified.")
    elif a is None:
        a = zpf(b.shape[1], 2)
    elif b is None:
        b = zpf(a.shape[1], 2)
        
    out = basis @ np.concatenate((a, b))
    
    return out[0:2,:], out[2:4,:]

# Simulation Model: QM

### Fock States

#### `fock_dim`

In [21]:
def fock_dim(n, m):
    '''Returns M := (n + m - 1) Choose n'''
    return comb(n + m - 1, n)

#### `is_fock_basis_state`

In [22]:
def is_fock_basis_state(tup):
    return isinstance(tup, tuple) and all([(isinstance(i,int) and i >= 0) for i in tup])

#### `is_fock_state`

In [23]:
def is_fock_state(amp_dict):
    '''Checks whether a dict of Fock basis state tuples to amplitudes is a valid Fock state.'''
        
    if not isinstance(amp_dict, dict):
        return False
        
    keys = list(amp_dict.keys())
    if not isinstance(keys[0], tuple) or not is_fock_basis_state(keys[0]):
        return False
    m = len(keys[0])
    n = sum(keys[0])
        
    total_pr = 0
    for i,tup in enumerate(keys):
        if not (isinstance(tup, tuple) 
                and is_fock_basis_state(tup)
                and len(tup) == m
                and sum(tup) == n
               ):
            return False
        total_pr += amp_dict[tup] * conj(amp_dict[tup])
                
    return math.isclose(total_pr, 1)

#### `get_fock_basis_states`

This function was written recursively with memoization (see `@cache`) in order to make it faster. The tradeoff is that it rapidly fills up computer memory with cached data.

Even with memoization, this function's runtime is worse than exponential in n, the number of particles. 

Use sparingly, and only on inputs with small n.


In [24]:
def get_fock_basis_states(n, m):
    '''Returns a python array of all (n + m - 1 Choose n) Fock basis 
    states (as tuples) with n particles in m modes.'''

    # These checks are a loop invariant of the 
    # loop in the helper, so we only need to check 
    # them before starting the loop.
    if not isinstance(n, int) or n < 0:
        raise ValueError("n must be a non-negative integer.")
    if not isinstance(m, int) or m < 1:
        raise ValueError("m must be a positive integer.")
        
    return _get_fock_basis_states(n, m)

# TODO: unroll recursion into loop
@cache
def _get_fock_basis_states(n, m):
    '''This helper should never be called outside of get_fock_basis_states. 
    See get_fock_basis_states.'''

    if m == 1:
        return [(n,)]
    if n == 1:
        return list(((0,)*(i) + (1,) + (0,)*(m-i-1)) for i in range(m))

    states = []
    for i in range(0,n+1):
        new_states = _get_fock_basis_states(n-i, m-1)
        for s in range(len(new_states)):
            states.append((i,) + new_states[s])
        
    return states

_get_fock_basis_states.cache_clear()

#### `get_antibunching_states`

Similar to the above, but for states where all partical counts are 0 or 1.

In [25]:
def get_antibunching_states(n, m):
    '''Returns a python array of all (m Choose n) antibunching
    states (as tuples) with n particles in m modes.'''
    
    # These checks are a loop invariant of the 
    # loop in the helper, so we only need to check 
    # them before starting the loop.
    if not isinstance(n, int) or n < 0:
        raise ValueError("n must be a non-negative integer.")
    if not isinstance(m, int) or m < 1:
        raise ValueError("m must be a positive integer.")
    if n > m:
        raise ValueError("n must be less than m.")
        
    return _get_antibunching_states(n, m)

# TODO: unroll recursion into loop
@cache
def _get_antibunching_states(n, m):
    '''This helper should never be called outside of get_antibunching_states. 
    See get_antibunching_states.'''

    if m == n:
        return [(1,)*m]
    if n == 1:
        return list(((0,)*(i) + (1,) + (0,)*(m-i-1)) for i in range(m))

    states = []
    for i in range(0, m-n+1):
        new_states = _get_antibunching_states(n-1, m-i-1)
        for s in range(len(new_states)):
            states.append((0,)*(i) + (1,) + new_states[s])
        
    return states

_get_antibunching_states.cache_clear()

### Unitaries

#### `is_unitary`

In [26]:
def is_unitary(U):
    return U.shape == (U.shape[0],U.shape[0]) and np.allclose(np.eye(len(U)), U.dot(U.T.conj()))

#### `direct_sum`

In [27]:
def direct_sum(A, B):
    return np.block([
        [A,                                     np.zeros((A.shape[0], B.shape[1]))],
        [np.zeros((B.shape[0], A.shape[1])),    B                                 ]
    ])

#### `HilbertSpaceUnitary`

Recall 

$$\langle S | \varphi (U) | T \rangle = \frac{Per(U_{S,T})}{\sqrt{s_1! \dots s_m! t_1! \dots t_m!}}$$

See pg 95 of https://www.scottaaronson.com/qisii.pdf for more info.

In [28]:
class HilbertSpaceUnitary(object):
    '''Converts the m x m unitary of a linear optical network to 
    the larger M x M unitary acting on the entire Hilbert space,
    where M := n+m-1 choose n. Lazily computes entries as needed.'''

    def __init__(self, U, n, state_space=None):
        if not is_unitary(U):
            raise ValueError("U must be an m x m unitary.")
        self.U = U
        self.m = U.shape[0]
        self.n = n
        self.state_space = (state_space if state_space else get_fock_basis_states(self.n, self.m))
        self.entries = {}

    def get_entry(self, S, T):
        '''This gives the entry <S|\phi(U)|T>, that is,
        the amplitude of a T to S transition.
        S and T must both be Fock states.'''

        if np.sum(S) != self.n or np.sum(T) != self.n:
            raise ValueError(f"Fock states must have {self.n} particles.")
        if not is_fock_basis_state(S):
            raise ValueError("S must be a Fock state.")
        if not is_fock_basis_state(T):
            raise ValueError("T must be a Fock state.")

        key = (S,T)
        if key not in self.entries:
            U_ST = np.repeat(np.repeat(self.U, S, axis=0), T, axis=1).astype(complex)
            self.entries[key] = ( 
                permanent.permanent(U_ST) 
                / sqrt(np.prod(factorial(S)))
                / sqrt(np.prod(factorial(T)))
            ) # watch out for multiplication overflows on normalization factor
        return self.entries[key]
  
    def __getitem__(self, S_and_T):
        '''Makes object subscriptable.
        self[S,T] is equivalent to self.get_entry(S,T)'''
        
        S,T = S_and_T
        return self.get_entry(S,T)
    
    def apply(self, state, post_selection=None):
        
        if is_fock_basis_state(state):
            state = {state: 1}
        
        if not is_fock_state(state):
            raise ValueError("Can only apply HilbertSpaceUnitary to valid Fock states.")
        if not post_selection:
            post_selection = self.state_space
            
        result = defaultdict(lambda: 0)
        for T in state:
            for S in post_selection:
                result[S] += self.get_entry(S,T)
        return result

#### `get_hilbert_space_unitary_matrix`

Exponential runtime - use only if you need to see the whole hilbert space unitary matrix at once. 

When possible, use HilbertSpaceUnitary to calculate entries lazily instead.

In [29]:
def get_hilbert_space_unitary_matrix(U, n, p=False):
    '''Converts the m x m unitary of a linear optical network to 
    the larger M x M unitary acting on the entire Hilbert space.
    U is an m x m unitary (m is number of modes)
    n is number of particles
    p determines whether or not to print the matrix
    M := n+m-1 choose n'''
    m = U.shape[0]
    basis_states = get_fock_basis_states(n, m)
    hsu = HilbertSpaceUnitary(U, n)
    M = len(basis_states)
    hsu_matrix = np.zeros((M,M), dtype=complex)
    for i in range(M):
        for j in range(M):
            hsu_matrix[i,j] = hsu.get_entry(basis_states[i], basis_states[j])
    if p:
        rounded_hsu = np.around(hsu_matrix,2)
        for i in range(M):
            print(basis_states[i], end=':\t')
            print(rounded_hsu[i])
    return hsu_matrix

#### `qr_haar`

In [30]:
# Source: https://pennylane.ai/qml/demos/tutorial_haar_measure.html

def qr_haar(m):
    """Generate a Haar-random matrix using the QR decomposition.
    m is the number of modes."""
    # Step 1
    A, B = np.random.normal(size=(m, m)), np.random.normal(size=(m, m))
    Z = A + 1j * B

    # Step 2
    Q, R = np.linalg.qr(Z)

    # Step 3
    Lambda = np.diag([R[i, i] / np.abs(R[i, i]) for i in range(m)])

    # Step 4
    return np.dot(Q, Lambda)