In [4]:
import numpy as np
from numpy.linalg import eig
import itertools
import functools
import time
def normalize_complex_arr(a: np.ndarray):
    a_oo = a - a.real.min() - 1j*a.imag.min() # origin offsetted
    return a_oo/np.abs(a_oo).max()
def norm_complex(arr: np.ndarray):
    if len(arr.shape) > 1:
        ret_out = np.zeros_like(arr)
        for i in range(arr.shape[0]):
            ret_out[i,:] = arr[i]/np.sqrt((np.abs(arr[i])**2).sum())  
        return ret_out
    else:
        return arr/np.sqrt((np.abs(arr)**2).sum())  

In [5]:
v = np.random.multivariate_normal(np.zeros(10),np.eye(10)/100, size=(1)).view(np.complex128)
print(np.linalg.norm(v))
np.linalg.norm(norm_complex(v))

0.33548807629082206


0.9999999999999999

In [6]:
sx = np.array([[0, 1], [1, 0]])
sy = np.array([[0, 1j], [-1j, 0]])
sz = np.array([[1, 0], [0, -1]])
basis = np.stack([np.eye(2), sx, sy, sz])

n = 4 # Nb of qubits
J = 4**n # Matches the number of bases
I = 6**n # Matches  R*A = 2^n * 3^n = 6^n
d = R = 2**n # matrix dimension and number of possibilities for R^a_s ({-1, 1}^n)
A = 3**n # Number of possible measurements
npa = np.array
b = npa(list(itertools.product(range(4), repeat=n))) # {I, x, y, z}^n
a = npa(list(itertools.product(range(1, 4), repeat=n))) # {x,y,z}^n
r = npa(list(itertools.product([0, 1], repeat=n))) # {0, 1}^n -> acts like a mask for which bases to select
r_neg = npa(list(itertools.product([-1, 1], repeat=n)))
def projectors(idx_list, r_):
    """
    Returns the P^{a_i}_{s_i} list of projection matrices 
    """
    evs = [eig(basis[i])[1] for i in idx_list]
    selected_evs = np.array([ev[:, r_[i]] for i, ev in enumerate(evs)])
    # print(selected_evs)
    ret = npa([np.outer(np.conj(ev), ev) for ev in selected_evs])
    # print(ret)
    return ret

#### All possible combinations of a (kron product of all permutations)
# Pauli basis for n qubit 
sig_b = npa([functools.reduce(np.kron, (basis[b[i,:], :, :])) for i in range(J)])

### Maybe matches to p_a,s = P(R^a = s) in paper, size: 6^n x 4^n
# For every comb of the bases (j in 0:J), then for every activation of the bases (r[s in S, ])
# Matrix P_{(r,a),b} 
P_rab = np.zeros((I, J))
for j in range(J):
    tmp = np.zeros((R, A))
    for s in range(R): # r_neg[s] = [-1, 1, -1, 1], b[j]=[1, 2, 0, 2], a[l] = [2, 4, 1, 1] 
        for l in range(A): #  r_neg[s, b[j] != 0] = [-1, 1, 1]
            val = np.prod(r_neg[s, b[j] != 0])\
                * np.prod(a[l, b[j] != 0] == b[j, b[j]!=0]) # a[l, b[j] != 0] == b[j, b[j] != 0] <=> [2, 4, 1] != [1, 2, 2] 
            tmp[s,l] = val
    P_rab[:, j] = tmp.flatten(order="F")

### Normally, should correspond to P^a_s in paper, size: 2^n x 3^n flattened
# The projectors matrices 
Pra = []
for j in range(A):
    for i in range(R):
        # print(a[j], r[i])
        # print(projectors(a[j], r[i]))
        # print()
        Pra.append(npa(functools.reduce(np.kron, projectors(a[j], r[i]))).flatten(order="F")[::-1])
    # break

In [7]:
# Pure state
# dens_ma = np.zeros((d,d))
# dens_ma[0,0] = 1
# Mixed state (but in reality pure state, in paper is psi * psi^T)
u = norm_complex(np.random.multivariate_normal(np.zeros(d*2),np.eye(d*2)/100, size=(d)).view(np.complex128))
dens_ma = np.conj(u.T) @ u /d # = \sum_i \gamma_i V_i V_i^T
print(dens_ma.shape)

Prob_ar = np.zeros((A, R)) # Corresponds to Tr(rho \dot P^a_s) 
if n==1:
    for i in range(A):
        for j in range(R):
            Prob_ar[i,j] = dens_ma.flatten(order="F") @ projectors(a[i], r[j])
else:
    for i in range(A):
        for j in range(R):
            Prob_ar[i,j] = np.diag(dens_ma @ npa(functools.reduce(np.kron, projectors(a[i], r[j])))).sum()
Prob_ar = np.real(Prob_ar)
# Nb of times w e repeat the measurements
n_size = 2000
p_ra = np.zeros((R, A)) # = \hat{p}_a,s
for i, x in enumerate(Prob_ar):
    H = np.random.choice(R, n_size, replace=True, p=x) #n_size elements
    out = []
    for s in range(R):
        out.append((H==s).sum()/n_size)
    p_ra[:, i] = out
# Transform matrix to vector form
p_ra1 = p_ra.flatten(order="F")
temp1 = p_ra1 @ P_rab
temp1 = temp1/d

# Calculate coefs rho_b
rho_b = [0] * J
for i in range(J):
    rho_b[i] = temp1[i]/(3**((b[i] == 0).sum()))

# Calculate density using inversion technique
rho_hat = np.zeros((d, d), dtype=np.complex128)
for s in range(J):
    rho_hat += rho_b[s] * sig_b[s]
u_hat = eig(rho_hat)[1]

# renormalize lambda_hat
lamb_til = eig(rho_hat)[0]
lamb_til[lamb_til < 0] = 0
lamb_hat = lamb_til/lamb_til.sum()
lamb_hat

(16, 16)


  Prob_ar[i,j] = np.diag(dens_ma @ npa(functools.reduce(np.kron, projectors(a[i], r[j])))).sum()


array([0.19569545+6.14199396e-18j, 0.16257078+3.34519846e-18j,
       0.14810537-6.62496584e-18j, 0.10953456+7.03510099e-18j,
       0.09494648+9.52868573e-19j, 0.0745722 +6.50006993e-19j,
       0.06343105+6.89826341e-18j, 0.        +0.00000000e+00j,
       0.        +0.00000000e+00j, 0.04447084-4.67739165e-18j,
       0.        +0.00000000e+00j, 0.0043523 +1.83183361e-18j,
       0.03615615-5.91400433e-18j, 0.02889249+1.45866089e-18j,
       0.02171634-8.81432044e-18j, 0.015556  -2.28324462e-18j])

In [5]:
p_ra1.shape

(1296,)

In [6]:
np.ones((3,))[np.newaxis, :].shape

(1, 3)

In [11]:
((np.repeat((npa(Pra).reshape((I, J)) @ tem_can)[:, np.newaxis], p_ra1.shape[0], axis=1))- p_ra1[:, np.newaxis])

array([[ 7.85044841e-03-2.11419424e-18j,  7.85044841e-03-2.11419424e-18j,
         7.85044841e-03-2.11419424e-18j, ...,
         7.85044841e-03-2.11419424e-18j,  7.85044841e-03-2.11419424e-18j,
         7.85044841e-03-2.11419424e-18j],
       [ 3.29262216e-03+1.02999206e-18j,  3.29262216e-03+1.02999206e-18j,
         3.29262216e-03+1.02999206e-18j, ...,
         3.29262216e-03+1.02999206e-18j,  3.29262216e-03+1.02999206e-18j,
         3.29262216e-03+1.02999206e-18j],
       [ 5.61159672e-05-1.57209315e-18j,  5.61159672e-05-1.57209315e-18j,
         5.61159672e-05-1.57209315e-18j, ...,
         5.61159672e-05-1.57209315e-18j,  5.61159672e-05-1.57209315e-18j,
         5.61159672e-05-1.57209315e-18j],
       ...,
       [ 1.53457122e-02-8.67361738e-19j,  1.53457122e-02-8.67361738e-19j,
         1.53457122e-02-8.67361738e-19j, ...,
         1.53457122e-02-8.67361738e-19j,  1.53457122e-02-8.67361738e-19j,
         1.53457122e-02-8.67361738e-19j],
       [-5.50978330e-03+0.00000000e+00j, -5.

In [8]:
rho = np.zeros((d, d))
Te = np.random.standard_exponential(d) # Initial Y_i^0
Id = np.eye(d)
U = u_hat # eigenvectors of \hat(rho) found using inversion, initial V_i^0
Lamb = Te/Te.sum() # gamma^0
U = u_hat # eigenvectors of \hat(rho) found using inversion, initial V_i^0
be = 1

for i in range(10):
    for j in range(d):
        Te_can = Te.copy() 
        Te_can[j] = Te[j] * np.exp(be * np.random.uniform(-0.5, 0.5, 1)) # \tilde(Y)_i = exp(y ~ U(-0.5, 0.5)) Y_i^t-1
        L_can = Te_can/Te_can.sum() # \tilde(gamma)_i = \tilde(Y_i)/sum_j^d(\tilde(Y_j))
        tem_can = (U @ np.diag(L_can) @ np.conj(U.T)).flatten(order="F") # gamma * U * U^T (U = V in paper)
        tem = (U @ np.diag(Lamb) @ np.conj(U.T)).flatten(order="F")
        Pra_m = npa(Pra).reshape((I, J))
        ss1 = (np.repeat((Pra_m @ tem_can)[:, np.newaxis], p_ra1.shape[0], axis=1) - p_ra1[:, np.newaxis])**2
        ss2 = (np.repeat((Pra_m @ tem)[:, np.newaxis], p_ra1.shape[0], axis=1)- p_ra1[:, np.newaxis])**2
        ss = (ss1 - ss2).sum()

In [14]:
def g():
    import jax
    import jax.random as rd
    import jax.numpy as jnp
    key = jax.random.PRNGKey(0)
    rho = jnp.zeros((d, d))
    key, subkey = rd.split(key)
    Te = rd.exponential(subkey, (d,))# np.random.standard_exponential(d) # Initial Y_i^0
    Id = jnp.eye(d)#np.eye(d)
    U = jnp.asarray(u_hat) # eigenvectors of \hat(rho) found using inversion, initial V_i^0
    Lamb = Te/Te.sum() # gamma^0
    be = 1
    p_ra1 = jnp.asarray(p_ra1)
    Pra_m = jnp.asarray(Pra).reshape((I, J))
    for i in range(10):
        for j in range(d):
            Te_can = Te.copy()
            key, subkey = rd.split(key)
            Te_can.at[j].set(Te[j] * jnp.exp(be * rd.uniform(subkey, (1,), minval=-0.5, maxval=0.5))[0]) # \tilde(Y)_i = exp(y ~ U(-0.5, 0.5)) Y_i^t-1
            L_can = Te_can/Te_can.sum() # \tilde(gamma)_i = \tilde(Y_i)/sum_j^d(\tilde(Y_j))
            tem_can = (U @ jnp.diag(L_can) @ jnp.conj(U.T)).flatten(order="F") # gamma * U * U^T (U = V in paper)
            tem = (U @ jnp.diag(Lamb) @ jnp.conj(U.T)).flatten(order="F")
            ss1 = (jnp.repeat((Pra_m @ tem_can)[:, jnp.newaxis], p_ra1.shape[0], axis=1) - p_ra1[:, jnp.newaxis])**2
            ss2 = (jnp.repeat((Pra_m @ tem)[:, jnp.newaxis], p_ra1.shape[0], axis=1)- p_ra1[:, jnp.newaxis])**2
            ss = (ss1 - ss2).sum()

In [15]:
%lprun -f g g()

UnboundLocalError: local variable 'p_ra1' referenced before assignment

## Main MH

In [9]:
%load_ext line_profiler

In [10]:
# def f():
# Main part
rho = np.zeros((d, d))
Te = np.random.standard_exponential(d) # Initial Y_i^0
Id = np.eye(d)
U = u_hat # eigenvectors of \hat(rho) found using inversion, initial V_i^0
Lamb = Te/Te.sum() # gamma^0
ro = 1/2
S = (rho_hat + np.conj(rho_hat.T))/2
be = 1

gamm = n_size/2 # lambda in paper 
entry = []
Iter = 500
burnin = 100
start_time = time.time()
Pra_m = npa(Pra).reshape((I, J))
for t in range(Iter + burnin):
    print(t)
    for j in range(d): # Loop for Y_i
        
        Te_can = Te.copy() 
        Te_can[j] = Te[j] * np.exp(be * np.random.uniform(-0.5, 0.5, 1)) # \tilde(Y)_i = exp(y ~ U(-0.5, 0.5)) Y_i^t-1
        L_can = Te_can/Te_can.sum() # \tilde(gamma)_i = \tilde(Y_i)/sum_j^d(\tilde(Y_j))
        tem_can = (U @ np.diag(L_can) @ np.conj(U.T)).flatten(order="F") # gamma * U * U^T (U = V in paper)
        tem = (U @ np.diag(Lamb) @ np.conj(U.T)).flatten(order="F") # prev gamma * U * U^T
        #ss = (npa([tem_can.T @ x - p_ra1 for x in Pra])**2 - npa([tem.T @ x - p_ra1 for x in Pra])**2).sum() # l^prob: sum_a sum_s (Tr(v P^a_s) - hat(p^_a,s))^2
        ss1 = (np.repeat((Pra_m @ tem_can)[:, np.newaxis], p_ra1.shape[0], axis=1) - p_ra1[:, np.newaxis])**2
        ss2 = (np.repeat((Pra_m @ tem)[:, np.newaxis], p_ra1.shape[0], axis=1)- p_ra1[:, np.newaxis])**2
        ss = (ss1 - ss2).sum()
        r_prior = (ro-1) * np.log(Te_can[j]/Te[j]) - Te_can[j] + Te[j] # other part of R acceptance ratio
        ap = -gamm*np.real(ss) # other part (why use np.real?)
        if np.log(np.random.uniform(0, 1, 1)) <= ap + r_prior: Te = Te_can # if value above draw from U(0,1), then update
        Lamb = Te/Te.sum() # gamma
    print("mid")
    for j in range(d): # Loop for V_i
        U_can = U.copy()
        U_can[:, j] = norm_complex(U[:,j] + np.random.multivariate_normal(np.zeros(d*2),np.eye(d*2)/100, size=(1)).view(np.complex128)) # Sample U/V from the unit sphere (not sure why we add to previous value)
        tem_can = (U_can @ np.diag(Lamb) @ np.conj(U_can.T)).flatten(order="F") # gamma * U * U^T
        tem = (U @ np.diag(Lamb) @ np.conj(U.T)).flatten(order="F") # gamma * U_t-1 * U^T_t-1
        # ss = (npa([tem_can.T @  x - p_ra1 for x in Pra])**2 - npa([tem.T @ x - p_ra1 for x in Pra])**2).sum() # l^prob: sum_a sum_s (Tr(v P^a_s) - hat(p^_a,s))^2
        ss1 = (np.repeat((Pra_m @ tem_can)[:, np.newaxis], p_ra1.shape[0], axis=1) - p_ra1[:, np.newaxis])**2
        ss2 = (np.repeat((Pra_m @ tem)[:, np.newaxis], p_ra1.shape[0], axis=1)- p_ra1[:, np.newaxis])**2
        ss = (ss1 - ss2).sum()
        ap = -gamm * np.real(ss) # other part of A accep ratio
        if np.log(np.random.uniform(0, 1, 1)) <= ap: U = U_can # if value above draw from U(0,1), then update

    if t > burnin:
        rho = U @ np.diag(Lamb) @ np.conj(U.T)/(t - burnin) + rho*(1-1/(t-burnin)) # approximate rho each time as rho_t = gamma_t * V_t * V_t^T /(t-burnin) + rho_t-1 / (1 - 1/(t-burnin)) -> the later we are, the more importance we give to prev rho
end_time = time.time()

0
mid
1
mid
2


KeyboardInterrupt: 

In [9]:
f()

0
mid
1
mid
2
mid
3
mid
4


KeyboardInterrupt: 

In [41]:
%lprun -f f f()

0
mid
1
mid
2
mid
3
mid
4
*** KeyboardInterrupt exception caught in code being profiled.

Timer unit: 1e-09 s

Total time: 5.37994 s
File: /tmp/ipykernel_27398/1258079928.py
Function: f at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def f():
     2                                               # Main part
     3         1      16203.0  16203.0      0.0      rho = np.zeros((d, d))
     4         1      20603.0  20603.0      0.0      Te = np.random.standard_exponential(d) # Initial Y_i^0
     5         1      31010.0  31010.0      0.0      Id = np.eye(d)
     6         1       1886.0   1886.0      0.0      U = u_hat # eigenvectors of \hat(rho) found using inversion
     7         1      49238.0  49238.0      0.0      Lamb = Te/Te.sum() # gamma^0
     8         1       1397.0   1397.0      0.0      ro = 1/2
     9         1      31010.0  31010.0      0.0      S = (rho_hat + np.conj(rho_hat.T))/2
    10         1       1956.0   1956.0      0.0      be = 1
    11                                          

In [None]:
print(f"Took: {end_time - start_time} s")
mean_rho = np.mean((dens_ma - rho) @ np.conj((dens_ma - rho).T))
mean_rho_hat = np.mean((dens_ma - rho_hat) @ np.conj((dens_ma - rho_hat).T))
rho_evs = eig(rho)[0]

In [155]:
#np.ones((5, 10)).sum(axis=0)
# dens_ma.flatten() @ projectors(a[0], r[0])
p_ra

array([[0.0515, 0.045 , 0.0565, ..., 0.0615, 0.057 , 0.0495],
       [0.047 , 0.0575, 0.053 , ..., 0.0465, 0.038 , 0.0465],
       [0.061 , 0.054 , 0.067 , ..., 0.044 , 0.0435, 0.0505],
       ...,
       [0.036 , 0.0515, 0.0635, ..., 0.068 , 0.088 , 0.0875],
       [0.075 , 0.064 , 0.0685, ..., 0.09  , 0.068 , 0.0735],
       [0.068 , 0.0695, 0.071 , ..., 0.0705, 0.0825, 0.083 ]])

In [20]:
functools.reduce(np.kron, [np.eye(2)] * 4).shape

(16, 16)

In [134]:
np.random.multivariate_normal(np.zeros(d*2),np.eye(d*2)/100, 16).view(np.complex128)

array([[-0.06956268+0.02435846j, -0.14123075-0.06311271j,
         0.00222598+0.2609201j ,  0.05914259+0.08152464j,
        -0.03022365+0.05777163j,  0.02326001-0.08326039j,
        -0.0484067 +0.24582793j, -0.00301281-0.13517634j,
        -0.02032145+0.06912466j, -0.10548587-0.04543596j,
         0.07877333+0.08333259j, -0.10900354+0.03113042j,
        -0.06166528+0.16194991j,  0.00724261-0.03640795j,
        -0.06860541-0.11615742j, -0.02843429+0.03530594j],
       [-0.11057946-0.1204106j , -0.05896407+0.08861196j,
         0.11388097+0.13442166j, -0.06069431+0.05797215j,
         0.11024054+0.03092668j,  0.20188125+0.08169993j,
        -0.13928047-0.09284634j, -0.12224631+0.14017381j,
        -0.07238691-0.23412461j,  0.14743275-0.07562126j,
         0.08320665-0.04646794j, -0.06249008+0.02412488j,
         0.01166913-0.01180255j, -0.02507991-0.12516436j,
         0.02586024+0.11108448j, -0.14223691+0.11283102j],
       [ 0.11396121-0.09717666j, -0.13510539+0.10105309j,
         0.1

In [142]:
cpl_v = npa([-0.06956268+0.02435846j, -0.14123075-0.06311271j,
         0.00222598+0.2609201j ,  0.05914259+0.08152464j,
        -0.03022365+0.05777163j,  0.02326001-0.08326039j,
        -0.0484067 +0.24582793j, -0.00301281-0.13517634j,
        -0.02032145+0.06912466j, -0.10548587-0.04543596j,
         0.07877333+0.08333259j, -0.10900354+0.03113042j,
        -0.06166528+0.16194991j,  0.00724261-0.03640795j,
        -0.06860541-0.11615742j, -0.02843429+0.03530594j])

In [146]:
def norm_complex(arr: np.ndarray):
    ret_out = np.zeros_like(arr)
    if len(arr.shape) > 1:
        for i in range(arr.shape[0]):
            ret_out[i,:] = arr[i]/np.sqrt((np.abs(arr[i])**2).sum())  
    else:
        return arr/np.sqrt((np.abs(arr)**2).sum())       
norm_complex(cpl_v)

array([-0.12776887+0.04474027j, -0.2594048 -0.11592192j,
        0.00408856+0.47924354j,  0.10862982+0.14973993j,
       -0.05551312+0.10611172j,  0.04272269-0.15292806j,
       -0.08891074+0.45152308j, -0.00553376-0.24828439j,
       -0.03732531+0.12696434j, -0.19375059-0.08345425j,
        0.14468648+0.15306067j, -0.20021165+0.05717863j,
       -0.11326336+0.2974606j ,  0.01330282-0.0668721j ,
       -0.12601061-0.2133515j , -0.05222652+0.06484799j])

In [23]:
def print_hjello():
    print("hello")
    print(5)

In [25]:
%lprun -f print_hjello print_hjello()

hello
5


Timer unit: 1e-09 s

Total time: 6.4954e-05 s
File: /tmp/ipykernel_12423/3878689494.py
Function: print_hjello at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def print_hjello():
     2         1      59087.0  59087.0     91.0      print("hello")
     3         1       5867.0   5867.0      9.0      print(5)