In [None]:
import numpy as np
import torch
from tqdm import tqdm
import opt_einsum
import matplotlib.pyplot as plt


import numqi
np_rng = np.random.default_rng()

tableau = ['#006BA4', '#FF800E', '#ABABAB', '#595959', '#5F9ED1', '#C85200', '#898989', '#A2C8EC', '#FFBC79', '#CFCFCF']

In [None]:
class DensityMatrixGMEModel(torch.nn.Module):
    def __init__(self, dim_list:tuple[int], num_ensemble:int, rank:int, CPrank:int=1):
        super().__init__()
        dim_list = tuple(int(x) for x in dim_list)
        assert (len(dim_list)>=2) and all(x>=2 for x in dim_list)
        self.dim_list = dim_list
        N0 = np.prod(np.array(dim_list))
        self.num_ensemble = int(num_ensemble)
        assert rank<=N0
        self.rank = int(rank)
        assert CPrank>=1
        self.CPrank = int(CPrank)

        self.manifold_stiefel = numqi.manifold.Stiefel(num_ensemble, rank, dtype=torch.complex128, method='polar')
        # methods='QR' seems really bad
        self.manifold_psi = torch.nn.ModuleList([numqi.manifold.Sphere(x, batch_size=num_ensemble*CPrank, dtype=torch.complex128) for x in dim_list])
        if CPrank>1:
            self.manifold_coeff = numqi.manifold.PositiveReal(num_ensemble*CPrank, dtype=torch.float64)
            N1 = len(dim_list)
            tmp0 = [(num_ensemble,rank),(num_ensemble,rank)] + [(num_ensemble,rank,x) for x in dim_list] + [(num_ensemble,rank,x) for x in dim_list]
            tmp1 = [(N1,N1+1),(N1,N1+2)] + [(N1,N1+1,x) for x in range(N1)] + [(N1,N1+2,x) for x in range(N1)]
            self.contract_psi_psi = opt_einsum.contract_expression(*[y for x in zip(tmp0,tmp1) for y in x], [N1])

        self._sqrt_rho = None
        self.contract_expr = None
        self.contract_coeff = None

    def set_density_matrix(self, rho:np.ndarray):
        N0 = np.prod(np.array(self.dim_list))
        assert rho.shape == (N0, N0)
        assert np.abs(rho-rho.T.conj()).max() < 1e-10
        EVL,EVC = np.linalg.eigh(rho)
        EVL = np.maximum(0, EVL[-self.rank:])
        assert abs(EVL.sum()-1) < 1e-10
        EVC = EVC[:,-self.rank:]
        tmp0 = (EVC * np.sqrt(EVL)).reshape(*self.dim_list, self.rank)
        self._sqrt_rho = torch.tensor(tmp0, dtype=torch.complex128)
        N1 = len(self.dim_list)
        if self.CPrank==1:
            tmp0 = [(N1+1,x) for x in range(N1)]
            tmp1 = [(self.num_ensemble,x) for x in self.dim_list]
            tmp2 = [y for x in zip(tmp1,tmp0) for y in x]
            self.contract_expr = opt_einsum.contract_expression(self._sqrt_rho, tuple(range(N1+1)),
                                [self.num_ensemble,self.rank], (N1+1,N1), *tmp2, (N1+1,), constants=[0])
        else:
            tmp0 = [((N1+1,N1+2))] + [(N1+1,N1+2,x) for x in range(N1)]
            tmp1 = [(self.num_ensemble,self.CPrank)] + [(self.num_ensemble,self.CPrank,x) for x in self.dim_list]
            tmp2 = [y for x in zip(tmp1,tmp0) for y in x]
            self.contract_expr = opt_einsum.contract_expression(self._sqrt_rho, tuple(range(N1+1)),
                                [self.num_ensemble,self.rank], (N1+1,N1), *tmp2, (N1+1,), constants=[0])

    def forward(self):
        matX = self.manifold_stiefel()
        psi_list = [x() for x in self.manifold_psi]
        if self.CPrank>1:
            coeff = self.manifold_coeff().reshape(self.num_ensemble, self.CPrank).to(matX.dtype)
            psi_list = [x.reshape(self.num_ensemble,self.CPrank,-1) for x in psi_list]
            psi_conj_list = [x.conj().resolve_conj() for x in psi_list]
            psi_psi = self.contract_psi_psi(coeff, coeff, *psi_list, *psi_conj_list).real
            coeff = coeff / torch.sqrt(psi_psi).reshape(-1,1)
            tmp2 = self.contract_expr(matX, coeff, *psi_list, backend='torch')
        else:
            tmp2 = self.contract_expr(matX, *psi_list, backend='torch')
        loss = 1-torch.vdot(tmp2,tmp2).real
        return loss

In [None]:
# Werner state for d=3
alpha_list = np.linspace(0,1,50)
dim = 3

model = DensityMatrixGMEModel([dim,dim], num_ensemble=18, rank=9)
ret = []
for alpha_i in tqdm(alpha_list):
    werner_rho = numqi.state.Werner(d=3, alpha=alpha_i)
    model.set_density_matrix(werner_rho)
    ret.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)
ret = np.array(ret)

In [None]:
fig,ax = plt.subplots()
ax.plot(alpha_list, ret, label='numerical result', color=tableau[0])
tmp0 = dim - (1-dim*dim) / (alpha_list - dim)
tmp1 = 0.5*(1-np.sqrt(1-np.minimum(0, tmp0)**2)) #https://doi.org/10.1103/PhysRevA.68.042307
ax.plot(alpha_list, tmp1, 'o', markerfacecolor='none', label='analytical result', color=tableau[0])
ax.axvline(1/dim, color='k', linestyle='--', label='sep-ent boundary')
ax.set_yscale('log')
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$E_G(\alpha)$')
ax.legend()
#fig.savefig('data/fig_werner.pdf')

In [None]:
# Random 2x2 density matrix

def get_gme_2qubit(rho:np.ndarray):
    C_rho = numqi.entangle.get_concurrence_2qubit(rho)
    ret = 1/2 * (1-np.sqrt(1-C_rho*C_rho))
    return ret

rand_dm_list = [numqi.random.rand_density_matrix(4) for _ in range(100)]

model = DensityMatrixGMEModel(dim_list=[2,2], num_ensemble=8, rank=4)
err_list = []
for rand_dm in tqdm(rand_dm_list):
    model.set_density_matrix(rand_dm)
    ret = numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun
    err_list.append(np.abs(ret-get_gme_2qubit(rand_dm)))

# max error
print(np.max(err_list))
# mean error
print(np.mean(err_list))

In [None]:
# bound entangled state
rho_bes = numqi.entangle.load_upb('tiles', return_bes=True)[1]
rank = (np.linalg.eigvalsh(rho_bes)>1e-5).sum()
print(rank)
alpha_list = np.linspace(0,1,50)
ret = []
model = DensityMatrixGMEModel(dim_list=[3,3], num_ensemble=18, rank=9)
for alpha_i in tqdm(alpha_list):
    rho = (1-alpha_i) * np.eye(9) / 9 + alpha_i * rho_bes
    model.set_density_matrix(rho)
    ret.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)


In [None]:
fig,ax = plt.subplots()
ax.plot(alpha_list, ret, label='numerical result')
ax.axvline(0.8649, color='k', linestyle='--', label='sep-ent boundary')
ax.set_yscale('log')
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$E_G(\alpha)$')
ax.tick_params(axis='y', which='minor', bottom=False, top=False, left=False, right=False)
ax.legend()
# fig.savefig('data/fig_bes.pdf')


In [None]:
## higher entangled state
dimA = 4
dimB = 4
tmp0 = [
    [(0,0,1), (1,1,1), (2,2,1), (3,3,1)],
    [(0,1,1), (1,2,1), (2,3,1), (3,0,1)],
    [(0,2,1), (1,3,1), (2,0,1), (3,1,-1)],
]
matrix_subspace = np.stack([numqi.matrix_space.build_matrix_with_index_value(dimA, dimB, x) for x in tmp0])
dm = np.einsum(matrix_subspace, [0,1,2], matrix_subspace.conj(), [0,3,4], [1,2,3,4], optimize=True).reshape(dimA*dimB,dimA*dimB) / 12

alpha_list = np.linspace(0, 1, 50)
alpha_fine_list = np.linspace(0.8, 0.9, 50)
ret1 = []
ret2 = []
ret3 = []

for alpha_i in tqdm(alpha_list):
    rho = (1-alpha_i) * np.eye(dimA*dimB) / (dimA*dimB) + alpha_i * dm
    model = DensityMatrixGMEModel([dimA,dimB], num_ensemble=32, rank=16, CPrank=1)
    model.set_density_matrix(rho)
    ret1.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)
    model = DensityMatrixGMEModel([dimA,dimB], num_ensemble=64, rank=16, CPrank=2)
    model.set_density_matrix(rho)
    ret2.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)

for alpha_i in tqdm(alpha_fine_list):
    rho = (1-alpha_i) * np.eye(dimA*dimB) / (dimA*dimB) + alpha_i * dm
    model = DensityMatrixGMEModel([dimA,dimB], num_ensemble=64, rank=16, CPrank=2)
    model.set_density_matrix(rho)
    ret3.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)


In [None]:
# save data in pickle
import pickle
with open('data/fig_higher_entangled_state.pkl', 'wb') as f:
    pickle.dump([alpha_list, alpha_fine_list, ret1, ret2, ret3], f)

In [None]:
fig,ax = plt.subplots(figsize=(6.4,4.8))
ax.plot(alpha_list, ret1, label='r=2', color=tableau[0])
ax.plot(alpha_list, ret2, label='r=3', color=tableau[1])
ax.legend(loc='upper left')
ax.set_xlabel(r'$\alpha$')
ax.set_ylabel(r'$E_r(\alpha)$')

axin = ax.inset_axes([0.1, 0.24, 0.47, 0.47])
axin.plot(alpha_fine_list, ret3, label='r=3', color=tableau[1])
axin.set_xlim(alpha_fine_list[0], alpha_fine_list[-1])
axin.set_yscale('log')
axin.set_ylim(10e-10, 1e-2)
hrect,hpatch = ax.indicate_inset_zoom(axin, edgecolor="red")
hrect.set_xy((hrect.get_xy()[0], -0.005))
hrect.set_height(0.02)
fig.tight_layout()
#fig.savefig('data/fig_higher_entangled.pdf')


In [None]:
# four-qubit cluster state (0000 + 0011 + 1100 - 1111)/2

q0 = np.array([1, 0]) 
q1 = np.array([0, 1])

state_0000 = np.kron(np.kron(np.kron(q0, q0), q0), q0)
state_0011 = np.kron(np.kron(np.kron(q0, q0), q1), q1)
state_1100 = np.kron(np.kron(np.kron(q1, q1), q0), q0)
state_1111 = np.kron(np.kron(np.kron(q1, q1), q1), q1)

cluster_state = ((state_0000 + state_0011 + state_1100 - state_1111) / 2)
rho_cluster = np.outer(cluster_state, cluster_state.conj())

# four-qubit GHZ state (0000 + 1111)/sqrt(2)

ghz_state = numqi.state.GHZ(4)
rho_ghz = np.outer(ghz_state, ghz_state.conj())

# four-qubit W state 
W_state = numqi.state.W(4)
rho_W = np.outer(W_state, W_state.conj())

# four-qubit Dicke state
state_0101 = np.kron(np.kron(np.kron(q0, q1), q0), q1)
state_1001 = np.kron(np.kron(np.kron(q1, q0), q0), q1)
state_0110 = np.kron(np.kron(np.kron(q0, q1), q1), q0)
state_1010 = np.kron(np.kron(np.kron(q1, q0), q1), q0)

dicke_state = (state_0011+state_0101+state_1001+state_1100+state_0110+state_1010)/np.sqrt(6)
rho_dicke = np.outer(dicke_state, dicke_state.conj())

time_list = np.linspace(0, 1, 32)
gamma = 1
cluster_ret0 = []
cluster_ret1 = []
ghz_ret0 = []
ghz_ret1 = []
W_ret0 = []
W_ret1 = []
dicke_ret0 = []
dicke_ret1 = []

for t in tqdm(time_list):
    x = np.exp(-gamma*t)
    rho_cluster_t = rho_cluster.copy() * x
    rho_ghz_t = rho_ghz.copy() * x
    rho_W_t = rho_W.copy() * x
    rho_dicke_t = rho_dicke.copy() * x

    for k in range(16):
        rho_cluster_t[k,k] = rho_cluster[k,k]
        rho_ghz_t[k,k] = rho_ghz[k,k]
        rho_W_t[k,k] = rho_W[k,k]
        rho_dicke_t[k,k] = rho_dicke[k,k]
    model = DensityMatrixGMEModel(dim_list=[2,2,2,2], num_ensemble=32, rank=16)

    model.set_density_matrix(rho_cluster_t)
    cluster_ret0.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)
    tmp = (3/8)*(1+x-np.sqrt(1+(2-3*x)*x))
    cluster_ret1.append(tmp)

    model.set_density_matrix(rho_ghz_t)
    ghz_ret0.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)
    tmp = (1/2)*(1-np.sqrt(1-x*x))
    ghz_ret1.append(tmp)

    model.set_density_matrix(rho_W_t)
    W_ret0.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)
    if x > 2183/2667:
        tmp = 37*(81*x-37)/2816 
    else:
        tmp = (3/8)*(1+x-np.sqrt(1+(2-3*x)*x))
    W_ret1.append(tmp)

    model.set_density_matrix(rho_dicke_t)
    dicke_ret0.append(numqi.optimize.minimize(model, num_repeat=3, tol=1e-10, print_every_round=0).fun)
    if x > 5/7:
        tmp = 5*(3*x-1)/16
    else:
        tmp = (5/18)*(1+2*x-np.sqrt(1+(4-5*x)*x))
    dicke_ret1.append(tmp)

In [None]:
fig,ax = plt.subplots()
ax.plot(time_list, cluster_ret0, '-', label='Cluster state', color=tableau[0])
ax.plot(time_list, cluster_ret1, 'o', markerfacecolor='none', color=tableau[0])
ax.plot(time_list, ghz_ret0, '-', label='GHZ state', color=tableau[1])
ax.plot(time_list, ghz_ret1, 'o', markerfacecolor='none', color=tableau[1])
ax.plot(time_list, W_ret0, '-', label='W state', color=tableau[2])
ax.plot(time_list, W_ret1, 'o', markerfacecolor='none', color=tableau[2])
ax.plot(time_list, dicke_ret0, '-', label='Dicke state', color=tableau[3])
ax.plot(time_list, dicke_ret1, 'o', markerfacecolor='none', color=tableau[3])
ax.set_xlabel('$t$')
ax.set_ylabel(r'$E_G(t)$')
ax.legend()
#fig.savefig('data/fig_multipartite.pdf')
    