# Surface Code

Ref: https://www.nature.com/articles/s41586-022-05434-1 - d=3 Implementation
https://www.nature.com/articles/s41586-022-04566-8 - More on implementation
https://arxiv.org/html/2307.14989v4/#S5.F7 - Diagrams


Core on roateted: https://arxiv.org/pdf/2409.14765









In [None]:
from collections import deque
import qsharp
from itertools import combinations
import pymatching

import numpy as np
import random
import networkx as nx


def convert_syndrome(syndrome_results):
    return np.array([1 if bit == qsharp.Result.One else 0 for bit in syndrome_results], dtype=np.uint8)


def getZCorrections(d, initMeasures, xMaps):
    num_qubits = d*d
    H_x = np.zeros((len(initMeasures), num_qubits), dtype=np.uint8)
    for i, stabilizer in enumerate(xMaps):
        for qubit_idx in stabilizer:
            H_x[i, qubit_idx] = 1


    z_decoder = pymatching.Matching(H_x) 
    z_correction = z_decoder.decode(initMeasures)
    # print(z_correction)
    z_error_qubits = np.where(z_correction == 1)[0]  # Qubits needing phase flips
    print("Apply phase flips (Z) to qubits:", z_error_qubits)
    return z_error_qubits


def flatten_syndrome_timeline(syndrome_matrix):
    if syndrome_matrix.shape[0] == 1:
        return syndrome_matrix[0]
    return syndrome_matrix.flatten()

def getSyndrome(a,b,c,d):
    return [i!=j for i,j in zip(a,b)], [i!=j for i,j in zip(c,d)]


class PauliTracker:
    '''Tracks the Pauli X&Z stabilizers measurements, Record 
    the stabilizer measurement outcomes in each round without
    physically correcting the qubits'''

    def __init__(self):
        self.XKEY = "X"
        self.ZKEY = "Z"
        self.tracker = {self.XKEY: [], self.ZKEY: []}

    def track(self, xMeasures, zMeasures):
        self.tracker[self.XKEY].append(xMeasures)
        self.tracker[self.ZKEY].append(zMeasures)
    
    def getX(self):
        return self.tracker[self.XKEY]
    def getZ(self):
        return self.tracker[self.ZKEY]


def addTimeStep():
    xMaps, initXMeasures, zMaps, initZMeasures = qsharp.eval(f"RotatedSurfaceCode.GenerateLattice(d, qubits);")
    PAULI_TRACKER.track(initXMeasures, initZMeasures)
    return xMaps, initXMeasures, zMaps, initZMeasures

def convert_syndrome(syndrome_results):
    return np.array([1 if bit == qsharp.Result.One else 0 for bit in syndrome_results], dtype=np.uint8)

def getCorrections(d, p, syndrome, maps):
    num_qubits = d*d
    H = np.zeros((len(maps), num_qubits), dtype=np.uint8)
    
    for i, stabilizer in enumerate(maps):
        for qubit_idx in stabilizer:
            H[i, qubit_idx] = 1

    weights = np.ones(H.shape[1]) * np.log((1 - p) / p)
    decoder = pymatching.Matching(H, weights=weights) 
    correction = decoder.decode(syndrome)
    error_qubits = np.where(correction == 1)[0]  # Qubits needing flips
    # print(error_qubits)
    # print("Apply flips to qubits:", error_qubits)
    return error_qubits

from collections import defaultdict
def decode_surface_code(parity_checks, stabilizer_qubit_map):
    """
    Args:
        parity_checks (List[int]): List of 0 or 1 values for each stabilizer (1 = error detected)
        stabilizer_qubit_map (List[List[int]]): List of lists, each with qubit indices for the stabilizer

    Returns:
        Set[int]: Indices of qubits to apply correction to
    """

    # Count how many times each qubit appears in stabilizers with parity 1
    qubit_votes = defaultdict(int)

    for stab_index, parity in enumerate(parity_checks):
        if parity == 1:
            for q in stabilizer_qubit_map[stab_index]:
                qubit_votes[q] += 1

    # Apply correction to qubits that are involved in an odd number of violated stabilizers
    corrections = {q for q, count in qubit_votes.items() if count % 2 == 1}
    return corrections

def applyCorrection(correct, zp, xp, zMeasures, zMaps, xMeasures, xMaps):
    if correct:
        x_correction = getCorrections(d, xp,  convert_syndrome(zMeasures), zMaps)
        z_correction = getCorrections(d, zp, convert_syndrome(xMeasures), xMaps)
        print(x_correction)
        for q in x_correction:
            qsharp.eval(f"X(qubits[{q}]);")
        for q in z_correction:
            qsharp.eval(f"Z(qubits[{q}]);")


def runLogicalQubit(d, layers, xnoise, znoise, correct=False):
    TOTAL_DATA_QUBITS = d*d
    PAULI_TRACKER = PauliTracker()
    qsharp.init(project_root=".")
    qsharp.eval(f"import Std.Diagnostics.ConfigurePauliNoise;")
    qsharp.eval(f"import Std.Arrays.ForEach;")
    qsharp.eval(f"let d = {d};")
    qsharp.eval(f"let TOTAL_DATA_QUBITS = {TOTAL_DATA_QUBITS};")
    qsharp.eval("use qubits = Qubit[TOTAL_DATA_QUBITS];")
    qsharp.eval("ApplyToEach(X, qubits);")
    qsharp.eval(f"ConfigurePauliNoise({xnoise}, 0.0, {znoise});")
    qsharp.eval("ResetAll(qubits);")

    xMaps, initXMeasures, zMaps, initZMeasures = addTimeStep()


    # x_correction = getCorrections(d, initZMeasures, zMaps)
    # z_correction = getCorrections(d, initXMeasures, xMaps)
    # # for q in x_correction:
    # #     qsharp.eval(f"X(qubits[{q}]);")
    # # for q in z_correction:
    # #     qsharp.eval(f"Z(qubits[{q}]);")
    # applyCorrection(correct, znoise, xnoise, initZMeasures, zMaps, initXMeasures, xMaps)
    corrupt = False
    for i in range(layers):
        x , xMeasures, z, zMeasures = addTimeStep()
        # applyCorrection(correct, znoise, xnoise, zMeasures, zMaps, xMeasures, xMaps)
        xSyndrome, zSyndrome = getSyndrome(initXMeasures, xMeasures, initZMeasures, zMeasures)
        # if any(zSyndrome):
        #     corrupt = True
        
            # x , xMeasures, z, zMeasures = addTimeStep()
        
        # print(f"Visualize If any errors for time({i})")
        # print("X", xSyndrome)
        # print("Z", zSyndrome)
        
        initXMeasures = xMeasures
        initZMeasures = zMeasures
    # xSyndrome, zSyndrome = getSyndrome(initXMeasures, xMeasures, initZMeasures, zMeasures)
    
    # results = qsharp.eval("ForEach(x => Measure([PauliZ], [x]), qubits);")
    results = []
    re = []
    s = d-1
    for i in range(0, d*d):
        # if ((i//d)*(d+1)) == i:
            # re.append(i)
            # s-=1
            results.append(qsharp.eval(f"Measure([PauliZ], [qubits[{i}]]);"))
    # print(re)
    # print(results)
    # for i in range(0, TOTAL_DATA_QUBITS, d):
    #     print(results[i:i+d])
    # print(sum(convert_syndrome(results))%2)
    # print(f"Final Logical State: {sum(convert_syndrome(results))%2}")
    return int(not sum(convert_syndrome(results))%2)
    

qsharp.eval("ResetAll(qubits);")

d = 3
layers = 4
xnoise = 0.01
znoise = 0.0
SHOTS = 100
def failure(r):
    return 1 - (sum(r)/SHOTS)

r = []
for i in range(SHOTS):
    r.append(runLogicalQubit(d, layers, xnoise, znoise, correct=False))
print(f"Logical Error without correction: {failure(r)}")

r = []
for i in range(SHOTS):
    r.append(runLogicalQubit(d, layers, xnoise, znoise, correct=True))
print(f"Logical Error with correction: {failure(r)}")




Logical Error without correction: 0.48
Logical Error with correction: 0.54


In [None]:
import numpy as np
import pymatching
import matplotlib.pyplot as plt
import qsharp
import pymatching

import numpy as np
import networkx as nx

def convert_syndrome(syndrome_results):
    return np.array([1 if bit == qsharp.Result.One else 0 for bit in syndrome_results], dtype=np.uint8)


class Decoder:
    def __init__(self, d,p, maps):
        self.d = d
        self.num_qubits = d*d
        
        self.H = np.zeros((len(maps), self.num_qubits), dtype=np.uint8)
        for i, stabilizer in enumerate(maps):
            for qubit_idx in stabilizer:
                self.H[i, qubit_idx] = 1
        
        weights = np.ones(self.H.shape[1]) * np.log((1 - p) / p)
        self.decoder = pymatching.Matching(self.H, weights=weights) 

    def decode(self, syndrome):
        return self.decoder.decode(syndrome)
    
    def decode_batch(self, syndromes):
        return self.decoder.decode_batch(syndromes)

    def getErrorQubits(self, correction):
        return np.where(correction == 1)[0]

def qsharp_runner(d,p, shots):
    qsharp.init(project_root=".")
    # qsharp.eval(f"import Std.Diagnostics.ConfigurePauliNoise;")
    # qsharp.eval(f"import Std.Arrays.ForEach;")
    # qsharp.eval(f"let d = {d};")
    # qsharp.eval(f"let TOTAL_DATA_QUBITS = {TOTAL_DATA_QUBITS};")
    # qsharp.eval("use qubits = Qubit[TOTAL_DATA_QUBITS];")
    # # qsharp.eval("ApplyToEach(X, qubits);")
    # qsharp.eval(f"ConfigurePauliNoise({p}, 0.0, 0.0);")
    # qsharp.eval("ResetAll(qubits);")
    xMaps, zMaps = qsharp.eval(f"RotatedSurfaceCode.GetMaps({d});")
    

    results = qsharp.run(f"RotatedSurfaceCode.RotatedSurfaceCode({d}, {d})", shots=shots, noise=qsharp.BitFlipNoise(p))
    zSyndromes = []
    trueValues = []
    for result in results:
        zSyndromes.append(convert_syndrome(result[len(xMaps):len(xMaps)+len(zMaps)]))
        trueValues.append(convert_syndrome(result[len(xMaps)+len(zMaps):]))

    # results = qsharp.eval("ForEach(x => Measure([PauliZ], [x]), qubits);")
    # results = []
    # for i in range(0, d*d):
    #     # if (i%(d+1))== 0:
    #         results.append(1 if qsharp.eval(f"Measure([PauliZ], [qubits[{i}]]);") == qsharp.Result.One else 0)
    # trueLogicalValue = int((sum(convert_syndrome(results))%2))
    return zSyndromes, zMaps, trueValues


def is_logical_error(d, correction, true_logical):
    def getParity(d, correction):
        r = []
        correctionParity = 0
        for i in range(d):
            # r.append((i+1)*(d-1))
            # if correction[(d*d)-(i+1)]:
            if correction[(i+1)*(d-1)]:
                correctionParity = 0 if correctionParity else 1
        # print(r)
        return correctionParity
    
    return (int(getParity(d, true_logical))) != (int(getParity(d, true_logical)))
    
    

def run_experiment_for_d(d_values, p_values, num_trials):
    results = {}
    for d in d_values:
        logical_error_rates = []
        print(f"Running for code distance d={d}")

        for p in p_values:
            logical_errors = 0
            print(f"Physical error rate p={p:.3f}")
            syndrome, maps, true_logical = qsharp_runner(d, p, num_trials)
            
            decoder = Decoder(d,p, maps)
            correction = decoder.decode_batch(syndrome)

            # Check if correction flips the logical qubit
            logical_errors = sum(is_logical_error(d, c, true_logical[i]) for i,c in enumerate(correction))
            logical_error_rates.append(logical_errors / num_trials)

        results[d] = logical_error_rates

    return results

print(run_experiment_for_d(d_values=[1, 3, 5], p_values=[0.005, 0.01, 0.05, 0.1, 0.15, 0.2], num_trials=1000))
# print(run_experiment_for_d(d_values=[1, 3, 5], p_values=[0.005, 0.01, 0.05, 0.1, 0.15, 0.2], num_trials=11000))
#{1: [0.010181818181818183, 0.01818181818181818, 0.098, 0.1809090909090909, 0.2540909090909091, 0.31836363636363635], 3: [0.07636363636363637, 0.14945454545454545, 0.43618181818181817, 0.48918181818181816, 0.5086363636363637, 0.4979090909090909], 5: [0.15163636363636362, 0.2737272727272727, 0.49772727272727274, 0.5010909090909091, 0.49536363636363634, 0.5014545454545455]}

# print(run_experiment_for_d(d_values=[3], p_values=[0.005, 0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7], num_trials=2000))
# print(run_experiment_for_d(d_values=[3], p_values=[0.005], num_trials=2))

Running for code distance d=1
Physical error rate p=0.005
Physical error rate p=0.010
Physical error rate p=0.050
Physical error rate p=0.100
Physical error rate p=0.150
Physical error rate p=0.200
Running for code distance d=3
Physical error rate p=0.005
Physical error rate p=0.010
Physical error rate p=0.050
Physical error rate p=0.100
Physical error rate p=0.150
Physical error rate p=0.200
Running for code distance d=5
Physical error rate p=0.005
Physical error rate p=0.010
Physical error rate p=0.050
Physical error rate p=0.100
Physical error rate p=0.150
Physical error rate p=0.200
{1: [0.008, 0.014, 0.096, 0.179, 0.267, 0.299], 3: [0.073, 0.149, 0.422, 0.468, 0.474, 0.511], 5: [0.133, 0.22, 0.48, 0.505, 0.513, 0.493]}


In [46]:
# !pip install torch torchvision torchaudio
import torch

if torch.backends.mps.is_available():
    print("MPS (Metal) is available!")
    device = torch.device("mps")
else:
    print("MPS not available, using CPU.")
    device = torch.device("cpu")


import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

lattice = torch.zeros((5, 5), dtype=torch.uint8, device=device)
syndrome = torch.ones((5, 5), dtype=torch.uint8, device=device)

lattice ^= syndrome
print(lattice.cpu())

MPS (Metal) is available!
mps
tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]], dtype=torch.uint8)


In [None]:
!pip install pyobjc
import numpy as np
import objc
from Cocoa import NSObject
import Metal
import MetalKit

# 1. Create Metal device
device = Metal.MTLCreateSystemDefaultDevice()

# 2. Load shader
source = open("my_shader.metal").read()
options = {}
compile_error = None
library = device.newLibraryWithSource_options_error_(source, options, objc.nil)
function = library.newFunctionWithName_("add_one")

# 3. Create compute pipeline
pipeline_state, err = device.newComputePipelineStateWithFunction_error_(function, None)

# 4. Input/output buffers
array = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
buffer_size = array.nbytes

input_buffer = device.newBufferWithBytes_length_options_(array.tobytes(), buffer_size, 0)
output_buffer = device.newBufferWithLength_options_(buffer_size, 0)

# 5. Create command queue and encoder
command_queue = device.newCommandQueue()
command_buffer = command_queue.commandBuffer()
command_encoder = command_buffer.computeCommandEncoder()

command_encoder.setComputePipelineState_(pipeline_state)
command_encoder.setBuffer_offset_atIndex_(input_buffer, 0, 0)
command_encoder.setBuffer_offset_atIndex_(output_buffer, 0, 1)

# 6. Dispatch threads
thread_count = len(array)
thread_group_size = pipeline_state.maxTotalThreadsPerThreadgroup()
thread_groups = ((thread_count + thread_group_size - 1) // thread_group_size)
command_encoder.dispatchThreadgroups_threadsPerThreadgroup_(
    (thread_groups, 1, 1),
    (thread_group_size, 1, 1)
)

command_encoder.endEncoding()
command_buffer.commit()
command_buffer.waitUntilCompleted()

# 7. Get results
result_ptr = output_buffer.contents()
result = np.frombuffer(objc.pythonify(result_ptr), dtype=np.float32, count=len(array))
print("Result:", result)


In [None]:
import numpy as np
import itertools
from typing import List, Dict, Tuple

def generate_all_single_pauli_errors(num_qubits: int) -> List[np.ndarray]:
    """Generates all single X/Z/Y errors on a code block."""
    errors = []
    for q in range(num_qubits):
        e = np.zeros(num_qubits, dtype=int)
        e[q] = 1
        errors.append((q, e))  # single-qubit X (or Z)
    return errors

def compute_syndrome(error: np.ndarray, stab_map: List[List[int]]) -> np.ndarray:
    """Compute syndrome based on parity check (mod 2) for each stabilizer."""
    syndrome = np.zeros(len(stab_map), dtype=int)
    for i, qubits in enumerate(stab_map):
        parity = sum(error[q] for q in qubits if q < len(error))
        syndrome[i] = parity % 2
    return syndrome

def build_lookup_table(num_qubits: int, stab_map: List[List[int]]) -> Dict[str, np.ndarray]:
    """Precompute the syndrome → correction table for all single-qubit errors."""
    table = {}

    for q, error in generate_all_single_pauli_errors(num_qubits):
        syndrome = compute_syndrome(error, stab_map)
        key = ''.join(map(str, syndrome))
        # Store only if not already stored — minimal weight correction
        if key not in table:
            table[key] = error.copy()
    
    # Add trivial correction for zero syndrome
    table['0' * len(stab_map)] = np.zeros(num_qubits, dtype=int)
    return table

def lookup_decode(syndrome: np.ndarray, lookup_table: Dict[str, np.ndarray]) -> np.ndarray:
    key = ''.join(map(str, syndrome))
    return lookup_table.get(key, np.zeros_like(next(iter(lookup_table.values()))))

In [None]:

class SurfaceCodeDecoder:
    def __init__(self, d):
        self.d = d  # code distance

    def _get_syndrome_graph(self, syndrome):
        # Create a graph of syndrome nodes
        G = nx.Graph()
        for i, node1 in enumerate(syndrome):
            for j, node2 in enumerate(syndrome):
                if i >= j:
                    continue
                weight = self._manhattan_distance(node1, node2)
                G.add_edge(node1, node2, weight=weight)
        return G

    def _manhattan_distance(self, a, b):
        return abs(a[0] - b[0]) + abs(a[1] - b[1])

    def decode(self, syndrome):
        G = self._get_syndrome_graph(syndrome)
        matching = nx.algorithms.matching.min_weight_matching(G)
        
        # In real code, we would now apply the correction along the shortest paths
        corrections = []
        for u, v in matching:
            corrections.append((u, v))  # paths to be corrected
        return corrections
    def getSyndrome(self, map, measures):
        syndrome = []
        syndromeMap = []
        for i,measure in enumerate(measures):
            if measure == qsharp.Result.One:
                qubits = []
                for qubit in map[i]:
                    qubits.append((qubit//self.d, qubit%self.d))
                syndrome.append((sum(q[0] for q in qubits)/len(qubits),sum(q[1] for q in qubits)/len(qubits) ))
                syndromeMap.append(i)
        print(syndromeMap)
        print(syndrome)
        return syndrome, syndromeMap
    
    def getCoordQubits(self, coord):
        return (coord[0]*self.d)+coord[1]
    
    

    def find_shortest_path(self, start, end):
        # Round coordinates to integers
        start_int = (round(start[0]), round(start[1]))
        end_int = (round(end[0]), round(end[1]))
        
        # Calculate horizontal and vertical distances
        x_diff = abs(end_int[0] - start_int[0])
        y_diff = abs(end_int[1] - start_int[1])
        
        # Initialize the path with the starting point
        path = [start_int]
        
        # Current position
        current = start_int
        
        # Determine direction of movement
        x_sign = 1 if end_int[0] > start_int[0] else -1 if end_int[0] < start_int[0] else 0
        y_sign = 1 if end_int[1] > start_int[1] else -1 if end_int[1] < start_int[1] else 0
        
        # Move diagonally as much as possible
        diagonal_steps = min(x_diff, y_diff)
        for _ in range(diagonal_steps):
            current = (current[0] + x_sign, current[1] + y_sign)
            path.append(current)
        
        # Move remaining horizontal steps
        remaining_x = x_diff - diagonal_steps
        for _ in range(remaining_x):
            current = (current[0] + x_sign, current[1])
            path.append(current)
        
        # Move remaining vertical steps
        remaining_y = y_diff - diagonal_steps
        for _ in range(remaining_y):
            current = (current[0], current[1] + y_sign)
            path.append(current)
        
        # Return the path (excluding the starting point if it's already included)
        return path[1:] if end_int != start_int else []


def correctErrors(errors):
    for error, emap in errors.items():
      for key, apply in emap.items():
        #   print(key, apply)
          if apply:
            # print("Applying")
            if error == "x":
                qsharp.eval(f"X(qubits[{key}]);")
            else:
                qsharp.eval(f"Z(qubits[{key}]);")



from collections import defaultdict

def decode_surface_code(parity_checks, stabilizer_qubit_map):
    """
    Args:
        parity_checks (List[int]): List of 0 or 1 values for each stabilizer (1 = error detected)
        stabilizer_qubit_map (List[List[int]]): List of lists, each with qubit indices for the stabilizer

    Returns:
        Set[int]: Indices of qubits to apply correction to
    """

    # Count how many times each qubit appears in stabilizers with parity 1
    qubit_votes = defaultdict(int)

    for stab_index, parity in enumerate(parity_checks):
        if parity == 1:
            for q in stabilizer_qubit_map[stab_index]:
                qubit_votes[q] += 1

    # Apply correction to qubits that are involved in an odd number of violated stabilizers
    corrections = {q for q, count in qubit_votes.items() if count % 2 == 1}
    return corrections


def correct_surface_code(xcorrections, zcorrections):
    for i in xcorrections:
        qsharp.eval(f"X(qubits[{i}]);")
    for i in zcorrections:
        qsharp.eval(f"Z(qubits[{i}]);")


def getZCorrections(d, initMeasures, xMaps):
    num_qubits = d*d
    H_x = np.zeros((len(initMeasures), num_qubits), dtype=np.uint8)
    for i, stabilizer in enumerate(xMaps):
        for qubit_idx in stabilizer:
            H_x[i, qubit_idx] = 1


    z_decoder = pymatching.Matching(H_x) 
    z_correction = z_decoder.decode(initMeasures)
    print(z_correction)
    z_error_qubits = np.where(z_correction == 1)[0]  # Qubits needing phase flips
    print("Apply phase flips (Z) to qubits:", z_error_qubits)
    return z_error_qubits
