In [1]:
import numpy as np
from numpy import *
import qutip
from qutip import *
from numpy.random import *
from qutip.measurement import measure, measurement_statistics
from copy import deepcopy
import matplotlib.pyplot as plt
# from qutip.states import ket2dm, basis

In [2]:
class DecisionTree():

    # this is an M-ary tree of depth N, i.e. it has M**(N+1) - 1 nodes in total
    # each node corresponds to a measurement history,
    # i.e. a list x_j = [k_1, k_2, ... k_j] of measurement outcomes.

    
    # initialize with each node at 0
    def __init__(self, M, N):
        self.M = M
        self.N = N
        self.alpha = 0
        if N == 1:
            self.children = None
        else:
            self.children = [DecisionTree(M, N-1)]*M
        #self.tree = self.makeRandTree(M, N)

    # get the value of alpha corresponding to some measurement history x (partial or complete)
    def get_alpha(self, x):
        r = self
        for j in x:
            r = r.children[j]
            
        return r.alpha

    # adds a normally-distributed term to each node in the tree
    # make sure to take a deepcopy of the tree, so the original tree is preserved
    # i may revert to the former tree, depending on the value of the loss function
    def tweak(self, std):
        mean = [0, 0]
        std_matrix = std*np.array([[1, 0], [0, 1]])
        noise = normal(mean, std_matrix)
        self.alpha = self.alpha + noise[0][0] + 1j*noise[1][1]
        if self.children != None:
            for j in self.children:
                j.tweak(std)

    # converts a complete measurement history (N elements) to an integer, in an
    # unambiguous and reversible manner, so that i can use it as a key for a dict
    # this hashing protocol only works on complete histories
def xhash(x, M):
    h = 0
    for j in range(len(x)):
        h += x[j] * M**j
    return h

    # inverse of xhash()
def unhash(h, M, N):
    x = []
    j = 0
    for j in range(N):
        rem = h % M**(j+1)
        x.append(int(rem/(M**j)))
        h -= rem
    return x


In [3]:
# given M and N, gives a list of length M**N, such that each element is a list of length N, containing
# ints between 0 and M-1
# gives every possible complete measurement history
def make_xs(M, N) -> list:
    if N == 1:
        return [[j] for j in range(M)]
    else:
        xsm = make_xs(M, N-1)
        res = []
        for xm in xsm:
            for j in range(M):
                xm.append(j)
                res.append(xm.copy())
                xm.pop()
        return res
        
# takes a decision tree, a dict of input states, and a code key (an int pointing to a specific input state)
# yields the measurement history, hashed
# this is a pseudorandom function
def process(tree: DecisionTree, code: dict, code_index: int, sigma: int) -> int:

    N = tree.N
    M = tree.M
    projs = [qutip.states.ket2dm(qutip.states.basis(M, j)) for j in range(M)] # single-mode projectors, |j><j|

    x = [] # x is the measurement history
    #state = ket2dm(coherent(M, code[code_index]))
    #state = state.depolarize(
    for _ in range(N):
        alpha = tree.get_alpha(x) # parameter for displacement D(\alpha)
        gamma = normal(0, sigma) # stochastic variable from a normal distribution, mean 0 and std sigma
        # sigma = 0 is valid input (it just always yields the mean)
        state = coherent(M, code[code_index]/sqrt(N)*exp(1j*gamma) + alpha)
        # this is the coherent state |\beta_i/sqrt(N) + \alpha>
        [outcome, _] = measure(state, projs) # measure in the (truncated) number basis
        
        x.append(outcome)

    return xhash(x, M) # convert history to int

# takes a decision tree, and runs the circuit Q times with the tree.
# generates the decision table (maps measurement history -> prediction)
# yields the loss (total number of mispredictions)
# inb4 is this loss? (yes)
def loss(tree, code, Q, priors, sigma):
    M = tree.M
    N = tree.N
    codewords = [randint(len(code)) for _ in range(Q)] # list of input states
    table = {}
    meas = {}
    outcomes = []
    xs = [xhash(x, M) for x in make_xs(M, N)]
    for j in xs:
        meas[j] = [0]*len(code)
        # meas is a dict of lists. each entry corresponds to a measurement outcome,
        # and each entry thereof to an input state
        # each entry thereof counts how many times this input-output pair was observed
            
    for j in range(Q):
        #code_in = code[randint(len(code))]
        outcome = process(tree, code, codewords[j], sigma) # run the circuit and obtain outcome
        meas[outcome][codewords[j]] += 1
        outcomes.append(outcome)
        # the outcome is recorded in two different ways

    for x in xs:
        posteriors = [meas[x][j]*priors[j] for j in range(len(code))]
        # this is the (unnormalized) prob. P(x|y), where y is input state and x is measurement outcome
        
        
        # this is the input codeword that maximizes the likelihood of the given measurement
        table[x] = posteriors.index(max(posteriors)) # if there are shared first-places, it just picks one of them
        # probably the first in the list (lowest code index)

    loss = 0

    for j in range(Q):
        pred = table[outcomes[j]] # what does the decision table believe the input state is?
        error = (pred != codewords[j]) # is the decision table wrong?
        loss += error
    return loss # total number of misidentifications



M = 2# dimension of Hilbert space. This is a truncated Fock state, with the basis (|0>, |1>, ... |M-1>).
# 
N = 5 # depth of circuit, including the root node. This is also the number of detectors.
beta = 5e-1 # this is a small parameter to ensure the coherent states are "weak"
Q = 1000 # number of circuit runs per tree
# parameters of some coherent states to distinguish between. I will generally try to handle/store only the indices
# (0, 1 ...) when possible because they're lighter than complex values.
code = {
    0: beta,
    1: beta*1j,
    2: -beta,
    3: -beta*1j
}

# prior probabilities of the input states. just a uniform for now
priors = {}
for j in range(len(code)):
    priors[j] = 1/len(code)



In [4]:
def optimize(sigma):
    Ni = 50 # number of trees to generate (minus 1)
    # equivalently, number of iterations to run the optimization
    
    print("starting")
    tree = DecisionTree(M, N)
    cur_loss = loss(tree, code, Q, priors, sigma)
    accuracies = [1-cur_loss/Q]
    accepts = []
    counter = 1
    for _ in range(Ni):
        new_tree = deepcopy(tree)
        new_tree.tweak(0.1) # add gaussian term to each node in tree
        new_loss = loss(new_tree, code, Q, priors, sigma)
        accuracies.append(1-new_loss/Q)
        if new_loss < cur_loss: # if the new tree is better than the existing, it replaces it.
            # subsequent random steps are taken from this as starting position
            tree = deepcopy(new_tree)
            cur_loss = new_loss
            accepts.append(1)
    
        else:
            accepts.append(0)
        print(f"{counter}/{Ni}")
        counter += 1
    
    print("done!")
    plt.figure(1)
    plt.plot(list(range(Ni+1)), accuracies, "r.")
    plt.hlines(1/len(code), plt.xlim()[0], plt.xlim()[1], linestyle = "--", color = "k", label = "random guess")
    plt.legend()
    plt.xlabel("iteration")
    plt.ylabel("accuracy")
    plt.title(f"Training a tree on std {sigma}")
    
    plt.figure(2)
    plt.plot(list(range(Ni)), accepts, 'b.')
    return tree

In [5]:
def test(tree, sigma):
    Nt = 50
    accuracies = []
    for _ in range(Nt):
        accuracies.append(1-loss(tree, code, Q, priors, sigma)/Q)

    plt.figure()
    plt.plot(list(range(Nt)), accuracies, 'g.')
    plt.xlabel("iteration")
    plt.ylabel("accuracy")
    plt.title(f"Testing a tree on std {sigma}")
    return sum(accuracies)/len(accuracies)

In [6]:
# this is a tree optimized for a noisefree system
# tree_pure, accuracies_pure = optimize(0)

In [None]:
Ns = 2
sigmas = np.linspace(0, 0.08, Ns)
print(sigmas)

trees = []
for s in sigmas:
    trees.append(optimize(s))

accs = []
for j in range(Ns):
    accs.append([])
    for k in range(Ns):
        accs[j].append(test(trees[j], sigmas[k]))

fig, ax = plt.subplots()
plt.imshow(accs)
plt.colorbar()
plt.ylabel("Training std")
plt.xlabel("Testing std")
ax.set_xticks(range(len(sigmas)), labels=sigmas)
ax.set_yticks(range(len(sigmas)), labels=sigmas)



[0.   0.08]
starting
1/50
2/50


In [None]:
a = [[1, 4], [0, 1]]
b = [3, 4]
fig, ax = plt.subplots()
plt.imshow(a)
ax.set_xticks(range(len(b)), labels=b)
ax.set_yticks(range(len(b)), labels=b)
plt.xlabel("Training sigma")

plt.colorbar()