In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import binom
import math

class EntangledPhotonSource:
    """
    Simulates an entangled photon pair source for the E91 protocol.
    """
    def __init__(self, pair_rate):
        """
        Args:
            pair_rate (float): Number of entangled photon pairs generated per second (ν_s)
        """
        self.pair_rate = pair_rate

    def pair_generation_rate(self):
        """
        Returns the entangled pair generation rate.
        """
        return self.pair_rate

from scipy.stats import binom
import numpy as np

class Channel:
    """
    Represents a quantum channel: either Optical Fiber or Free Space Optical (FSO).
    """
    def __init__(self, base_efficiency, distance=0, attenuation=None,
                 channel_type='fiber', d_t=0.01, d_r=0.03, divergence=0.025e-3):
        """
        Initialize the channel with type and properties.

        Args:
            base_efficiency (float): Base channel efficiency (0-1)
            distance (float): Distance in km
            attenuation (float): Attenuation in dB/km (fiber) or atmospheric loss factor (FSO)
            channel_type (str): 'fiber' or 'fso'
            d_t (float): Transmitter aperture diameter (m) [FSO only]
            d_r (float): Receiver aperture diameter (m) [FSO only]
            divergence (float): Beam divergence angle in rad [FSO only]
        """
        self.base_efficiency = base_efficiency
        self.distance = distance
        self.attenuation = attenuation
        self.channel_type = channel_type.lower()
        
        # FSO-specific parameters
        self.d_t = d_t
        self.d_r = d_r
        self.divergence = divergence

        if attenuation is not None:
            self.attenuation = attenuation
        elif self.channel_type == 'fiber':
            self.attenuation = 0.2  # dB/km
        elif self.channel_type == 'fso':
            self.attenuation = 0.1  # dB/km
        else:
            raise ValueError("Unsupported channel type. Use 'fiber' or 'fso'.")

        self.misalignment_base = 0.015
        self.misalignment_factor = 0.0002

        self.efficiency = self.calculate_efficiency()
        
        # Stray light noise based on channel type
        if self.channel_type == 'fiber':
            self.p_stray = 0.0
            self.p_raman = 5e-5
        elif self.channel_type == 'fso':
            self.p_stray = 5e-6
            self.p_raman = 0.0
        else:
            raise ValueError("Unsupported channel type. Use 'fiber' or 'fso'.")  


    def calculate_efficiency(self):
        """
        Compute transmission efficiency based on channel type.
        """
        if self.channel_type == 'fiber':
            attenuation_db = self.distance * self.attenuation
            return self.base_efficiency * 10 ** (-attenuation_db / 10)
        
        elif self.channel_type == 'fso':
            L_m = self.distance * 1000  # convert to meters
            geometric_loss = (self.d_r / (self.d_t + self.divergence * L_m)) ** 2
            atmospheric_loss = np.exp((-self.attenuation/4.343) * self.distance)
            return self.base_efficiency * geometric_loss * atmospheric_loss
        
        else:
            raise ValueError("Unsupported channel type. Use 'fiber' or 'fso'.")

    def update_distance(self, distance):
        self.distance = distance
        self.efficiency = self.calculate_efficiency()

    def transmission_probability(self, sent_photons, received_photons):
        if received_photons > sent_photons:
            return 0.0
        return binom.pmf(received_photons, sent_photons, self.efficiency)

    def calculate_misalignment_error(self):
        return min(0.1, self.misalignment_base + self.misalignment_factor * self.distance)


class Detector:
    """
    Represents a single-photon detector with noise characteristics.
    """
    def __init__(self, efficiency, dark_count_rate, time_window):
        """
        Initialize detector with its characteristics.
        
        Args:
            efficiency (float): Detector efficiency (0-1)
            dark_count_rate (float): Dark count rate in counts per second
            time_window (float): Detection time window in seconds
        """
        self.efficiency = efficiency
        self.dark_count_rate = dark_count_rate
        self.time_window = time_window
        self.p_dark = 1 - np.exp(-dark_count_rate * time_window)
        
        # Detector afterpulsing probability
        self.afterpulsing_prob = 0.02
        
        # Detector timing jitter (as error probability)
        self.timing_jitter_error = 0.01
    
    def detect_probability(self, photons):
        """
        Calculate the probability of detection given number of photons.
        
        Args:
            photons (int): Number of photons arriving at detector
            
        Returns:
            float: Probability of detection
        """
        # Probability of at least one photon being detected
        if photons > 0:
            # 1 - probability that none are detected
            p_detect_signal = 1 - (1 - self.efficiency)**photons
            
            # Add saturation effect for multiple photons (models crosstalk and other non-linearities)
            saturation_factor = 1.0
            if photons > 1:
                # Detector saturation for multi-photon pulses
                saturation_factor = 1.0 + 0.02 * (photons - 1)
                
            return min(1.0, p_detect_signal * saturation_factor)
        return 0
    
    def dark_count_probability(self):
        """
        Calculate the probability of a dark count in the detection window.
        
        Returns:
            float: Dark count probability
        """
        return self.p_dark
    
    
class E91:
        """
        Lightweight E91 QKD protocol configuration class.
        Stores physical and system parameters for use in simulators.
        """
        def __init__(self, detector_efficiency, channel_base_efficiency,
                    dark_count_rate, time_window, distance=0, attenuation=0.2, pair_rate=0.64e6):
            self.distance = distance
            self.attenuation = attenuation
            self.time_window = time_window
            self.pair_rate = pair_rate  # ν_s

            # Detection and optical parameters
            self.eta_d = detector_efficiency         # Detector efficiency
            self.eta_c = 0.6                         # Collection efficiency
            self.eta_total = self.eta_d * self.eta_c # Total effective efficiency

            # Channel
            self.channel = Channel(channel_base_efficiency, distance, attenuation)

            # Detector
            self.detector = Detector(detector_efficiency, dark_count_rate, time_window)

            # Noise sources
            self.p_raman = 5e-5                      # Raman noise
            self.p_dark = self.detector.p_dark       # Dark count from detector object
            self.p_stray = self.channel.p_stray
            self.p_noise = self.p_raman + self.p_dark + self.p_stray


            # Measurement basis angles (in radians)
            self.theta_A = [0, np.pi / 4]            # Alice's settings
            self.theta_B = [-np.pi / 8, np.pi / 8]   # Bob's settings
            self.phi = np.pi                         # Bell state phase for |Φ+⟩

        def get_total_efficiency(self):
            return self.eta_total

        def get_noise_probability(self):
            return self.p_noise

        def get_measurement_angles(self):
            return self.theta_A, self.theta_B

        def get_channel_efficiency(self):
            return self.channel.efficiency

        def update_distance(self, new_distance):
            self.distance = new_distance
            self.channel.update_distance(new_distance)

        def summary(self):
            print(f"E91 Setup Summary:")
            print(f"Distance: {self.distance} km")
            print(f"Detector Efficiency: {self.eta_d}")
            print(f"Collection Efficiency: {self.eta_c}")
            print(f"Total Efficiency: {self.eta_total}")
            print(f"Channel Efficiency: {self.channel.efficiency:.4f}")
            print(f"Noise Probability: {self.p_noise:.2e}")

def update_distance(self, distance):
    """
    Update the distance between Alice and Bob and recalculate channel efficiency.
    
    Args:
        distance (float): New distance in kilometers
    """
    self.distance = distance
    self.channel.update_distance(distance)

def update_pair_rate(self, new_rate):
    """
    Update the entangled photon pair generation rate.

    Args:
        new_rate (float): New pair generation rate (Hz)
    """
    self.pair_rate = new_rate

def calculate_sifted_key_rate(self, T):
    """
    Calculate the sifted key rate for E91 (per second).
    
    Args:
        T (float): Total channel transmittance over full distance
    
    Returns:
        float: Sifted key rate in bits per second
    """
    # Per-side transmittance (half distance each)
    T_side = np.sqrt(T)
    
    # Total detection efficiency per side
    detection_efficiency = self.eta_d * self.eta_c
    
    # Probability both photons detected (coincidence)
    coincidence_prob = (T_side * detection_efficiency) ** 2
    
    # Basis sifting fraction for E91 (1/3)
    basis_match_fraction = 1 / 3
    
    # Final sifted key rate
    return self.pair_rate * coincidence_prob * basis_match_fraction


def calculate_quantum_bit_error_rate(self):
    """
    Calculate QBER (%) from the Bell parameter S.
    
    Returns:
        float: QBER as a percentage
    """
    S = self.calculate_bell_parameter_S()
    return max(0.0, (0.5 - S / (4 * np.sqrt(2))) * 100)

def calculate_statistical_error(self, qber_percent, n_samples, confidence=0.95):
    """
    Estimate statistical error in QBER measurement.

    Args:
        qber_percent (float): QBER as a percentage (0–100)
        n_samples (int): Number of key bits measured
        confidence (float): Confidence level (default 0.95)

    Returns:
        float: Statistical error as percentage
    """
    qber = qber_percent / 100  # convert to 0–1

    if qber == 0 or qber == 1 or n_samples <= 0:
        return 0.0

    z = 1 - confidence
    denominator = qber * n_samples * (1 - z)
    if denominator <= 0:
        return 0.0

    error = np.sqrt(qber * (1 - qber) / denominator)
    return error * 100  # return as percentage

def error_correction_efficiency(self, error_rate):
    """
    Calculate the fraction of bits lost due to error correction.
    
    Args:
        error_rate (float): Error rate (δ)
        
    Returns:
        float: Fraction of bits lost in error correction
    """
    if error_rate <= 0:
        return 0
    
    # r_ec = 1.1 × h_binary(error_rate)
    r_ec = 1.15 * self.h_binary(error_rate)
    
    return r_ec

def privacy_amplification_efficiency(self, S):
    """
    Calculate the privacy amplification cost based on Bell parameter S.

    Args:
        S (float): CHSH Bell parameter (should be > 2 for secure key)

    Returns:
        float: Privacy amplification cost (0 to 1)
    """
    if S <= 2:
        # No Bell violation, so maximum privacy leakage
        return 1.0

    # Calculate the argument inside the binary entropy function
    arg = (1 + np.sqrt((S / 2) ** 2 - 1)) / 2

    # Return the binary entropy of the argument
    return self.h_binary(arg)

def h_binary(self, p):
    """
    Binary entropy function H(p) = -p*log2(p) - (1-p)*log2(1-p).
    
    Args:
        p (float): Probability (0 <= p <= 1)
        
    Returns:
        float: Binary entropy value
    """
    if p == 0 or p == 1:
        return 0
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

def calculate_skr(self):
    T = self.transmittance(self.distance)
    N = self.compute_N(T)
    S = self.compute_S(N, self.phi)
    Q = self.compute_QBER(S)
    
    if S <= 2:
        return 0

    # Error rate as fraction
    error_rate = Q

    # Error correction loss depends on QBER
    r_ec = self.error_correction_efficiency(error_rate)
    # Privacy amplification depends on Bell parameter S
    r_pa = self.privacy_amplification_efficiency(S)

    sifted_rate = (1/3) * self.nu_s * T
    final_skr = sifted_rate * (1 - r_ec - r_pa)

    return max(0, final_skr)


class E91Simulator:
    def __init__(self, 
                 pair_rate,
                 detector_efficiency,
                 dark_count_rate,
                 time_window,
                 channel_base_efficiency,
                 distance=0,
                 attenuation=0.2,
                 entanglement_visibility=0.98,
                 channel_type='fso'  # <--- Add this
                ):
        self.pair_rate = pair_rate
        self.detector_efficiency = detector_efficiency
        self.dark_count_rate = dark_count_rate
        self.time_window = time_window
        self.channel_base_efficiency = channel_base_efficiency
        self.distance = distance
        self.alpha_db = attenuation
        self.entanglement_visibility = entanglement_visibility
        self.channel = Channel(
            base_efficiency=channel_base_efficiency,
            distance=distance,
            attenuation=attenuation,
            channel_type=channel_type
        )

        self.p_dark = 1 - np.exp(-dark_count_rate * time_window)
        self.p_dark = 1 - np.exp(-dark_count_rate * time_window)
        self.p_stray = self.channel.p_stray
        self.p_raman = self.channel.p_raman
        self.P_nc = self.p_dark + self.p_stray + self.p_raman



        # Entangled photon source
        self.source = EntangledPhotonSource(pair_rate)

        # Detectors
        self.detector_alice = Detector(detector_efficiency, dark_count_rate, time_window)
        self.detector_bob = Detector(detector_efficiency, dark_count_rate, time_window)

        # Distance split for center-source model
        half_distance = distance / 2

        # Channels to Alice and Bob (using FSO or fiber)
        self.channel_alice = Channel(channel_base_efficiency, half_distance,
                                     attenuation, channel_type=channel_type)
        self.channel_bob = Channel(channel_base_efficiency, half_distance,
                                   attenuation, channel_type=channel_type)

        # Other constants
        self.eta = detector_efficiency
        self.eta_c = 0.6
        self.eta_t = self.eta * self.eta_c
        # self.P_nc = 5e-6 + 5e-6  # noise = stray + dark
        self.theta_1A = 0
        self.theta_3A = np.pi / 4
        self.theta_1B = -np.pi / 8
        self.theta_3B = np.pi / 8
        self.phi = np.pi

        T = self.channel.efficiency

    def _compute_N(self, T):
        p_s = T ** 2
        p_1 = 2 * T * (1 - T)
        p_0 = (1 - T) ** 2
        click_prob = self.eta_t + 2 * self.P_nc * (1 - self.eta_t)
        numerator = p_s * self.eta_t ** 2
        denominator = (p_s * click_prob ** 2 +
                       2 * p_1 * self.P_nc * click_prob +
                       4 * p_0 * self.P_nc ** 2)
        return numerator / denominator

    def _E(self, theta_A, theta_B, N, phi):
        from math import cos, sin
        return N * (-cos(2 * theta_A) * cos(2 * theta_B) +
                    cos(phi) * sin(2 * theta_A) * sin(2 * theta_B))

    def calculate_bell_parameter_S(self):
        T = self.channel.efficiency
        N = self._compute_N(T)
        E1 = self._E(self.theta_1A, self.theta_1B, N, self.phi)
        E2 = self._E(self.theta_1A, self.theta_3B, N, self.phi)
        E3 = self._E(self.theta_3A, self.theta_1B, N, self.phi)
        E4 = self._E(self.theta_3A, self.theta_3B, N, self.phi)
        S = abs(E1 + E2 - E3 + E4)
        return S* self.entanglement_visibility

    def calculate_qber(self):
        S = self.calculate_bell_parameter_S()
        return max(0.0, (0.5 - S / (4 * np.sqrt(2))) * 100)

    def calculate_sifted_key_rate(self, T=None):
        if T is None:
            T = self.channel.efficiency
        T_side = np.sqrt(T)
        detection_efficiency = self.detector_efficiency * self.eta_c
        coincidence_prob = (T_side * detection_efficiency) ** 2
        return self.pair_rate * coincidence_prob * (1/3)


    def error_correction_efficiency(self, error_rate):
        if error_rate <= 0:
            return 0
        return 1.15 * self.h_binary(error_rate)

    def privacy_amplification_efficiency(self, S):
        if S <= 2:
            return 1.0
        arg = (1 + np.sqrt((S / 2) ** 2 - 1)) / 2
        return self.h_binary(arg)

    def h_binary(self, p):
        if p == 0 or p == 1:
            return 0
        return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

    def calculate_skr(self):
        T = self.channel.efficiency
        N = self._compute_N(T)
        S = self.calculate_bell_parameter_S()
        Q = (0.5 - S / (4 * np.sqrt(2)))  # QBER as fraction

        if S <= 2:
            return 0

        r_ec = self.error_correction_efficiency(Q)
        r_pa = self.privacy_amplification_efficiency(S)

        sifted_rate = (1 / 3) * self.pair_rate * T
        final_skr = sifted_rate * (1 - r_ec - r_pa)

        return max(0, final_skr)

