In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Any

# Import ReadSpyn components
import sys
sys.path.append('../src')
from readout_simulator import (
    QuantumDotSystem, 
    RLC_sensor, 
    JAXReadoutSimulator,
    OU_noise, 
    OverFNoise
)

# Set JAX random key for reproducibility
key = jax.random.PRNGKey(42)

print("ReadSpyn White Noise Example")
print("=" * 40)


ReadSpyn White Noise Example


In [2]:
# Define quantum dot system (2 dots, 1 sensor)
Cdd = jnp.array([[1.0, 0.1], [0.1, 1.0]])  # 2x2 dot-dot capacitance matrix
Cds = jnp.array([[0.6], [0.3]])  # 2x1 dot-sensor coupling matrix
dot_system = QuantumDotSystem(Cdd, Cds)

print(f"Quantum dot system: {dot_system.num_dots} dots, {dot_system.num_sensors} sensors")

# Configure sensor parameters
params_resonator = {
    'Lc': 800e-9,      # Inductance (H)
    'Cp': 0.5e-12,     # Parasitic capacitance (F)
    'RL': 40,          # Load resistance (Ω)
    'Rc': 100e6,       # Coupling resistance (Ω)
    'Z0': 50           # Characteristic impedance (Ω)
}

params_coulomb_peak = {
    'g0': 1/50/1e6,    # Maximum conductance (S)
    'eps0': 0.5,       # Operating point (relative to eps_width)
    'eps_width': 1.0   # Energy width (eV)
}

# Create noise models
eps_noise = OverFNoise(n_fluctuators=3, S1=1e-6, sigma_couplings=0.1,
                       ommax=1e6, ommin=1e3, equally_dist=True)
c_noise = OU_noise(sigma=1e-12, gamma=1e5)

# Create sensor
sensor = RLC_sensor(params_resonator, params_coulomb_peak, c_noise, eps_noise)

print(f"Sensor resonant frequency: {sensor.f0/1e9:.2f} GHz")
print(f"Sensor resonant period: {sensor.T0*1e9:.2f} ns")


Dot-sensor coupling strength (Δε/ε_w): [[0.5757576  0.24242423]]
Quantum dot system: 2 dots, 1 sensors
[RLC_sensor] Initialized with:
  Lc = 8.000e-07 H
  Cp = 5.000e-13 F
  Self-capacitance = 0.000e+00 F
  Total capacitance = 5.000e-13 F
  RL = 40 Ω
  Rc = 1.000e+08 Ω
  Z0 = 50 Ω
  R0 = 5.000e+07 Ω
  g0 = 2.000e-08 S
  eps_w = 1.000e+00 eV
  Resonant frequency = 2.516e+08 Hz
  Resonant period = 3.974e-09 s
  Capacitance noise model: OU_noise
  Energy noise model: OverFNoise
Sensor resonant frequency: 0.25 GHz
Sensor resonant period: 3.97 ns


In [3]:
# Create simulator
simulator = JAXReadoutSimulator(dot_system, [sensor])

# Define charge states to simulate
charge_states = jnp.array([
    [0, 0],  # Both dots empty
    [1, 0],  # First dot occupied
    [0, 1],  # Second dot occupied
    [1, 1]   # Both dots occupied
])

print(f"Charge states to simulate:")
for i, state in enumerate(charge_states):
    print(f"  State {i}: {state}")

# Define simulation parameters
t_end = 1000 * sensor.T0  # 1000 resonant periods
dt = 0.5e-9  # 0.5 ns time step
times = jnp.arange(0, t_end, dt)

print(f"\nSimulation parameters:")
print(f"  End time: {t_end*1e6:.2f} μs")
print(f"  Time step: {dt*1e9:.1f} ns")
print(f"  Number of time points: {len(times)}")


Charge states to simulate:
  State 0: [0 0]
  State 1: [1 0]
  State 2: [0 1]
  State 3: [1 1]

Simulation parameters:
  End time: 3.97 μs
  Time step: 0.5 ns
  Number of time points: 7948


In [4]:
# Calculate average separation between charge states
avg_separation = simulator.calculate_average_separation(charge_states, sensor_idx=0)

print(f"Average separation between charge states: {avg_separation:.6f}")

# Let's also look at individual separations
n_states = charge_states.shape[0]
print(f"\nIndividual separations between charge state pairs:")
for i in range(n_states):
    for j in range(i + 1, n_states):
        # Calculate separation for this pair
        energy_offsets = simulator._calculate_energy_offsets(charge_states[[i, j]], sensor, 0)
        conductances = simulator._calculate_conductance(energy_offsets, sensor)
        separation = abs(conductances[0] - conductances[1])
        print(f"  States {i} vs {j}: {separation:.6f}")


Average separation between charge states: 0.000000

Individual separations between charge state pairs:


TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.

In [None]:
# Precompute noise trajectories
n_realizations = 10
key, subkey = jax.random.split(key)
simulator.precompute_noise(subkey, times, n_realizations, eps_noise)

# Test different SNR values
snr_values = [0.5, 1.0, 2.0, 5.0]
results_dict = {}

for snr in snr_values:
    print(f"\nRunning simulation with SNR = {snr}")
    
    # Define simulation parameters
    params = {
        'eps0': 0.0,
        'snr': snr,
        't_end': t_end
    }
    
    # Run simulation
    key, subkey = jax.random.split(key)
    results = simulator.run_simulation(charge_states, times, params, subkey)
    results_dict[snr] = results
    
    print(f"  Simulation completed for SNR = {snr}")


In [None]:
# Plot IQ trajectories for different SNR values
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
axes = axes.flatten()

colors = ['red', 'blue', 'green', 'orange']
state_labels = ['(0,0)', '(1,0)', '(0,1)', '(1,1)']

for idx, snr in enumerate(snr_values):
    ax = axes[idx]
    
    # Get integrated IQ data
    I_integrated, Q_integrated = simulator.get_integrated_IQ(sensor_idx=0)
    
    # Plot trajectories for each charge state
    for state_idx in range(len(charge_states)):
        # Use first realization for plotting
        I_traj = I_integrated[state_idx, 0, :]
        Q_traj = Q_integrated[state_idx, 0, :]
        
        ax.plot(I_traj, Q_traj, color=colors[state_idx], alpha=0.7, 
                label=f'State {state_labels[state_idx]}', linewidth=1)
        
        # Mark start and end points
        ax.scatter(I_traj[0], Q_traj[0], color=colors[state_idx], s=50, marker='o')
        ax.scatter(I_traj[-1], Q_traj[-1], color=colors[state_idx], s=50, marker='s')
    
    ax.set_xlabel('I (integrated)')
    ax.set_ylabel('Q (integrated)')
    ax.set_title(f'SNR = {snr}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axis('equal')

plt.tight_layout()
plt.show()


In [None]:
# Calculate effective SNR for each simulation
effective_snr_values = []
theoretical_snr_values = []

for snr in snr_values:
    # Get results for this SNR
    results = results_dict[snr]
    sensor_results = results['sensor_results'][0]
    
    # Get integrated IQ data
    I_integrated, Q_integrated = simulator.get_integrated_IQ(sensor_idx=0)
    
    # Calculate centroids for each charge state
    n_states = len(charge_states)
    centroids = []
    
    for state_idx in range(n_states):
        I_mean = jnp.mean(I_integrated[state_idx, :, -1])  # Final integrated values
        Q_mean = jnp.mean(Q_integrated[state_idx, :, -1])
        centroids.append([float(I_mean), float(Q_mean)])
    
    centroids = jnp.array(centroids)
    
    # Calculate average separation between centroids
    separations = []
    for i in range(n_states):
        for j in range(i + 1, n_states):
            separation = jnp.linalg.norm(centroids[i] - centroids[j])
            separations.append(separation)
    
    avg_centroid_separation = float(jnp.mean(jnp.array(separations)))
    
    # Calculate noise level (standard deviation of final integrated values)
    all_final_I = I_integrated[:, :, -1].flatten()
    all_final_Q = Q_integrated[:, :, -1].flatten()
    noise_level = float(jnp.sqrt(jnp.var(all_final_I) + jnp.var(all_final_Q)))
    
    # Effective SNR = signal separation / noise level
    effective_snr = avg_centroid_separation / noise_level if noise_level > 0 else 0
    
    effective_snr_values.append(effective_snr)
    theoretical_snr_values.append(snr)
    
    print(f"SNR = {snr}:")
    print(f"  Average centroid separation: {avg_centroid_separation:.6f}")
    print(f"  Noise level: {noise_level:.6f}")
    print(f"  Effective SNR: {effective_snr:.3f}")
    print()

# Plot theoretical vs effective SNR
plt.figure(figsize=(8, 6))
plt.plot(theoretical_snr_values, effective_snr_values, 'bo-', label='Effective SNR')
plt.plot(theoretical_snr_values, theoretical_snr_values, 'r--', label='Theoretical SNR')
plt.xlabel('Theoretical SNR')
plt.ylabel('Effective SNR')
plt.title('Theoretical vs Effective SNR')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
