# Full code for simulations of satellite-based high-dimensional entanglement distribution 

In [None]:
# Initialise the libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import integrate
from scipy.integrate import quad
from scipy.integrate import trapz
import scipy.optimize
from random import random
from sympy import *
import math
from tqdm import tqdm

### The following is the code to find the proper values for the transmission probability, $p_T$. This was provided by Janice van Dam's Master thesis.

In [None]:
'''data'''

R = 6371*10**3 #m, Earth's radius  
theta_e = 2*np.pi/(24*60*60) #angular velocity of Earth at radians per seconds 
G = 6.674*10**(-11) # graviatitonal constant
M = 5.9736 * 10**(24) #mass of the Earth
v = 23 #(m/s) wind speed
A = 1.7*10**(-14) #(m^(-2/3))
#w_0 = 0.25 #good satellite
w_0 = 0.0308 #micius
Tr= 0.98
tauw = -np.log(Tr)
c = 2.9*10**(8) #(m/s)


alpha=beta=0

max_time = 24*60*60+1
time = np.linspace(0,24*60*60,max_time)
h=[567,894,1262,1681,2162,2772,3385,4182,5165]
landa=810*10**(-9)
#rg = 1.0 #good satellite
rg = 0.6 #micius
#rs = 0.5 #good satellite
rs = 0.2 #micius

L_T = 10**(-0.05) #emitter loss, data from Janice
L_R = 10**(-0.05) #receiver loss, data from Janice
L_j = 10**(-0.01) #jitter loss
L_p = 10**(-0.01) #pointing loss



In [None]:
'''function that defines distances between ground stations'''


def distance_bt_cities(x1,y1,x2,y2):
    x1=x1*np.pi/180
    x2=x2*np.pi/180
    y1=y1*np.pi/180
    y2=y2*np.pi/180
    
    a = np.sin((x2-x1)/2)**2+np.cos(x1)*np.cos(x2)*np.sin((y2-y1)/2)**2
    c = 2*np.arctan2(np.sqrt(a),np.sqrt(1-a))
    d=R*c
    return d/1000
# 500 - 1200km: np.linspace(49.50335,43.2,20)
# 50-400 km: np.linspace(53.55,50.4,20)
# y1 = np.linspace(53.55,50.402,20)
y1 = np.linspace(52.18,42.75,25)
y2=54
x1=0
x2=0

distance = []

for i in range(len(y1)):
    distance.append(distance_bt_cities(x1,y1[i],x2,y2))


dist = distance
print(dist)


In [None]:
'''#transmission probability'''

def zenith(x,y,landa,h,alpha,rs,t,delta):
    
    def xground(x,y,t):
        A = R*np.sin((90-x)*np.pi/180)*np.cos(y*np.pi/180+ theta_e*t)
        B = R*np.sin((90-x)*np.pi/180)*np.sin(y*np.pi/180+ theta_e*t)
        C = R*np.cos((90-x)*np.pi/180)
        
        return np.array([A,B,C])
    
    def theta_s(h):
        T = 2*np.pi*np.sqrt((R + h)**3/(G*M))
        return 2*np.pi/T
    
    alpha = alpha*np.pi/180
    
    
    def Ry(beta): #rotation matrix around y axis
        return np.array([[np.cos(beta), 0, np.sin(beta)], 
                      [0, 1, 0], 
                      [-np.sin(beta), 0, np.cos(beta)]])
    def Rx(beta): #rotation matrix around x axis
        return np.array([[1, 0, 0], 
                     [0, np.cos(beta), -np.sin(beta)], 
                     [0, np.sin(beta), np.cos(beta)]])
    
    def x_sat(h,t,alpha,delta):
        A_sat = (R + h)*np.cos(theta_s(h)*(t+delta))
        B_sat = (R+ h)*np.sin(theta_s(h)*(t+delta))
        C_sat = 0
    
        basis = np.array([A_sat,B_sat,C_sat])
        
        return np.dot(Rx(alpha),basis)
    
    x_ground_norm = xground(x,y,t)/np.linalg.norm(xground(x,y,t),2)
    dist_norm = np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t))/np.linalg.norm(np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t)),2)
    
    zenit = np.arccos(np.dot(x_ground_norm,dist_norm))
    
    return zenit
    
    


def trans_probability_down(x,y,landa,h,alpha,rs,t,delta):
    
    def xground(x,y,t):
        A = R*np.sin((90-x)*np.pi/180)*np.cos(y*np.pi/180+ theta_e*t)
        B = R*np.sin((90-x)*np.pi/180)*np.sin(y*np.pi/180+ theta_e*t)
        C = R*np.cos((90-x)*np.pi/180)
        
        return np.array([A,B,C])
        
    
    
    def theta_s(h):
        T = 2*np.pi*np.sqrt((R + h)**3/(G*M))
        return 2*np.pi/T
    
    alpha = alpha*np.pi/180
    
    
    def Ry(beta): #rotation matrix around y axis
        return np.array([[np.cos(beta), 0, np.sin(beta)], 
                      [0, 1, 0], 
                      [-np.sin(beta), 0, np.cos(beta)]])
    def Rx(beta): #rotation matrix around x axis
        return np.array([[1, 0, 0], 
                     [0, np.cos(beta), -np.sin(beta)], 
                     [0, np.sin(beta), np.cos(beta)]])
    
    def x_sat(h,t,alpha,delta):
        A_sat = (R + h)*np.cos(theta_s(h)*(t+delta))
        B_sat = (R+ h)*np.sin(theta_s(h)*(t+delta))
        C_sat = 0
    
        basis = np.array([A_sat,B_sat,C_sat])
        
        return np.dot(Rx(alpha),basis)

    
    x_ground_norm = xground(x,y,t)/np.linalg.norm(xground(x,y,t),2)
    dist_norm = np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t))/np.linalg.norm(np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t)),2)
    
    zenit = np.arccos(np.dot(x_ground_norm,dist_norm))
   
    l = np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t))
        
    l_mod = np.sqrt(l[0]**2+l[1]**2+l[2]**2)
    
    z_R = np.pi*w_0**2/landa
    
    def w2(z):
        return w_0**2*(1+(z/z_R)**2)
    
    def Cn2(r):
        a = 0.00594*(v/27)**2*(10**(-5)*r)**10*np.exp(-r/1000)
        b = 2.7*10**(-16)*np.exp(-r/1500)
        c = A*np.exp(-r/100)
        return a+b+c
    
    
    
    
    def Tdown(landa,zenit,h,z):
        k=2*np.pi/landa
        Landa = 2*z/(k*w2(z))
        func = lambda r: Cn2(r)*((r)/h)**(5/3)
        result = quad(func,0,h,epsabs = 1e-16)[0]
        
        return 4.35*Landa**(5/6)*k**(7/6)*h**(5/6)*(1/np.cos(zenit))**(11/6)*result
     



    def weff2down(z,landa,zenit,h):
       return w2(z)*(1+Tdown(landa,zenit,h,z))
   
    def PRPEdownrel(l_mod,landa,zenit,rs,h):
            LATM = 0.88**(1/np.cos(zenit)) #regular extintion
            E = (1-np.exp(-2*rs**2/weff2down(l_mod,landa,zenit,h)))*LATM
            
            return E
        
        
    
    
    if zenit<1.22:
        return PRPEdownrel(l_mod,landa,zenit,rs,h)*L_T*L_p*L_j*L_R
            
    else: 
        return 0

In [None]:
'''length sat-ground station'''

def length_sat_ground(x,y,landa,h,alpha,t,delta):
    
    def xground(x,y,t):
        A = R*np.sin((90-x)*np.pi/180)*np.cos(y*np.pi/180+ theta_e*t)
        B = R*np.sin((90-x)*np.pi/180)*np.sin(y*np.pi/180+ theta_e*t)
        C = R*np.cos((90-x)*np.pi/180)
        
        return np.array([A,B,C])
        
    
    
    def theta_s(h):
        T = 2*np.pi*np.sqrt((R + h)**3/(G*M))
        return 2*np.pi/T
    
    alpha = alpha*np.pi/180
    
    
    def Ry(beta): #rotation matrix around y axis
        return np.array([[np.cos(beta), 0, np.sin(beta)], 
                      [0, 1, 0], 
                      [-np.sin(beta), 0, np.cos(beta)]])
    def Rx(beta): #rotation matrix around x axis
        return np.array([[1, 0, 0], 
                     [0, np.cos(beta), -np.sin(beta)], 
                     [0, np.sin(beta), np.cos(beta)]])
    
    def x_sat(h,t,alpha,delta):
        A_sat = (R + h)*np.cos(theta_s(h)*(t+delta))
        B_sat = (R+ h)*np.sin(theta_s(h)*(t+delta))
        C_sat = 0
    
        basis = np.array([A_sat,B_sat,C_sat])
        
        return np.dot(Rx(alpha),basis)
    
    x_ground_norm = xground(x,y,t)/np.linalg.norm(xground(x,y,t),2)
    dist_norm = np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t))/np.linalg.norm(np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t)),2)
    
    zenit = np.arccos(np.dot(x_ground_norm,dist_norm))

   
    l = np.subtract(x_sat(h,t,alpha,delta),xground(x,y,t))
        
    l_mod = np.sqrt(l[0]**2+l[1]**2+l[2]**2)
    
    if zenit<1.22:
        return l_mod
            
    else: 
        return 0

In [None]:
'''satellite-ground station link-sanity check'''

Value_dict = {}
# heights = [400e3,500e3,600e3]
heights = [400e3]

for h in heights:
    values = []

    p_T = []
    length_a =np.zeros(len(dist))
    T_com_sat_g = np.zeros(len(dist))


    for i in range(len(dist)):
        
        Z_A = np.zeros(max_time)
        Z_B = np.zeros(max_time)
        p_Ap_B = 0
        p_t = 0
        Le = 0
        Lenght = 0
        
        for j in range(len(time)):
            Z_A[j] = zenith(0, y1[i], landa , h, alpha, rg, time[j],0)
            Z_B[j] = zenith(0,54,landa,h,alpha,rg,time[j],0)
            if Z_A[j] < 1.22:
                
                le = length_sat_ground(0,y1[i],landa,h,alpha,time[j],0)
                Le+=1
                Lenght+= le
                
            
            if Z_A[j] and Z_B[j] < 1.22:
                
                p_A = trans_probability_down(0, y1[i], landa , h, alpha, rg, time[j],0)
                p_B = trans_probability_down(0,54,landa,h,alpha,rg,time[j],0)
                p_Ap_B+=1
                pa_pb = p_A*p_B
                s_pa_pb = np.sqrt(pa_pb)
                p_t+=s_pa_pb
                
        if p_t!=0:       
            p_T.append(p_t/(p_Ap_B))
            length_a[i] = Lenght/Le
            T_com_sat_g[i]=2*length_a[i]/c
        else:
            p_T.append(0)
            
        values.append((dist[i],p_T[i]))
    Value_dict[h] = dict(values)
    # print(p_T)
    # print(dist)
    # print(length_a)
    # print(T_com_sat_g)
    print(Value_dict)




In [None]:
maxpt_values = []
for i in dist:
    pt_values = []
    for j in heights:
        pt_values.append(Value_dict[j][i])
    
    maxpt_values.append(max(pt_values))
print(maxpt_values)

#here we choose the best height for the satellite depeding on the distances

### The following are the functions for the simulations done in my thesis.
Here the list of distances and corresponding $p_T$ values are provided by the code above, where different parameters can be set to account for several types of losses during transmission.

In [None]:
def Repetition_rate(m,trep_source):
    return 2**m*trep_source


def Fqudit(lam,eta,pt,pda,m,t):
    return 2**m*eta**2*pt**2*lam**2*((np.exp(-t/t1)+(1-np.exp(-t/t1))*id)*(np.exp(-t/t2)+(1-np.exp(-t/t2))*id)
    +(1-np.exp(-t/t1))*(1-np.exp(-t/t2))*(x**2+y**2+z**2))/(2**m*pt**2*lam**2*eta**2
    +2**(m-1)*(2**m+1)*lam**4*pt**4*eta**4
    +2**(m+1)*lam**2*pda*pt*eta*(1-pt*eta)
    +2*m*lam**2*(1-pt*eta)**2*pda**2
    +2**m*(2**m+1)*lam**4*pt**2*eta**2*(1-pt*eta)**2*(pda+2)
    +2**(m+1)*(2**m+1)*lam**4*pt*eta*(1-pt*eta)**3*pda
    +2**(m+1)*(2**m+1)*lam**4*pt**3*eta**3*(1-pt*eta)
    +2**(m-1)*(2**m+1)*lam**4*(1-pt*eta)**4*pda**2+pda**2)   


def p_suc(lam,pt,eta,pda,m):
    return (1-lam**2)**(2**m)*(2**m*pt**2*lam**2*eta**2
    +2**(m-1)*(2**m+1)*lam**4*pt**4*eta**4
    +2**(m+1)*lam**2*pda*pt*eta*(1-pt*eta)
    +2**m*lam**2*(1-pt*eta)**2*pda**2
    +2**m*(2**m+1)*lam**4*pt**2*eta**2*(1-pt*eta)**2*(pda+2)
    +2**(m+1)*(2**m+1)*lam**4*pt*eta*(1-pt*eta)**3*pda
    +2**(m+1)*(2**m+1)*lam**4*pt**3*eta**3*(1-pt*eta)
    +2**(m-1)*(2**m+1)*lam**4*(1-pt*eta)**4*pda**2+pda**2)  


def p_ent(th,trep,lam,pt,eta,pda,m):
    p_success = p_suc(lam,pt,eta,pda,m)
    n = math.ceil(th/trep)
    pi_00 = 1/(1+n*(2*np.sqrt(p_success)-p_success))
    return p_success*pi_00


def average_trials(lam,eta,pt,pda,n_pairs,m,multiplex,trep,th):
    # n_pairs is the number of pairs we want to entangle
    # m is the number of pairs that a qudit will entangle
    # multiplex gives how many detectors we have for the qudits, so if multiplex is 2 and m=4 we will have 2 detectors but a total of 8 memories

    p_success = Symbol('p_success')
    q = 1-p_success
    detectors_needed = math.ceil(n_pairs/m)
    detectors_present = multiplex*detectors_needed
    avg_trials = np.sum([q**(detectors_present+1-i)/(1-q**(detectors_present+1-i)) \
                         for i in range(1,detectors_needed+1)])+detectors_needed 
    return avg_trials.evalf(subs={p_success:p_ent(th,trep,lam,pt,eta,pda,m)})

def average_rate(lam,eta,pt,pda,n_pairs,m,multiplex,trep,th):
    return 1/(average_trials(lam,eta,pt,pda,n_pairs,m,multiplex,trep,th)*trep)


def average_fidelity(lam, eta, pt, pda, n_pairs,m, multiplex,trep,th, precision=1e-2):
    # multiplex here tells us how many detectors we have in total the amount of memories is m*multiplex and n_pairs is desired amount of pairs
   
  
    p_success = Symbol('p_success')
    q = 1-p_success
    a = (id-id**2-x**2-y**2-z**2)*np.exp(-th/t1)
    a1 = np.exp(-trep/t1)
    b = (id-id**2-x**2-y**2-z**2)*np.exp(-th/t2)
    b1 = np.exp(-trep/t2)
    c = (1-2*id+id**2+x**2+y**2+z**2)*np.exp(-th*(t1+t2)/(t1*t2))
    c1 = np.exp(-trep*(t1+t2)/(t1*t2))
    d = id**2+x**2+y**2+z**2
    coef1 = [a1,b1,c1]
    coef = [a,b,c]
    
    # for each sum strictly bigger then the previous one we get the following implementation
    
    detectors_needed = math.ceil(n_pairs/m)
    detectors_present = multiplex*detectors_needed
    memories = detectors_present*m
    fid = math.factorial(detectors_present)/math.factorial(detectors_present-detectors_needed)*\
        p_success**detectors_needed*\
        q**(np.sum([detectors_present-detectors_needed+i for i in range(detectors_needed)]))*\
        Fqudit(lam, eta, pt, pda, m=m, t=0)
 
    mylist = []
    counter = 0
    for i in coef1:
        product_terms = [1/(1-q**i) for i in range(detectors_present-detectors_needed+1,detectors_present+1)]
        for j in range(0,detectors_needed-1):
            product_terms[j] = 1/(1-i*q**(j+detectors_present-detectors_needed+1))
            mylist.append(coef[counter]*i**(j+1)*np.product(product_terms))
        counter += 1
    sum_terms = np.sum(mylist)+(a+b+c+detectors_needed*d)*np.product([1/(1-q**i) for i in range(1+detectors_present-detectors_needed,1+detectors_present)])
        
    prob = math.factorial(detectors_present)/math.factorial(detectors_present-detectors_needed)*\
    p_success**detectors_needed*q**(np.sum([detectors_present-detectors_needed+i for i in range(detectors_needed)]))*\
    np.product([1/(1-q**(detectors_present+1-i)) for i in range(1,detectors_needed+1)])
    if prob.evalf(subs={p_success:p_ent(th,trep,lam,pt,eta,pda,m)}) < 1-precision:
        print('Prob multiplex is bad',prob.evalf(subs={p_success:p_ent(th,trep,lam,pt,eta,pda,m)}), lam)
        approximation_invalid = True
    else:
        approximation_invalid = False  
    avg_fid = fid*sum_terms/(detectors_needed*prob)  
    return avg_fid.evalf(subs={p_success:p_ent(th,trep,lam,pt,eta,pda,m)}), approximation_invalid


def plot_rate_dist(amount,m,list_multiplex,min_fid,error,save_plot=False,save_data=False):
    
    lambdas=np.arange(0,0.1,1e-4)
    
    m_counter=0
    for amount in tqdm(amount):
        qudit_multiplex_plot = []
        n=m[m_counter]
        trep = Repetition_rate(n,trep_source)

        for multiplexing in tqdm(list_multiplex):
            rate_qudit_multiplex = [0]*len(dist)                       
            index = 0

            for L in tqdm(dist):                
                th=L*10**3/c
                pt = maxpt_values[index]     

                qudit_fid_multiplex = []
                approximation_invalid = []
                

                for lam in lambdas:
                    qudit_fid,approximation = average_fidelity(lam,eta=eta,pt=pt,pda=pda,m=n,\
                                                            n_pairs=amount,multiplex=multiplexing,\
                                                                trep=trep,th=th,precision=error)
                    qudit_fid_multiplex.append(qudit_fid)
                    approximation_invalid.append(approximation)
                
        
                for i, val in enumerate(reversed(qudit_fid_multiplex)):         
                    if val > min_fid:              
                        rate_qudit_multiplex[index]=average_rate(lambdas[len(lambdas)-i-1],eta=eta,\
                                                                pt=pt,pda=pda,n_pairs=amount,\
                                                                m=n,multiplex=multiplexing,\
                                                                trep=trep,th=th)
                        break
                if approximation_invalid[len(lambdas)-i-1]==True:
                    print('warning approximation invalid ',approximation_invalid[len(lambdas)-i-1])
                    break  
                if rate_qudit_multiplex[index]==0:
                    break
             

                index += 1
        
            qudit_multiplex_plot.append(rate_qudit_multiplex)

        
        plt.figure(dpi=300)
        for i in range(len(qudit_multiplex_plot)):
            plt.plot(dist,qudit_multiplex_plot[i],label='multiplex = '+str(list_multiplex[i]))
        plt.xlabel('Distance (km)')
        plt.ylabel("Rate (Hz)")
        plt.title('Rate for '+str(amount)+' pairs with fidelity > '+str(min_fid)+' for m = '+str(n))
        plt.legend(loc='best')
        plt.tight_layout()
        if save_plot==True:
            plt.savefig('minfid'+str(min_fid)+'_'+str(t1)+'s_'+str(amount)+'pairs_m'+str(n)+'maxmultiplex'\
                    +str(max(list_multiplex))+'.png',bbox_inches='tight',dpi=600)
        plt.show()           

        qudit_data_dict = {'n_pairs':amount,'t':t1,'m=':n,\
                    'min_fid':min_fid,'max_fid':max(qudit_fid_multiplex),\
                    'pda':pda,'eta':eta}  
        dist_pt_dict = {'distance':dist,'p_t':maxpt_values}
        heights_dict = {'heights':heights}
        multiplex_dict = {'multiplex':list_multiplex}
        qudit_plot_dict = dict(zip(list_multiplex,qudit_multiplex_plot))
        df_data = pd.DataFrame([qudit_data_dict])
        df_plot = pd.DataFrame(qudit_plot_dict)
        df_data = pd.concat([df_data,pd.DataFrame(heights_dict),pd.DataFrame(multiplex_dict),\
                            pd.DataFrame(dist_pt_dict),df_plot],axis=1)
        if save_data==True:
            df_data.to_csv('minfid'+str(min_fid)+'_'+str(t1)+'s'+str(amount)+'pairs_m'+str(n)\
                       +'_'+str(round(max(dist)))+'km_trep'+str(trep_source)+'maxmultiplex'+str(max(list_multiplex))\
                        +'eta'+str(eta)+'pda'+str(pda)+'data.csv')

        m_counter+=1


In [None]:
# decoherence times
t1=10
t2=10

# speed of light
c = 2.998e8

#r epetition rate of source
trep_source=1e-7


# type of error
x=1/4
y=1/4
z=1/4  #depolarising
id=1/4

eta = 0.5
pda = 1.6e-5
# pda = 1e-6

# amount of pairs desired $n_p$
amount = [5,10]

# different multiplexing values
list_multiplex = [1,10,25,50]

# minimum fidelity required
min_fid = 0.95

# precision of the approximation
error=1e-2

save_p = False
save_d = True



# print("plot for qudit with m=5 always")
m=len(amount)*[5]
plot_rate_dist(amount,m,list_multiplex,min_fid = min_fid, error= error, save_plot = save_p,save_data = save_d)

