# Advanced Features Tutorial

This notebook demonstrates advanced features of the 2D Pseudomode Framework:

1. Custom spectral densities
2. Adaptive truncation schemes
3. GPU acceleration
4. Batch processing for materials screening
5. Performance profiling

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pseudomode_py as pm
import time

%matplotlib inline

## 1. Custom Spectral Densities

Build custom spectral densities by combining multiple components.

In [None]:
omega = np.linspace(0, 5, 1000)

# Combine acoustic, flexural, and discrete phonon modes
J_acoustic = pm.SpectralDensity2D.acoustic(omega, alpha=0.5, omega_c=1.0, q=1.5)
J_flexural = pm.SpectralDensity2D.flexural(omega, alpha_f=0.3, omega_f=0.5, s_f=0.5, q=2.0)
J_optical = pm.SpectralDensity2D.lorentzian_peak(omega, Omega_j=2.5, lambda_j=0.8, Gamma_j=0.15)

# Total spectral density
J_total = np.array(J_acoustic) + np.array(J_flexural) + np.array(J_optical)

# Plot components
plt.figure(figsize=(12, 6))
plt.plot(omega, J_acoustic, label='Acoustic', linewidth=2)
plt.plot(omega, J_flexural, label='Flexural (ZA)', linewidth=2)
plt.plot(omega, J_optical, label='Optical Peak', linewidth=2)
plt.plot(omega, J_total, 'k--', label='Total', linewidth=2.5)
plt.xlabel('Energy (eV)', fontsize=14)
plt.ylabel('Spectral Density J(ω)', fontsize=14)
plt.title('Custom Composite Spectral Density', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

## 2. Adaptive Truncation Analysis

Explore how adaptive Hilbert space truncation improves efficiency.

In [None]:
# Test different truncation thresholds
truncation_thresholds = [0.001, 0.01, 0.05, 0.1]
n_max_values = []
computation_times = []

for threshold in truncation_thresholds:
    config = pm.SimulationConfig()
    config.max_pseudomodes = 4
    config.total_time_ps = 10.0
    config.occupation_threshold = threshold
    
    # Compute adaptive n_max
    modes = [pm.PseudomodeParams(omega_eV=1.0, gamma_eV=0.1, g_eV=0.5)]
    n_max = pm.Utils.compute_adaptive_n_max(modes, temperature_K=300.0, 
                                             occupation_threshold=threshold)
    n_max_values.append(n_max)
    
    print(f"Threshold {threshold:.3f}: n_max = {n_max}")

# Plot threshold vs Hilbert space size
plt.figure(figsize=(10, 6))
plt.plot(truncation_thresholds, n_max_values, 'o-', linewidth=2, markersize=10)
plt.xlabel('Occupation Threshold', fontsize=14)
plt.ylabel('Adaptive n_max', fontsize=14)
plt.title('Adaptive Truncation: Threshold vs Hilbert Space Size', fontsize=16)
plt.xscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## 3. GPU vs CPU Performance Comparison

In [None]:
# Setup
system_params = pm.System2DParams()
system_params.omega0_eV = 1.4
system_params.temperature_K = 300.0

omega_grid = np.linspace(0, 5, 500)
time_grid = np.linspace(0, 20, 200)

# CPU benchmark
config_cpu = pm.SimulationConfig()
config_cpu.use_gpu = False
config_cpu.max_pseudomodes = 5

framework_cpu = pm.PseudomodeFramework2D(config_cpu)

start = time.time()
result_cpu = framework_cpu.simulate_material("graphene", system_params, omega_grid, time_grid)
time_cpu = time.time() - start

# GPU benchmark
config_gpu = pm.SimulationConfig()
config_gpu.use_gpu = True
config_gpu.max_pseudomodes = 5

framework_gpu = pm.PseudomodeFramework2D(config_gpu)

start = time.time()
result_gpu = framework_gpu.simulate_material("graphene", system_params, omega_grid, time_grid)
time_gpu = time.time() - start

# Compare
print(f"CPU time: {time_cpu:.3f} s")
print(f"GPU time: {time_gpu:.3f} s")
print(f"Speedup: {time_cpu/time_gpu:.2f}x")

# Visualize
fig, ax = plt.subplots(figsize=(8, 6))
ax.bar(['CPU', 'GPU'], [time_cpu, time_gpu], color=['#3498db', '#2ecc71'])
ax.set_ylabel('Computation Time (s)', fontsize=14)
ax.set_title('CPU vs GPU Performance', fontsize=16)
ax.grid(True, axis='y', alpha=0.3)
plt.show()

## 4. Batch Materials Screening

Process multiple materials in parallel for high-throughput screening.

In [None]:
# Materials to screen
materials = ['graphene', 'MoS2', 'WS2', 'hBN']

# System parameters for each material
systems = [pm.System2DParams() for _ in materials]
for sys in systems:
    sys.omega0_eV = 1.4
    sys.temperature_K = 300.0

# Batch simulation
config = pm.SimulationConfig()
config.max_pseudomodes = 4
config.total_time_ps = 20.0

framework = pm.PseudomodeFramework2D(config)

print("Running batch simulation...")
results = framework.batch_simulate(materials, systems, n_parallel_jobs=-1)

# Extract and compare coherence times
T1_values = [r.coherence_times.T1_ps for r in results]
T2_values = [r.coherence_times.T2_star_ps for r in results]

# Plot comparison
x = np.arange(len(materials))
width = 0.35

fig, ax = plt.subplots(figsize=(12, 6))
ax.bar(x - width/2, T1_values, width, label='T₁', alpha=0.8)
ax.bar(x + width/2, T2_values, width, label='T₂*', alpha=0.8)

ax.set_xlabel('Material', fontsize=14)
ax.set_ylabel('Coherence Time (ps)', fontsize=14)
ax.set_title('Materials Screening: Coherence Times Comparison', fontsize=16)
ax.set_xticks(x)
ax.set_xticklabels(materials)
ax.legend(fontsize=12)
ax.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

# Print ranking
print("\nMaterial Ranking by T₂*:")
ranking = sorted(zip(materials, T2_values), key=lambda x: x[1], reverse=True)
for i, (mat, t2) in enumerate(ranking, 1):
    print(f"{i}. {mat}: T₂* = {t2:.2f} ps")

## 5. Memory Usage Estimation

Estimate memory requirements for different system sizes.

In [None]:
# Test different configurations
n_pseudomodes_range = range(1, 11)
memory_usage = []

system_dim = 2  # Qubit
n_max = 5

for n_modes in n_pseudomodes_range:
    mem_bytes = pm.Utils.estimate_memory_usage(system_dim, n_modes, n_max)
    mem_mb = mem_bytes / (1024 * 1024)
    memory_usage.append(mem_mb)
    print(f"n_modes = {n_modes}: {mem_mb:.2f} MB")

# Plot scaling
plt.figure(figsize=(10, 6))
plt.semilogy(list(n_pseudomodes_range), memory_usage, 'o-', linewidth=2, markersize=8)
plt.xlabel('Number of Pseudomodes', fontsize=14)
plt.ylabel('Memory Usage (MB)', fontsize=14)
plt.title('Memory Scaling with System Size', fontsize=16)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()