# 5 - Quantum error correction inspired by classical codes

## Core Task 5.1 - Explore QEC codes inspired by classical codes

[TODO, 70pts]:

Classical error correcting codes provide a natural and powerful pathway to constructing quantum codes by directly translating classical parity checks into quantum stabilizer measurements. In particular, any linear classical code can be mapped to a quantum code that detects and corrects **bit flip (X) errors** by promoting each classical parity check into a multi-qubit ( Z )-type stabilizer. In this construction, classical codewords become logical quantum states, and the syndrome extraction process is identical in spirit to classical decoding. This approach is especially well suited for hardware with **strong noise bias**, where one error channel dominates. In our case, biased cat qubits exponentially suppress phase flip errors, leaving bit flips as the primary failure mode. As a result, we can focus entirely on X-error correction, allowing us to use a much wider and more efficient family of classical codes than would be possible for fully general quantum noise.

The final and core challenge is to choose any classical error correcting code (or family of codes), translate it into its quantum counterpart, and benchmark it against the repetition code that you already implemented. You will simulate the resulting quantum code in **Stim**, extract syndromes, perform decoding, and compare key performance metrics such as logical error rate versus number of physical qubits at a fixed physical error rate, encoding efficiency ( k/n ), the effective distance of the code and required hardware connectivity (i.e. what two-qubit gates are needed). This exploration will show how classical coding theory can be directly leveraged to design quantum codes that outperform simple repetition strategies when the noise is strongly biased.

*Optionally*, only if time permits, you may wish to demonstrate a universal, fault-tolerant set of logical gates for your code, starting with the Clifford group and extending to non-Clifford gates.


Please refer to `./2-classical-to-quantum-codes.ipynb` for a step-by-step introduction to translating a classical code into a quantum bit-flipâ€“correcting code, along with a curated (but not exhaustive) list of classical code families to use as inspiration. You should consider this notebook required reading for the core task in this challenge.

This is an open-ended challenge, judged by the criteria specified in the `README.md` doc. 

In [15]:
import stim
import numpy as np
from typing import Dict, Tuple

# Hamming (7,4) parity check matrix
H_hamming = np.array([
    [1,0,1,0,1,0,1],
    [0,1,1,0,0,1,1],
    [0,0,0,1,1,1,1]
], dtype=int)

def concatenated_hamming_circuit(p: float = 0.05):
    """
    Concatenated Hamming code circuit (7,4 outer, 7,4 inner) for X-error correction.
    """
    # Outer code length
    n_outer = 7
    # Inner code length
    n_inner = 7
    # Total physical qubits
    n_data = n_outer * n_inner  # 49 data qubits
    # Total stabilizers per outer block
    n_stab_per_block = H_hamming.shape[0]  # 3 stabilizers per block
    n_stabilizer = n_stab_per_block * n_outer  # 21 total stabilizers
    
    c = stim.Circuit()
    
    # Qubit layout: interleave data and measure qubits
    # data qubits at even indices, measure qubits at odd
    data_qubits = [2*i for i in range(n_data)]
    measure_qubits = [2*i + 1 for i in range(n_stabilizer)]
    
    # Initialize all qubits to |0>
    c.append("R", data_qubits + measure_qubits)
    
    c.append("TICK")
    
    # Apply X errors on data qubits
    for q in data_qubits:
        c.append("X_ERROR", [q], p)
    
    c.append("TICK")
    
    # Measure stabilizers for each outer code block
    measure_index = 0
    for outer_idx in range(n_outer):
        block_start = outer_idx * n_inner
        for row_idx, row in enumerate(H_hamming):
            # Find which qubits participate in this stabilizer
            participating_data_qubits = [data_qubits[block_start + i] 
                                         for i, bit in enumerate(row) if bit == 1]
            anc = measure_qubits[measure_index]
            
            # Apply CNOTs from data qubits to ancilla
            for q in participating_data_qubits:
                c.append("CX", [q, anc])
            
            measure_index += 1
    
    c.append("TICK")
    
    # Measure all ancilla qubits
    c.append("MR", measure_qubits)
    
    # Add detectors for each stabilizer
    for i in range(n_stabilizer):
        c.append("DETECTOR", [stim.target_rec(-(n_stabilizer) + i)])
    
    c.append("TICK")
    
    # Measure all data qubits
    c.append("M", data_qubits)
    
    # Define logical observable - just use first data qubit of first block as logical Z
    c.append("OBSERVABLE_INCLUDE", [stim.target_rec(-n_data)], 0)
    
    return c, n_data, n_stabilizer


def simulate_circuit(circuit: stim.Circuit, n_data: int, n_stabilizer: int, num_shots=10000) -> Dict[Tuple[str,str], int]:
    """
    Simulate circuit and extract measurement results.
    """
    sampler = circuit.compile_sampler()
    shots = sampler.sample(shots=num_shots)
    results = {}
    
    for shot in shots:
        # Measurements come in order: first n_stabilizer ancillas, then n_data data qubits
        synd_bits = ''.join(str(int(shot[i])) for i in range(n_stabilizer))
        data_bits = ''.join(str(int(shot[n_stabilizer + i])) for i in range(n_data))
        key = (data_bits, synd_bits)
        results[key] = results.get(key, 0) + 1
    
    return results


def decode_hamming_block(data_block: list, syndrome_block: list) -> int:
    """
    Decode a single Hamming (7,4) block using syndrome.
    
    Args:
        data_block: list of 7 bits (measurements of data qubits)
        syndrome_block: list of 3 bits (syndrome measurements)
    
    Returns:
        Decoded logical bit (majority vote of corrected data)
    """
    corrected = data_block.copy()
    syndrome = syndrome_block
    
    # If syndrome is all zeros, no error detected
    if sum(syndrome) == 0:
        return 1 if sum(corrected) > 3 else 0
    
    # Find error location by matching syndrome to H columns
    # Each column of H represents the syndrome for an error at that position
    for pos in range(7):
        column = H_hamming[:, pos]
        if np.array_equal(syndrome, column):
            # Found the error position - flip it
            corrected[pos] ^= 1
            break
    
    # Return majority vote
    return 1 if sum(corrected) > 3 else 0


def decode_hamming_concatenated(meas: Tuple[str, str]) -> int:
    """
    Decode concatenated Hamming code:
    - Decode each inner Hamming (7,4) block using its syndrome
    - Perform majority vote on the 7 outer logical qubits
    """
    data_bits, synd_bits = meas
    n_outer = 7
    n_inner = 7
    n_stab_per_block = 3
    
    data = [int(b) for b in data_bits]
    synd = [int(b) for b in synd_bits]
    
    # Decode each of the 7 inner blocks
    outer_logical_bits = []
    for outer_idx in range(n_outer):
        # Extract data and syndrome for this block
        data_start = outer_idx * n_inner
        data_end = data_start + n_inner
        data_block = data[data_start:data_end]
        
        synd_start = outer_idx * n_stab_per_block
        synd_end = synd_start + n_stab_per_block
        syndrome_block = synd[synd_start:synd_end]
        
        # Decode this inner block
        logical_bit = decode_hamming_block(data_block, syndrome_block)
        outer_logical_bits.append(logical_bit)
    
    # Now decode the outer code using majority vote
    # (The outer code is also a Hamming code, but we're just using majority vote)
    return 1 if sum(outer_logical_bits) > n_outer // 2 else 0


def logical_error_rate(results: Dict[Tuple[str,str], int], logical_prepared=0) -> float:
    """
    Compute logical error rate.
    """
    errors = 0
    total = 0
    for (data_bits, synd_bits), count in results.items():
        decoded = decode_hamming_concatenated((data_bits, synd_bits))
        if decoded != logical_prepared:
            errors += count
        total += count
    return errors / total


# Test the code
print("Building concatenated Hamming circuit...")
circuit, n_data, n_stab = concatenated_hamming_circuit(p=0.05)
print(f"Circuit has {n_data} data qubits and {n_stab} stabilizers")
print(f"\nCircuit preview (first 20 lines):")
print('\n'.join(str(circuit).split('\n')[:20]))

print("\nRunning simulation...")
results = simulate_circuit(circuit, n_data, n_stab, num_shots=5000)

print(f"\nNumber of unique measurement outcomes: {len(results)}")
print(f"Most common outcomes:")
for i, (key, count) in enumerate(sorted(results.items(), key=lambda x: -x[1])[:5]):
    data, synd = key
    print(f"  {i+1}. syndrome={synd[:10]}... data={data[:10]}... count={count}")

p_L = logical_error_rate(results, logical_prepared=0)
print(f"\nLogical error rate: {p_L:.6f}")

# Rough analytical estimate
# Inner code can correct 1 error per block (distance 3)
# Probability of >1 error in a block of 7 with p=0.05:
p = 0.05
p_inner_fail = sum(np.math.comb(7, k) * (p**k) * ((1-p)**(7-k)) for k in range(2, 8))
print(f"Expected inner block failure rate: {p_inner_fail:.6f}")
# Outer code needs >3 inner failures
p_outer_fail = sum(np.math.comb(7, k) * (p_inner_fail**k) * ((1-p_inner_fail)**(7-k)) for k in range(4, 8))
print(f"Expected logical error rate (approx): {p_outer_fail:.6f}")

Building concatenated Hamming circuit...
Circuit has 49 data qubits and 21 stabilizers

Circuit preview (first 20 lines):
R 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78 80 82 84 86 88 90 92 94 96 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41
TICK
X_ERROR(0.05) 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78 80 82 84 86 88 90 92 94 96
TICK
CX 0 1 4 1 8 1 12 1 2 3 4 3 10 3 12 3 6 5 8 5 10 5 12 5 14 7 18 7 22 7 26 7 16 9 18 9 24 9 26 9 20 11 22 11 24 11 26 11 28 13 32 13 36 13 40 13 30 15 32 15 38 15 40 15 34 17 36 17 38 17 40 17 42 19 46 19 50 19 54 19 44 21 46 21 52 21 54 21 48 23 50 23 52 23 54 23 56 25 60 25 64 25 68 25 58 27 60 27 66 27 68 27 62 29 64 29 66 29 68 29 70 31 74 31 78 31 82 31 72 33 74 33 80 33 82 33 76 35 78 35 80 35 82 35 84 37 88 37 92 37 96 37 86 39 88 39 94 39 96 39 90 41 92 41 94 41 96 41
TICK
MR 1 3 5 7 9 11 13 15 17 

  p_inner_fail = sum(np.math.comb(7, k) * (p**k) * ((1-p)**(7-k)) for k in range(2, 8))
  p_outer_fail = sum(np.math.comb(7, k) * (p_inner_fail**k) * ((1-p_inner_fail)**(7-k)) for k in range(4, 8))


## Trying with 100,000 simulations instead of 5,000

In [16]:
import stim
import numpy as np
import math
from typing import Dict, Tuple

# Hamming (7,4) parity check matrix
H_hamming = np.array([
    [1,0,1,0,1,0,1],
    [0,1,1,0,0,1,1],
    [0,0,0,1,1,1,1]
], dtype=int)

def concatenated_hamming_circuit(p: float = 0.05):
    """
    Concatenated Hamming code circuit (7,4 outer, 7,4 inner) for X-error correction.
    """
    n_outer = 7
    n_inner = 7
    n_data = n_outer * n_inner  # 49 data qubits
    n_stab_per_block = H_hamming.shape[0]  # 3 stabilizers per block
    n_stabilizer = n_stab_per_block * n_outer  # 21 total stabilizers
    
    c = stim.Circuit()
    
    # Qubit layout: interleave data and measure qubits
    data_qubits = [2*i for i in range(n_data)]
    measure_qubits = [2*i + 1 for i in range(n_stabilizer)]
    
    # Initialize all qubits to |0>
    c.append("R", data_qubits + measure_qubits)
    
    c.append("TICK")
    
    # Apply X errors on data qubits
    for q in data_qubits:
        c.append("X_ERROR", [q], p)
    
    c.append("TICK")
    
    # Measure stabilizers for each outer code block
    measure_index = 0
    for outer_idx in range(n_outer):
        block_start = outer_idx * n_inner
        for row_idx, row in enumerate(H_hamming):
            # Find which qubits participate in this stabilizer
            participating_data_qubits = [data_qubits[block_start + i] 
                                         for i, bit in enumerate(row) if bit == 1]
            anc = measure_qubits[measure_index]
            
            # Apply CNOTs from data qubits to ancilla
            for q in participating_data_qubits:
                c.append("CX", [q, anc])
            
            measure_index += 1
    
    c.append("TICK")
    
    # Measure all ancilla qubits
    c.append("MR", measure_qubits)
    
    # Add detectors for each stabilizer
    for i in range(n_stabilizer):
        c.append("DETECTOR", [stim.target_rec(-(n_stabilizer) + i)])
    
    c.append("TICK")
    
    # Measure all data qubits
    c.append("M", data_qubits)
    
    # Define logical observable - just use first data qubit of first block as logical Z
    c.append("OBSERVABLE_INCLUDE", [stim.target_rec(-n_data)], 0)
    
    return c, n_data, n_stabilizer


def simulate_circuit(circuit: stim.Circuit, n_data: int, n_stabilizer: int, num_shots=10000) -> Dict[Tuple[str,str], int]:
    """
    Simulate circuit and extract measurement results.
    """
    sampler = circuit.compile_sampler()
    shots = sampler.sample(shots=num_shots)
    results = {}
    
    for shot in shots:
        # Measurements come in order: first n_stabilizer ancillas, then n_data data qubits
        synd_bits = ''.join(str(int(shot[i])) for i in range(n_stabilizer))
        data_bits = ''.join(str(int(shot[n_stabilizer + i])) for i in range(n_data))
        key = (data_bits, synd_bits)
        results[key] = results.get(key, 0) + 1
    
    return results


def decode_hamming_block(data_block: list, syndrome_block: list) -> int:
    """
    Decode a single Hamming (7,4) block using syndrome.
    
    Args:
        data_block: list of 7 bits (measurements of data qubits)
        syndrome_block: list of 3 bits (syndrome measurements)
    
    Returns:
        Decoded logical bit (majority vote of corrected data)
    """
    corrected = data_block.copy()
    syndrome = syndrome_block
    
    # If syndrome is all zeros, no error detected
    if sum(syndrome) == 0:
        return 1 if sum(corrected) > 3 else 0
    
    # Find error location by matching syndrome to H columns
    # Each column of H represents the syndrome for an error at that position
    for pos in range(7):
        column = H_hamming[:, pos]
        if np.array_equal(syndrome, column):
            # Found the error position - flip it
            corrected[pos] ^= 1
            break
    
    # Return majority vote
    return 1 if sum(corrected) > 3 else 0


def decode_hamming_concatenated(meas: Tuple[str, str], verbose=False) -> int:
    """
    Decode concatenated Hamming code:
    - Decode each inner Hamming (7,4) block using its syndrome
    - Perform majority vote on the 7 outer logical qubits
    """
    data_bits, synd_bits = meas
    n_outer = 7
    n_inner = 7
    n_stab_per_block = 3
    
    data = [int(b) for b in data_bits]
    synd = [int(b) for b in synd_bits]
    
    # Decode each of the 7 inner blocks
    outer_logical_bits = []
    for outer_idx in range(n_outer):
        # Extract data and syndrome for this block
        data_start = outer_idx * n_inner
        data_end = data_start + n_inner
        data_block = data[data_start:data_end]
        
        synd_start = outer_idx * n_stab_per_block
        synd_end = synd_start + n_stab_per_block
        syndrome_block = synd[synd_start:synd_end]
        
        # Decode this inner block
        logical_bit = decode_hamming_block(data_block, syndrome_block)
        outer_logical_bits.append(logical_bit)
        
        if verbose and outer_idx < 2:
            print(f"  Block {outer_idx}: data={data_block}, synd={syndrome_block}, logical={logical_bit}")
    
    if verbose:
        print(f"  Outer logical bits: {outer_logical_bits}")
    
    # Now decode the outer code using majority vote
    result = 1 if sum(outer_logical_bits) > n_outer // 2 else 0
    if verbose:
        print(f"  Final result: {result}")
    return result


def logical_error_rate(results: Dict[Tuple[str,str], int], logical_prepared=0, verbose=False) -> float:
    """
    Compute logical error rate.
    """
    errors = 0
    total = 0
    error_cases = []
    
    for (data_bits, synd_bits), count in results.items():
        decoded = decode_hamming_concatenated((data_bits, synd_bits))
        if decoded != logical_prepared:
            errors += count
            if len(error_cases) < 3:  # Save first few error cases
                error_cases.append((data_bits, synd_bits, count))
        total += count
    
    if verbose and error_cases:
        print(f"\nExample error cases:")
        for i, (data, synd, count) in enumerate(error_cases):
            print(f"\nError case {i+1} (occurred {count} times):")
            print(f"Data: {data[:20]}...")
            print(f"Synd: {synd}")
            decode_hamming_concatenated((data, synd), verbose=True)
    
    return errors / total


# Test with more shots
print("Building concatenated Hamming circuit...")
p = 0.05
circuit, n_data, n_stab = concatenated_hamming_circuit(p=p)
print(f"Circuit has {n_data} data qubits and {n_stab} stabilizers")

print("\nRunning simulation with 100,000 shots...")
results = simulate_circuit(circuit, n_data, n_stab, num_shots=100_000)

print(f"\nNumber of unique measurement outcomes: {len(results)}")
print(f"Most common outcomes:")
for i, (key, count) in enumerate(sorted(results.items(), key=lambda x: -x[1])[:5]):
    data, synd = key
    num_errors = sum(int(b) for b in data)
    num_syndromes = sum(int(b) for b in synd)
    print(f"  {i+1}. #errors={num_errors}, #synd={num_syndromes}, count={count}")

p_L = logical_error_rate(results, logical_prepared=0, verbose=True)
print(f"\n{'='*60}")
print(f"Logical error rate: {p_L:.6f}")

# Analytical estimate
p_inner_fail = sum(math.comb(7, k) * (p**k) * ((1-p)**(7-k)) for k in range(2, 8))
print(f"Expected inner block failure rate: {p_inner_fail:.6f}")
p_outer_fail = sum(math.comb(7, k) * (p_inner_fail**k) * ((1-p_inner_fail)**(7-k)) for k in range(4, 8))
print(f"Expected logical error rate (approx): {p_outer_fail:.6f}")
print(f"{'='*60}")

# Compare to repetition code
print(f"\nFor comparison:")
print(f"Physical error rate: {p:.6f}")
print(f"Repetition code (n=49): ~{sum(math.comb(49, k) * (p**k) * ((1-p)**(49-k)) for k in range(25, 50)):.6f}")

Building concatenated Hamming circuit...
Circuit has 49 data qubits and 21 stabilizers

Running simulation with 100,000 shots...

Number of unique measurement outcomes: 36415
Most common outcomes:


NameError: name 'errors' is not defined