In [1]:
%load_ext autoreload
%autoreload 2

### Basic Imports

In [2]:
# Import optimization functions from SciPy
from scipy.optimize import differential_evolution, minimize, OptimizeResult, brentq
import numpy as np  
import numdifftools as nd  
import matplotlib.pyplot as plt  
import matplotlib.animation as animation  
import time  
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile, assemble
from qiskit_aer import Aer  
from qiskit.circuit.library.standard_gates import HGate, SGate, SdgGate, RYGate, CXGate

from qiskit.circuit.library import XGate, RXGate, CCXGate, CRXGate, RZGate, CRZGate, SwapGate, QFT
from qiskit.quantum_info import Operator, SparsePauliOp, Statevector, partial_trace, DensityMatrix
from tqdm import tqdm  

from itertools import product  

import datetime  

### Custom Imports

In [3]:
from ansatz_circuit import get_full_variational_quantum_circuit
from ComparisonMaker import get_less_than_operations, get_less_than_logic_explanation
from CustomOperations import * #imports CCRX, AND, OR, CQFT, and their inverses

# How many qubits?

In [4]:
QUBITS_NUM = 8 #number of Position Register qubits (longer runtime scales exponentially with number of qubits)
ANCILLA_QUBITS = 4
COMPARISON_QUBITS = 3

total_qubits = QUBITS_NUM + ANCILLA_QUBITS + COMPARISON_QUBITS

start_time = 0.0 #start time
timestep = 10.0 #Trotter step size
stop_time = 2001.0 #stop time
trial = 1.0
diabaticity = 0.1
IntegratedCoupling = 0.007923 #integral from -inf to inf of 0.01*e^(-5x^2)

numbers = [127, 128, 129] #where are your breakpoints?

In [5]:
current_time = int(time.time())
simulator_backend = Aer.get_backend('qasm_simulator')

# **TROTTER-SUZUKI** Parameters

In [6]:
m = 1818.18 #mass term
L = 20.0 #box size
d = L/2 #Box is defined from (-d to d)

N = 2**QUBITS_NUM #number of states

x_0 = 0 #horizontal offset of initialized Gaussian wavepacket
p_0 = -1.0 #momentum of initialized Gaussian wavepacket
δ = 1.0/3.0 #width of initialized Gaussian wavepacket

Δ_x = L / N #position grid spacing
Δ_p  = 2*np.pi / (N * Δ_x) #momentum grid spacing

Nyquist = np.pi / Δ_x #max / min momentum

center = 0 #center of box

V1_strength = 0.015
V2_strength = 0.015

offset = center #offset for first quadratic potential
offset_2 = center - 1.5 #offset for second quadratic potential
statevector = []

ħ = 1.0 #atomic units for hbar

step_phase = (IntegratedCoupling / ((Δ_x)*(numbers[1] - numbers[0] + 1)))* timestep #strength of coupling

#Filenames
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
ancilla_filename = f"{current_date}_{QUBITS_NUM}q_ancilla_states_marcus_timestep{timestep}_x0{x_0}_p0{p_0}_t{trial}_alpha{diabaticity}.txt"
statevector_filename = f"{current_date}_{QUBITS_NUM}q_statevectors_marcus_timestep{timestep}_x0{x_0}_p0{p_0}_t{trial}_alpha{diabaticity}.txt"
momentum_filename = f"{current_date}_{QUBITS_NUM}q_momentum_after_cqft_marcus_timestep{timestep}_x0{x_0}_p0{p_0}_t{trial}_alpha{diabaticity}.txt"

#The same ancilla qubit can be used in the entire circuit. A second one helps with clarity, but is unnecessary
ancilla_register = QuantumRegister(ANCILLA_QUBITS, name="ancilla")
ancilla_circuit = QuantumCircuit(ancilla_register)

#One qubit is the objective qubit, whose rotations are controlled by the position register
#n comparator qubits are used to store positions for n breakpoints (this example, three breakpoints)
comparison_register = QuantumRegister(COMPARISON_QUBITS, name="comparison")
comparison_circuit = QuantumCircuit(comparison_register)

#A recreation of the initial circuit, to define a register that represents qubit positions
position_register = QuantumRegister(QUBITS_NUM, name="position")
position_circuit = QuantumCircuit(position_register)

# Helper Functions

In [7]:
def record_ancilla_state_z_basis_single_tau(main_circuit, ancilla_register, tau, filename="ancilla_states.txt"):

    # Set up the statevector simulator
    simulator = Aer.get_backend('statevector_simulator')

    # Transpile the circuit for the simulator
    transpiled_circuit = transpile(main_circuit, simulator)

    # Run the simulation and get the statevector
    result = simulator.run(transpiled_circuit).result()
    statevector = Statevector(result.get_statevector())

    # Get the Pauli-Z operator for the ancilla qubit
    pauli_z = np.array([[1, 0], [0, -1]])

    ancilla_index = main_circuit.find_bit(ancilla_register[0]).index
    
    # Apply the Pauli-Z operator to the ancilla qubit in the statevector
    z_expectation_value = np.real(statevector.expectation_value(Operator(pauli_z), [ancilla_index]))
    population_fraction = (z_expectation_value + 1) / 2
    # Append the result to the file in a new line, tab-separated
    with open(filename, "a") as file:
        file.write(f"{tau}\t{population_fraction}\n")

# Trace out Position Register from Ancilla, Comparison registers

In [8]:
def isolate_position_density_matrix(statevector, ancilla_indices, comparison_indices):
    # Convert statevector to a density matrix
    rho = DensityMatrix(statevector)

    # Perform the partial trace over the ancilla and comparison registers
    reduced_rho = partial_trace(rho, ancilla_indices + comparison_indices)

    return reduced_rho  # Return the density matrix directly

# Record Statevectors over Time for Animation

In [9]:
def record_statevectors_over_time(main_circuit, tau, filename="statevectors_over_time.txt"):
    # Set up the statevector simulator backend
    simulator = Aer.get_backend('statevector_simulator')

    # Transpile the circuit for the simulator
    transpiled_circuit = transpile(main_circuit, simulator)

    # Run the simulation and get the full statevector
    result = simulator.run(transpiled_circuit).result()
    statevector = Statevector(result.get_statevector())

    # Extract the qubit registers
    position_register = main_circuit.qregs[0]  # Assuming position_register is the first qreg
    ancilla_register = main_circuit.qregs[1]   # Adjust indices if needed
    comparison_register = main_circuit.qregs[2]

    # Get the qubit indices for each register
    position_indices = [main_circuit.find_bit(q).index for q in position_register]
    ancilla_indices=tuple(range(QUBITS_NUM, QUBITS_NUM + ANCILLA_QUBITS))
    comparison_indices=tuple(range(QUBITS_NUM + ANCILLA_QUBITS, total_qubits))

    # Isolate the density matrix for the position register
    reduced_rho = isolate_position_density_matrix(
        statevector, ancilla_indices, comparison_indices
    )

    # Extract the real part of the diagonal elements (probabilities for each state)
    statevector_real = np.real(np.diag(reduced_rho.data))

    N = len(statevector_real)
    position_range = np.linspace(-d, d, N)  # Map states to [-10, 10]
    
    with open(filename, "a") as file:
        file.write(f"{tau}\t")
        file.write("\t".join(f"{component:.6f}" for component in statevector_real))
        file.write("\n")

# Record Momentum Distribution after CQFT

In [10]:
def record_momentum_distribution(circuit, position_register, tau, filename="momentum_distribution.txt"):
    simulator = Aer.get_backend('statevector_simulator')
    transpiled_circuit = transpile(circuit, simulator)
    result = simulator.run(transpiled_circuit).result()
    statevector = Statevector(result.get_statevector())

    ancilla_indices=tuple(range(QUBITS_NUM, QUBITS_NUM + ANCILLA_QUBITS))
    comparison_indices=tuple(range(QUBITS_NUM + ANCILLA_QUBITS, total_qubits))
    
    # Get reduced density matrix for position register
    rho_position = isolate_position_density_matrix(
        statevector,
        ancilla_indices=ancilla_indices,
        comparison_indices=comparison_indices
    )
    
    # Get probabilities
    probabilities = np.real(np.diag(rho_position.data))
    probabilities = np.flip(probabilities)
    
    # Write to file: time followed by all probabilities
    with open(filename, "a") as file:
        file.write(f"{tau}\t" + "\t".join(f"{prob:.6f}" for prob in probabilities) + "\n")

# Calculate Potentials for Visualization

In [11]:
def calculate_potential_from_bits(x_values, QUBITS_NUM, d, Δ_x, V1_strength, offset, ħ, vertical_off = diabaticity):
    β = (-d + (Δ_x / 2) - offset)/(Δ_x * QUBITS_NUM)
    γ = V1_strength * (Δ_x)**2 / ħ
    
    potentials = np.zeros_like(x_values)
    
    for idx, x in enumerate(x_values):
        # Convert position to binary representation
        normalized_pos = (x + d) / (2 * d)  # Map from [-d,d] to [0,1]
        if 0 <= normalized_pos <= 1:
            binary_val = int(normalized_pos * (2**QUBITS_NUM))
            binary_seq = [(binary_val >> j) & 1 for j in range(QUBITS_NUM)]
            
            # Calculate potential using the exponential formula
            potential = 0
            for j in range(QUBITS_NUM):
                for l in range(QUBITS_NUM):
                    k_j = binary_seq[j]
                    k_l = binary_seq[l]
                    term = γ * (k_j * 2**j + β) * (k_l * 2**l + β) 
                    potential += term
            potentials[idx] = potential + vertical_off
        else:
            potentials[idx] = np.nan  # Outside valid range
            
    return potentials

# Initialize Gaussian Wavepacket

In [12]:
def initialize_gaussian_wavepacket(x_0, p_0, δ):
    # Initialize empty state vector
    psi = np.zeros(N, dtype=complex)
    
    # Correct normalization factor for Gaussian wavepacket
    norm_factor = (1/(2*np.pi*δ**2))**(1/4)
    print(f"Initial norm factor: {norm_factor}")
    
    # Calculate wavefunction for each position
    for position in range(N):
        x = -d + position * Δ_x
        
        # Gaussian envelope
        gaussian = np.exp(-(x - x_0)**2 / (4*δ**2))
        psi[position] = gaussian
    
    # Apply normalization factor
    psi = norm_factor * psi
    
    # Check intermediate normalization
    intermediate_norm = np.sqrt(np.sum(np.abs(psi)**2 * Δ_x))
    print(f"Norm after applying norm_factor: {intermediate_norm}")
    
    # Discrete normalization for quantum circuit
    final_norm = np.sqrt(np.sum(np.abs(psi)**2))
    psi = psi / final_norm
    
    # Verify final normalization
    verification_norm = np.sqrt(np.sum(np.abs(psi)**2))
    print(f"Final verification norm: {verification_norm}")
    
    # Print some wavefunction values for verification
    print(f"Max amplitude: {np.max(np.abs(psi))}")
    print(f"Wavefunction at x_0: {psi[int((x_0 + d)/Δ_x)]}")
    
    return psi

# Debugging function: record high bit state

In [13]:
def record_high_bit_state(circuit, position_register, tau, filename="high_bit_states.txt"):
    # Set up the statevector simulator
    simulator = Aer.get_backend('statevector_simulator')
    
    # Transpile the circuit
    transpiled_circuit = transpile(circuit, simulator)
    
    # Run the simulation and get the statevector
    result = simulator.run(transpiled_circuit).result()
    statevector = Statevector(result.get_statevector())
    
    # Get the Pauli-Z operator for the high bit
    pauli_z = np.array([[1, 0], [0, -1]])
    
    # Get index of the highest bit in position register
    high_bit_index = circuit.find_bit(position_register[QUBITS_NUM - 1]).index
    
    # Calculate expectation value of Z for highest bit
    z_expectation = np.real(statevector.expectation_value(Operator(pauli_z), [high_bit_index]))
    
    # Convert to probability of being in |1⟩ state: (1 - ⟨Z⟩)/2
    prob_one = (1 - z_expectation) / 2
    
    # Append results to file
    with open(filename, "a") as file:
        file.write(f"{tau}\t{prob_one}\n")

# Get Expectation Values Over Time

In [14]:
def calculate_expectation_value(probabilities, positions, box_size=20):
    # Change box_size parameter and position mapping
    x_min = -d
    x_max = d
    box_size = x_max - x_min
    scaling = box_size / len(positions)
    real_positions = np.linspace(x_min, x_max, len(positions))
    return -np.sum(probabilities * real_positions)

# Trial: First Potential Surface, Applied Alone, Not controlled by Ancilla

In [15]:
def zeroth_order_operations(angle, circuit, ancilla, target):
    circuit.cp(angle, ancilla, target)
    circuit.cx(ancilla, target)
    circuit.cp(angle, ancilla, target)
    circuit.cx(ancilla, target)

In [16]:
def first_order_operations(angle, circuit, ancilla, position_register):
    for qubit in range(QUBITS_NUM):
        bit_order = qubit
        position_scaling = 2**bit_order
        target = position_register[bit_order]
        circuit.cp(angle * position_scaling, ancilla, target)

In [17]:
def second_order_operations(angle, circuit, ancilla, intermediate, position_register):
    for control in range(QUBITS_NUM):
        bit_order = 2*(control)
        position_scaling = 2**(bit_order)
        circuit.cp(angle*position_scaling, ancilla, position_register[control])
        for target in range(QUBITS_NUM):
            if target != control:
                bit_order = (control) + (target)
                position_scaling = 2**(bit_order)
                circuit.ccx(ancilla, position_register[control], intermediate)
                circuit.cp(angle*position_scaling, intermediate, position_register[target])
                circuit.ccx(ancilla, position_register[control], intermediate)

# Kinetic Term

In [18]:
def apply_kinetic_term(circuit, position_register, QUBITS_NUM, τ):
    β = (-Nyquist - p_0)/(Δ_p)
    γ = (Δ_p)**2 / (2*m*ħ)
    
    θ_1 = -τ*(γ*β**2)
    θ_2 = -2*τ*γ*β
    θ_3 = -τ*γ
    
    # Apply quadratic phase
    circuit.p(θ_1, position_register[0])
    circuit.x(position_register[0])
    circuit.p(θ_1, position_register[0])
    circuit.x(position_register[0])

    # Apply linear phase for momentum shift
    for qubit in range(QUBITS_NUM):
        bit_order = qubit
        position_scaling = 2**bit_order
        circuit.p(θ_2 * position_scaling, position_register[bit_order])

    # Apply kinetic energy operator
    for control in range(QUBITS_NUM):
        bit_order = 2*(control)
        position_scaling = 2**(bit_order)
        circuit.p(θ_3*position_scaling, position_register[control])
        for target in range(QUBITS_NUM):
            if target != control:
                bit_order = (control) + (target)
                position_scaling = 2**(bit_order)
                circuit.cp(θ_3*position_scaling, position_register[control], position_register[target])

    return circuit

# Marcus Model: Two Harmonic Potentials

In [19]:
def HarmonicPotential(circuit, position_register, ancilla_register, τ, potentialnum):
    ancilla_control = ancilla_register[0]
    
    if potentialnum == 1:
        circuit.x(ancilla_control)
        VertOffset = diabaticity
        HorizOffset = offset
        Strength = V1_strength
    elif potentialnum == 2: 
        HorizOffset = offset_2
        VertOffset = 0.0
        Strength = V2_strength

    β = (-d - HorizOffset + (Δ_x)/2)/(Δ_x)
    γ = Strength * (Δ_x)**2 / (ħ)
    
    θ_1 = -τ*(γ*β**2 + VertOffset)
    θ_2 = -2*τ*γ*β
    θ_3 = -τ*γ

    # Apply quadratic phase
    zeroth_order_operations(θ_1, circuit, ancilla_control, position_register[0])

    # Apply linear phase for position shift
    first_order_operations(θ_2, circuit, ancilla_control, position_register)

    # Apply potential energy operator
    second_order_operations(θ_3, circuit, ancilla_control, ancilla_register[1], position_register)

    if potentialnum == 1:
        circuit.x(ancilla_control)

    return circuit

# Coupling Potential Terms

In [20]:
def bigger_step_coupling(circuit, position_register, ancilla_register, comparison_register):
    ctrl1 = comparison_register[0]
    ctrl2 = comparison_register[1]
    target = ancilla_register[0]
    circuit.x(ctrl1)
    ccrx(circuit, step_phase, ctrl1, ctrl2, target)
    circuit.x(ctrl1)

# Define Breakpoint Checks

In [21]:
def lessthanlogic0(position_register, ancilla_register, comparison_register):
    logic = []
    logic.extend(OR(position_register[0], position_register[1], ancilla_register[1]))
    logic.extend(OR(position_register[2], position_register[3], ancilla_register[2]))
    logic.extend(OR(ancilla_register[1], ancilla_register[2], ancilla_register[3]))
    logic.extend(invOR(position_register[2], position_register[3], ancilla_register[2]))
    logic.extend(invOR(position_register[0], position_register[1], ancilla_register[1]))
    logic.extend(OR(position_register[4], position_register[5], ancilla_register[1]))
    logic.extend(OR(ancilla_register[1], ancilla_register[3], ancilla_register[2]))
    logic.extend(invOR(position_register[4], position_register[5], ancilla_register[1]))
    logic.extend(OR(position_register[6], ancilla_register[2], ancilla_register[1]))
    
    logic.extend(AND(position_register[7], ancilla_register[1], comparison_register[0]))
    
    logic.extend(invOR(position_register[6], ancilla_register[2], ancilla_register[1]))
    logic.extend(OR(position_register[4], position_register[5], ancilla_register[1]))
    logic.extend(invOR(ancilla_register[1], ancilla_register[3], ancilla_register[2]))
    logic.extend(invOR(position_register[4], position_register[5], ancilla_register[1]))
    logic.extend(OR(position_register[0], position_register[1], ancilla_register[1]))
    logic.extend(OR(position_register[2], position_register[3], ancilla_register[2]))
    logic.extend(invOR(ancilla_register[1], ancilla_register[2], ancilla_register[3]))
    logic.extend(invOR(position_register[2], position_register[3], ancilla_register[2]))
    logic.extend(invOR(position_register[0], position_register[1], ancilla_register[1]))  
    return logic

def lessthanlogic1(position_register, ancilla_register, comparison_register):
    logic = []
    logic.extend(AND(position_register[0], position_register[1], ancilla_register[1]))
    logic.extend(AND(position_register[2], position_register[3], ancilla_register[2]))
    logic.extend(AND(ancilla_register[1], ancilla_register[2], ancilla_register[3]))
    logic.extend(invAND(position_register[2], position_register[3], ancilla_register[2]))
    logic.extend(invAND(position_register[0], position_register[1], ancilla_register[1]))
    logic.extend(AND(position_register[4], position_register[5], ancilla_register[1]))
    logic.extend(AND(ancilla_register[1], ancilla_register[3], ancilla_register[2]))
    logic.extend(invAND(position_register[4], position_register[5], ancilla_register[1]))
    logic.extend(AND(position_register[6], ancilla_register[2], ancilla_register[1]))
    logic.extend(OR(position_register[7], ancilla_register[1], comparison_register[1]))
    return logic

In [22]:
def logics(position_register, ancilla_register, comparison_register):
    logic0 = lessthanlogic0(position_register, ancilla_register, comparison_register)
    logic1 = lessthanlogic1(position_register, ancilla_register, comparison_register)
    all_bps = {numbers[0]: logic0, numbers[1]: logic1}  # Use a dictionary with breakpoint numbers as keys
    return all_bps

def create_breakpoint_operations(number, logics, position_register, ancilla_register, comparison_register):
    return logics.get(number, [])  # Get operations for this breakpoint number

def create_combined_operations(numbers, n_bits, logics, position_register, ancilla_register, comparison_register):
    all_operations = []
    explanations = {}
    
    # First number
    first_number = numbers[0]
    result1 = get_less_than_operations(first_number, n_bits)
    explanations[first_number] = result1["explanation"]
    
    # X gates for first number
    x_gate_operations1 = []
    for bit_pos in result1["x_gate_positions"]:
        x_gate_operations1.append((XGate(), [position_register[bit_pos]]))
    
    # Logic operations for first number
    bp_operations1 = create_breakpoint_operations(first_number, logics, position_register, ancilla_register, comparison_register)
    
    # Add first number operations
    all_operations.extend(x_gate_operations1)
    all_operations.extend(bp_operations1)
    
    # Add undo X gates for first number
    for bit_pos in result1["x_gate_positions"]:
        all_operations.append((XGate(), [position_register[bit_pos]]))
    
    # Second number
    second_number = numbers[1]
    result2 = get_less_than_operations(second_number, n_bits)
    explanations[second_number] = result2["explanation"]
    
    # X gates for second number
    x_gate_operations2 = []
    for bit_pos in result2["x_gate_positions"]:
        x_gate_operations2.append((XGate(), [position_register[bit_pos]]))
    
    # Logic operations for second number
    bp_operations2 = create_breakpoint_operations(second_number, logics, position_register, ancilla_register, comparison_register)
    
    # Add second number operations
    all_operations.extend(x_gate_operations2)
    all_operations.extend(bp_operations2)
    
    return all_operations, explanations

def check_all_breakpoints(main_circuit, operations):
    for gate, qubits in operations:
        main_circuit.append(gate, qubits)

def uncheck_all_breakpoints(main_circuit, operations):
    for gate, qubits in reversed(operations):
        main_circuit.append(gate, qubits)

In [23]:
def apply_x_gates(main_circuit, bit_positions, position_register):
    for bit_pos in bit_positions:
        main_circuit.append(XGate(), [position_register[bit_pos]])

def apply_comparison_logic(main_circuit, number, logic):
    bp_operations = logic.get(number, [])
    for gate, qubits in bp_operations:
        main_circuit.append(gate, qubits)

def handle_comparison(main_circuit, number, position_register, logic, undo_x_gates=False):
    # Get comparison operations for this number
    result = get_less_than_operations(number, QUBITS_NUM)
    
    # Apply X gates
    x_positions = result["x_gate_positions"]
    apply_x_gates(main_circuit, x_positions, position_register)
    
    # Apply comparison logic
    apply_comparison_logic(main_circuit, number, logic)
    
    # Undo X gates if requested
    if undo_x_gates:
        apply_x_gates(main_circuit, x_positions, position_register)
    
    return result["explanation"], x_positions

def comparator_circuit(numbers, logic, position_register, ancilla_register, comparison_register, main_circuit):
    operations, explanations = create_combined_operations(
        numbers, QUBITS_NUM, logic, position_register, ancilla_register, comparison_register
    )
    check_all_breakpoints(main_circuit, operations)
    bigger_step_coupling(main_circuit, position_register, ancilla_register, comparison_register)
    uncheck_all_breakpoints(main_circuit, operations)
    
    return explanations

# Composed Comparator Circuit (Wrapper for breakpoints and piecewise function)

# Create Animation

In [24]:
def create_quantum_evolution_animation(
    statevector_filename, 
    ancilla_filename, 
    momentum_filename,
    QUBITS_NUM, 
    timestep,
    trial
):
    # Load all data
    position_data = np.loadtxt(statevector_filename, delimiter='\t')
    time = position_data[:, 0]
    positions = position_data[:, 1:]
    
    population_data = np.loadtxt(ancilla_filename, delimiter='\t')
    time_population = population_data[:, 0]
    population_fraction = population_data[:, 1]
    
    momentum_data = np.loadtxt(momentum_filename, delimiter='\t')
    time_momentum = momentum_data[:, 0]
    momentum_distributions = momentum_data[:, 1:]
    
    # Truncate data to the minimum number of frames
    num_frames = min(len(time), len(time_population), len(time_momentum))
    time = time[:num_frames]
    positions = positions[:num_frames]
    population_fraction = population_fraction[:num_frames]
    momentum_distributions = momentum_distributions[:num_frames]
    
    # Calculate position expectation values
    position_values = np.linspace(-d, d, N)
    expectation_values = [calculate_expectation_value(positions[i], position_values) 
                         for i in range(num_frames)]
    
    # Setup plots
    fig, (ax1, ax3, ax4, ax2) = plt.subplots(4, 1, figsize=(8, 16))
    
    # Calculate potentials
    potential_energies_1 = calculate_potential_from_bits(
        position_values, 
        QUBITS_NUM, 
        d, 
        Δ_x, 
        V1_strength,
        offset,
        ħ,
        vertical_off = diabaticity
    )
    # Plot potentials
    potential_line_1 = ax1.plot(position_values, potential_energies_1, 
                             'r-', alpha=0.7, label='Potential 1')

    # Find min and max values for all elements
    all_elements = [positions, potential_energies_1]
    y_min = min(np.nanmin(c) for c in all_elements)
    y_max = max(np.nanmax(c) for c in all_elements)
    
    # Add some padding to the limits
    padding = max(0.1, (y_max - y_min) * 0.1)
    ax1.set_ylim(np.min(positions), np.max(positions)*1.1)
    
    # Add a legend with better positioning
    ax1.legend(loc='best', fontsize=8)
    
    # First subplot - Position distribution
    ax1.set_xlabel('Position')
    ax1.set_ylabel('Probability Amplitude / Coupling')
    bar_width = (position_values[1] - position_values[0]) * 0.8
    bars = ax1.bar(position_values, positions[0], width=bar_width, color='blue', alpha=0.6)
    ax1.set_xlim(-d, d)
    ax1.set_title(f'{QUBITS_NUM} Qubits, Quantum Evolution')
    time_box = ax1.text(
        0.02, 0.95, '', transform=ax1.transAxes, fontsize=14,
        verticalalignment='top', horizontalalignment='left',
        color='white', bbox=dict(facecolor='black', edgecolor='none', boxstyle='round,pad=0.5')
    )

    # Population fraction subplot
    ax2.set_xlim(np.min(time), np.max(time))
    ax2.set_ylim(np.min(population_fraction)-0.05, np.max(population_fraction)+0.05)
    ax2.set_xlabel('Time (a.u.)')
    ax2.set_ylabel('Population Fraction $V_0$')
    line, = ax2.plot([], [], 'o-', color='purple', markersize=5)

    # Density plot with expectation value
    ax3.set_xlabel('Time')
    ax3.set_ylabel('Position')
    ax3.invert_yaxis()
    density_data = np.zeros_like(positions.T)
    
    density_plot = ax3.imshow(density_data,
        aspect='auto',
        extent=[np.min(time), np.max(time), -10, 10],
        cmap='gray_r',
        interpolation='nearest',
        vmin=0,
        vmax=np.max(positions)
    )
    plt.colorbar(density_plot, ax=ax3, label='Probability Amplitude')
    expectation_line, = ax3.plot([], [], 'r-', linewidth=2)

    # Momentum distribution subplot (using quantum data)
    ax4.set_xlabel('Momentum')
    ax4.set_ylabel('Probability Density')
    # Calculate momentum range based on Nyquist frequency
    momentum_values = np.linspace(-Nyquist, Nyquist, 2**QUBITS_NUM)
    ax4.set_xlim(np.min(momentum_values), np.max(momentum_values))
    ax4.set_ylim(0, np.max(momentum_distributions) * 1.1)
    momentum_line, = ax4.plot(momentum_values, momentum_distributions[0], color='purple')

    def update(frame):
        # Update position distribution
        for i, bar in enumerate(bars):
            bar.set_height(positions[frame, i])
        
        time_box.set_text(f't = {time[frame]:.1f}')
        
        # Update population fraction
        xdata = time[:frame + 1]
        ydata = population_fraction[:frame + 1]
        line.set_data(xdata, ydata)
        
        # Update expectation value line
        xdata_exp = time[:frame + 1]
        ydata_exp = expectation_values[:frame + 1]
        expectation_line.set_data(xdata_exp, ydata_exp)
        
        # Update density plot
        current_density = positions.T.copy()
        current_density[:, frame+1:] = np.nan  # Set future times to NaN
        density_plot.set_array(current_density)
        
        # Update momentum distribution using quantum data
        momentum_line.set_ydata(momentum_distributions[frame])
        
        # Include all animated elements
        return ([bar for bar in bars] + [line, time_box, expectation_line, momentum_line] + 
                list(potential_line_1))

    # Create animation
    ani = animation.FuncAnimation(
        fig, update, frames=num_frames, blit=True, interval=125
    )

    # Try different animation writers
    try:
        print("Attempting to save with Pillow writer...")
        animation_filename = f"{current_date}_mp4_{QUBITS_NUM}q_quantum_evolution_timestep{timestep}_t{trial}_alpha{diabaticity}.gif"
        ani.save(animation_filename, writer='pillow', fps=4)
    except Exception as e:
        print(f"Error with Pillow writer: {e}")
        try:
            print("Attempting to save with imagemagick writer...")
            ani.save(animation_filename, writer='imagemagick', fps=4)
        except Exception as e:
            print(f"Error with imagemagick writer: {e}")
            print("Could not save animation. Displaying only...")
    
    plt.show()
    return ani

# Main Function

In [1]:
def main():

    for filename in [statevector_filename, ancilla_filename, momentum_filename]:
        with open(filename, 'w') as f:
            f.write('')

    current_time = start_time
    τ = timestep
    max_time = stop_time
    total_qubits = QUBITS_NUM + ANCILLA_QUBITS + COMPARISON_QUBITS

    # Initialize only once at the start
    initial_state = initialize_gaussian_wavepacket(x_0=x_0, p_0=p_0, δ=δ)
    psi = np.zeros(2**total_qubits, dtype=complex)
    psi[:2**QUBITS_NUM] = initial_state

    # Create single circuit for entire evolution
    circuit = QuantumCircuit(position_register, ancilla_register, comparison_register)
    all_qubits = list(position_register) + list(ancilla_register) + list(comparison_register)
    circuit.initialize(psi, all_qubits)
    
    ancilla_start_idx = QUBITS_NUM
    statevectors_evolved = []

    # Initial state measurement
    simulator = Aer.get_backend('statevector_simulator')
    initial_result = simulator.run(circuit).result()
    initial_state = Statevector(initial_result.get_statevector())
    
    # Check initial position register normalization
    initial_rho = isolate_position_density_matrix(
        initial_state,
        ancilla_indices=tuple(range(ancilla_start_idx, ancilla_start_idx + ANCILLA_QUBITS)),
        comparison_indices=tuple(range(ancilla_start_idx + ANCILLA_QUBITS, total_qubits))
    )
    initial_pos_prob = np.sum(np.real(np.diag(initial_rho.data)))
    print(f"Time {current_time:.1f}: Initial position register probability sum = {initial_pos_prob:.10f}")

    breakpoint_logic = logics(position_register, ancilla_register, comparison_register)

    while current_time < max_time:
        # Evolution steps
        cqft(circuit, position_register, QUBITS_NUM)
        record_momentum_distribution(circuit, position_register, current_time, filename=momentum_filename)
        apply_kinetic_term(circuit, position_register, QUBITS_NUM, τ)
        ciqft(circuit, position_register, QUBITS_NUM)

        HarmonicPotential(circuit, position_register, ancilla_register, τ, 1)
        HarmonicPotential(circuit, position_register, ancilla_register, τ, 2)
        
        comparator_circuit(numbers, breakpoint_logic, position_register, ancilla_register, comparison_register, circuit)

        # Get intermediate state without breaking circuit
        current_result = simulator.run(circuit).result()
        current_state = Statevector(current_result.get_statevector())
    
        # Your debug measurements
        record_ancilla_state_z_basis_single_tau(circuit, ancilla_register, current_time, filename=ancilla_filename)

        # Still print the probabilities for console output
        ancilla_state = current_state.probabilities([QUBITS_NUM])
        #Printout to track evolution by eye. 
        print(f"Time {current_time:.1f}: After comparator Ancilla[0] probabilities - |0⟩: {ancilla_state[0]:.6f}, |1⟩: {ancilla_state[1]:.6f}")
        
        # Your existing density matrix calculation
        rho_position = isolate_position_density_matrix(
            current_state,
            ancilla_indices=tuple(range(ancilla_start_idx, ancilla_start_idx + ANCILLA_QUBITS)),
            comparison_indices=tuple(range(ancilla_start_idx + ANCILLA_QUBITS, total_qubits))
        )
    
        state_vector = np.real(np.diag(rho_position.data))
        with open(statevector_filename, 'a') as f:
            vector_string = '\t'.join([f"{x:.10f}" for x in state_vector])
            f.write(f"{current_time:.1f}\t{vector_string}\n")
        
        statevectors_evolved.append(rho_position.data)
        current_time += τ

    return statevectors_evolved

if __name__ == '__main__':
    statevectors_evolved = main()
    create_quantum_evolution_animation(
        statevector_filename=statevector_filename,
        ancilla_filename=ancilla_filename,
        momentum_filename=momentum_filename,
        QUBITS_NUM=QUBITS_NUM,
        timestep=timestep,
        trial=trial
    )

NameError: name 'statevector_filename' is not defined