In [59]:
import numpy as np
from numpy import *
import qutip
from qutip import *
from numpy.random import *
from qutip.measurement import measure, measurement_statistics
from copy import deepcopy
# from qutip.states import ket2dm, basis

In [60]:
class Node():
    def get_alpha(self):
        return self.alpha
    def get_children(self):
        return self.children

class DecisionTree():

    # this is an M-ary tree of depth N, i.e. it has M**(N+1) - 1 nodes in total
    # each node corresponds to a measurement history,
    # i.e. a list x_j = [k_0, k_1, ... k_j] of measurement outcomes.

    
    # initialize as as random. each node has value between 0 and 1/2 + i
    def __init__(self, M, N):
        self.M = M
        self.N = N
        self.alpha = 0.5*rand() + 1j*rand()
        if N == 1:
            self.children = None
        else:
            self.children = [DecisionTree(M, N-1)]*M
        #self.tree = self.makeRandTree(M, N)



    # first element is value at this location, the second is a list (the children)
    def makeRandTree(self, M, N):
        if N == 1:
            return [0.2*rand()]*M
        else:
            return [0.2*rand(), self.makeRandTree(M, N-1)] * M

    def get_alpha(self, x):
        r = self
        for j in x:
            r = r.children[j]
            
        return r.alpha

    def loss(self, code, priors):
        return 0

    # adds a normally-distributed term to each node in the tree
    # make sure to take a deepcopy of the tree, so the original tree is preserved
    # i may revert to the former tree, depending on the value of the loss function
    def tweak(self):
        mean = [0, 0]
        std = 0.01*np.array([[1, 0], [0, 1]])
        noise = normal(mean, std)
        self.alpha = self.alpha + noise[0][0] + 1j*noise[1][1]
        if self.children != None:
            for j in self.children:
                j.tweak()
            
    # this hashing protocol only works on complete histories
def xhash(x, M):
    h = 0
    for j in range(len(x)):
        h += x[j] * M**j
    return h

def unhash(h, M, N):
    x = []
    j = 0
    for j in range(N):
        rem = h % M**(j+1)
        x.append(int(rem/(M**j)))
        h -= rem
    return x

In [79]:
# given M and N, gives a list of length M**N, such that each element is a list of length M, containing
# ints between 0 and M-1
# gives every possible measurement set
def make_xs(M, N):
    if N == 1:
        return [[j] for j in range(M)]
    else:
        xsm = make_xs(M, N-1)
        res = []
        for xm in xsm:
            for j in range(M):
                xm.append(j)
                res.append(xm.copy())
                xm.pop()
        return res
        
# takes a decision tree, and an input codeword. yields the measurement history, hashed
def process(tree, code, codeword):

    N = tree.N
    M = tree.M
    projs = [qutip.states.ket2dm(qutip.states.basis(M, j)) for j in range(M)] # single-mode projectors, |j><j|

    tapped = coherent(M, code[codeword]/N)
    x = []
    for _ in range(N):
        alpha = tree.get_alpha(x)
        state = displace(M, alpha) @ tapped
        [outcome, _] = measure(state, projs)
        x.append(outcome)

    return xhash(x, M)

# takes a decision tree, and Q
# runs the circuit Q times with tree, and tells
# inb4 is this loss? (yes)
def loss(tree, code, Q, priors):
    M = tree.M
    N = tree.N
    codewords = [randint(len(code)) for _ in range(Q)]
    table = {}
    meas = {}
    xs = [xhash(x, M) for x in make_xs(M, N)]
    for j in xs:
        meas[j] = [0]*len(code)
            
    for j in range(Q):
        #code_in = code[randint(len(code))]
        outcome = process(tree, code, codewords[j])
        meas[outcome][codewords[j]] += 1

    loss = 0
    for x in xs:
        posteriors = [meas[x][codewords[j]]*priors[j] for j in range(len(code))]
        # this is the (unnormalized) prob. P(x|y), where y is input state and x is measurement outcome

        
        # this is the input codeword that maximizes the likelihood of the given measurement
        table[x] = posteriors.index(max(posteriors)) # if there are shared first-places, it just picks one of them
        # probably the first in the list (lowest code index)

    loss = 0
    for x in xs:
        for i in range(len(code)):
            error = (table[x] != i)
            prob = sum(meas[x])/Q
            prior = priors[i]
            print(f"error: {error}")
            print(f"prob: {prob}")
            print(f"prior: {prior}")
            print(error*prob)
            print(loss)
            loss += error*prob*prior
    return loss


In [83]:

M = 3# dimension of Hilbert space. This is a truncated Fock state, with the basis (|0>, |1>, ... |M-1>).
# Hilbert space has dimension 2**M
N = 4 # depth of circuit, including the root node
beta = 5e-3 # this is a small parameter to ensure the coherent states are "weak"
Q = 200 # number of circuit runs per tree
# parameters of six coherent states to distinguish between. I will generally try to handle/store only the indices
# (0, 1 ...) when possible because they're lighter than complex values.
code = {
    0: beta*exp(1j * atan(0.5)),
    1: beta*exp(-1j * atan(0.5)),
    2: 0,
    3: 1 + 1j,
    4: -1
}

# prior probabilities of the input states. just a uniform for now
priors = {}
for j in range(len(code)):
    priors[j] = 1/len(code)

meas1 = [0, 2, 1]
h = hash(meas1, M)
meas2 = unhash(h, M, N)
print(f"meas1: {meas1}")
print(f"hash: {h}")
print(f"meas2: {meas2}")
print(f"{meas1 == meas2}")

tree1 = DecisionTree(M, N)
loss1 = loss(tree1, code, Q, priors)
tree2 = deepcopy(tree1)
tree2.tweak()
loss2 = loss(tree2, code, Q, priors)
print(f"loss 1: {loss1}")
print(f"loss 2: {loss2}")



print(make_xs(M, 1))
print(make_xs(M, 2))




meas1: [0, 2, 1]
hash: 15
meas2: [0, 2, 1, 0]
False
error: True
prob: 0.19
prior: 0.2
0.19
0
error: True
prob: 0.19
prior: 0.2
0.19
0.038000000000000006
error: True
prob: 0.19
prior: 0.2
0.19
0.07600000000000001
error: False
prob: 0.19
prior: 0.2
0.0
0.11400000000000002
error: True
prob: 0.19
prior: 0.2
0.19
0.11400000000000002
error: True
prob: 0.085
prior: 0.2
0.085
0.15200000000000002
error: True
prob: 0.085
prior: 0.2
0.085
0.16900000000000004
error: True
prob: 0.085
prior: 0.2
0.085
0.18600000000000005
error: True
prob: 0.085
prior: 0.2
0.085
0.20300000000000007
error: False
prob: 0.085
prior: 0.2
0.0
0.22000000000000008
error: True
prob: 0.01
prior: 0.2
0.01
0.22000000000000008
error: True
prob: 0.01
prior: 0.2
0.01
0.2220000000000001
error: False
prob: 0.01
prior: 0.2
0.0
0.2240000000000001
error: True
prob: 0.01
prior: 0.2
0.01
0.2240000000000001
error: True
prob: 0.01
prior: 0.2
0.01
0.2260000000000001
error: True
prob: 0.035
prior: 0.2
0.035
0.2280000000000001
error: True
pro

In [None]:
print(range(0, 3))

In [18]:
print(list(range(0, 3)))

[0, 1, 2]
