# Quantum Walk Simulation

During our research, I noticed a use case for quantum walk techniques to model the propagation of risk in networked systems. Quantum walks offer an intriguing quantum-mechanical analog to classical random walks, exhibiting interference patterns that can be useful in simulating complex dynamic processes. 

In this notebook, we implement a discrete-time quantum walk on a cycle. Our model uses:
- **One coin qubit:** Determines the direction of the walk.
- **A position register:** Consisting of 2 qubits (for 4 positions).

At each step, we apply:
1. A coin toss (Hadamard gate) on the coin qubit.
2. A conditional shift operator that moves the walker forward or backward depending on the coin state.

Below, you'll find the code that builds the circuit, runs simulations (both QASM and statevector), and prints diagnostic information to help verify that the walk is behaving as expected.

In [None]:
# Quantum walk diagnostic module

import numpy as np
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit.circuit.library import UnitaryGate
from qiskit_aer import Aer
from qiskit.quantum_info import Statevector

def shift_plus(N):
    """Return the N×N shift-plus matrix: |i> -> |(i+1) mod N>."""
    S = np.zeros((N, N))
    for i in range(N):
        S[i, (i+1)%N] = 1
    return S

def shift_minus(N):
    """Return the N×N shift-minus matrix: |i> -> |(i-1) mod N>."""
    S = np.zeros((N, N))
    for i in range(N):
        S[i, (i-1)%N] = 1
    return S

def build_shift_operator(num_position_qubits):
    """
    Build the composite shift operator for a quantum walk on a cycle.
    Uses one coin qubit and a position register with dimension N = 2^(num_position_qubits).
    """
    N = 2**num_position_qubits
    S_plus = shift_plus(N)
    S_minus = shift_minus(N)
    # Coin projectors: when coin is |0>, shift forward; when |1>, shift backward.
    P0 = np.array([[1, 0], [0, 0]])
    P1 = np.array([[0, 0], [0, 1]])
    S = np.kron(P0, S_plus) + np.kron(P1, S_minus)
    return S

def build_quantum_walk_circuit(num_steps, num_position_qubits=2, add_measurements=True):
    """
    Build a discrete-time quantum walk circuit on a cycle.
    - 1 coin qubit.
    - A position register with num_position_qubits qubits.
    
    Each step applies a coin toss (Hadamard) followed by the shift operator.
    """
    coin = QuantumRegister(1, 'coin')
    position = QuantumRegister(num_position_qubits, 'pos')
    coin_cl = ClassicalRegister(1, 'coin_cl')
    pos_cl = ClassicalRegister(num_position_qubits, 'pos_cl')
    
    qc = QuantumCircuit(coin, position, coin_cl, pos_cl)
    
    # Initialize coin in superposition
    qc.h(coin)
    
    # Apply the quantum walk steps
    for step in range(num_steps):
        qc.h(coin)  # coin toss
        S = build_shift_operator(num_position_qubits)
        S_gate = UnitaryGate(S, label=f"Shift_{step+1}")
        # Append S_gate on coin+position (order: coin then position)
        qc.append(S_gate, coin[:] + position[:])
    
    if add_measurements:
        qc.measure(coin, coin_cl)
        qc.measure(position, pos_cl)
    
    return qc

def run_qasm_simulation(qc, shots=1024):
    """Run QASM simulation and return measurement counts."""
    simulator = Aer.get_backend('aer_simulator')
    compiled = transpile(qc, simulator)
    result = simulator.run(compiled, shots=shots).result()
    return result.get_counts()

def run_statevector(qc):
    """Run a statevector simulation (without measurements) and return probabilities."""
    qc_no_meas = qc.remove_final_measurements(inplace=False)
    state = Statevector.from_instruction(qc_no_meas)
    return state.probabilities_dict()

def diagnostic_quantum_walk(num_steps, num_position_qubits=2, shots=1024):
    """
    Build and simulate a quantum walk circuit, printing detailed diagnostics:
    - The final circuit (text format)
    - QASM measurement counts
    - Final statevector probabilities
    - Intermediate statevector distributions after each step.
    """
    # Build the full circuit with measurements
    qc_full = build_quantum_walk_circuit(num_steps, num_position_qubits, add_measurements=True)
    print("=== Final Quantum Walk Circuit ===")
    print(qc_full.draw(output='text'))
    
    # QASM simulation with measurements
    counts = run_qasm_simulation(qc_full, shots=shots)
    print("\n=== QASM Measurement Counts ===")
    print(counts)
    
    # Statevector simulation (without measurements)
    sv_probs = run_statevector(qc_full)
    print("\n=== Final Statevector Probabilities (without measurements) ===")
    print(sv_probs)
    
    # Intermediate diagnostics: simulate each step
    print("\n=== Intermediate Step Diagnostics ===")
    coin = QuantumRegister(1, 'coin')
    pos = QuantumRegister(num_position_qubits, 'pos')
    qc_base = QuantumCircuit(coin, pos)
    qc_base.h(coin)  # Initial coin superposition
    
    for step in range(num_steps):
        qc_step = qc_base.copy()
        S = build_shift_operator(num_position_qubits)
        S_gate = UnitaryGate(S, label=f"Shift_{step+1}")
        qc_step.append(S_gate, coin[:] + pos[:])
        sv = Statevector.from_instruction(qc_step)
        print(f"\nAfter step {step+1}:")
        print("Statevector:")
        print(sv)
        print("Probability distribution:")
        print(sv.probabilities_dict())
        qc_base = qc_step  # Update for next step
    
    return qc_full, counts, sv_probs

# Run diagnostics
if __name__ == "__main__":
    num_steps = 5
    num_position_qubits = 2  # 4 positions
    shots = 1024
    
    qc_final, qasm_counts, final_sv = diagnostic_quantum_walk(num_steps, num_position_qubits, shots=shots)
    print("\n=== End of Diagnostics ===")
