# QEC Dataset Visualization

Exploration and visualization of the raw surface code data generated for GNN training.

In [None]:
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import SVG, display

import stim
import pymatching

from qec_generator import load_config

# Better plotting defaults
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10

## Configuration & Data Loading

Load config and select a specific setting to visualize.

In [None]:
# Load configuration
cfg = load_config("../configs/data_generation.yaml")

# Select a setting to visualize
distance = 5
rounds = 5
error_prob = 0.005

# Get directory for this setting
setting_dir = Path('./../data/raw_data/d5_r5_p0_005')
print(f"Data directory: {setting_dir}")

## Circuit & Error Model

Load the Stim circuit and detector error model (DEM).

In [None]:
# Load circuit and DEM
circuit = stim.Circuit.from_file(setting_dir / "circuit.stim")
dem = stim.DetectorErrorModel.from_file(setting_dir / "model.dem")

print(f"Circuit: {len(circuit)} instructions")
print(f"DEM: {dem.num_detectors} detectors, {dem.num_observables} observables")
print(f"DEM errors: {dem.num_errors}")

In [None]:
# Visualize circuit diagram (interactive SVG)
display(SVG(str(circuit.diagram("timeline-svg"))))

## Detector Graph

Load and inspect the graph structure used for GNN decoding.

In [None]:
# Load graph tensors
graph_dir = setting_dir / "graph"

edge_index = np.load(graph_dir / "edge_index.npy")
edge_error_prob = np.load(graph_dir / "edge_error_prob.npy")
edge_weight = np.load(graph_dir / "edge_weight.npy")
node_coords = np.load(graph_dir / "node_coords.npy")
node_is_boundary = np.load(graph_dir / "node_is_boundary.npy")

# Load metadata
with open(graph_dir / "meta.json") as f:
    graph_meta = json.load(f)

print("Graph structure:")
print(f"  Nodes: {graph_meta['num_nodes']} ({graph_meta['num_detectors']} detectors + boundary)")
print(f"  Edges: {edge_index.shape[1]} (directed, bidirectional)")
print(f"  Unique edges: {edge_index.shape[1] // 2}")
print(f"  Observables: {graph_meta['num_observables']}")
print("\nEdge statistics:")
print(f"  Error prob: {edge_error_prob.min():.6f} - {edge_error_prob.max():.6f}")
print(f"  Weight: {edge_weight.min():.3f} - {edge_weight.max():.3f}")
print(f"\nNode coordinates: {node_coords.shape}")

In [None]:
# Visualize graph in 2D (x, y coordinates)
fig, ax = plt.subplots(figsize=(8, 8))

# Plot edges (undirected - only plot u->v where u < v)
for i in range(edge_index.shape[1]):
    u, v = edge_index[0, i], edge_index[1, i]
    if u < v and not (node_is_boundary[u] or node_is_boundary[v]):
        x1, y1 = node_coords[u, :2]
        x2, y2 = node_coords[v, :2]
        ax.plot([x1, x2], [y1, y2], 'gray', alpha=0.3, linewidth=0.5)

# Plot detector nodes
detector_mask = ~node_is_boundary
ax.scatter(node_coords[detector_mask, 0], 
           node_coords[detector_mask, 1],
           c='steelblue', s=30, alpha=0.7, label='Detectors')

ax.set_title(f"Detector Graph (d={distance}, r={rounds})")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.axis('equal')
ax.legend()
plt.tight_layout()
plt.show()

## Syndrome Data

Load and analyze sampled syndromes and logical outcomes.

In [None]:
# Load training split data
syndrome = np.load(setting_dir / "train_syndrome.npy")
logical = np.load(setting_dir / "train_logical.npy")

print(f"Syndrome shape: {syndrome.shape} ({syndrome.dtype})")
print(f"Logical shape:  {logical.shape} ({logical.dtype})")
print(f"\nNumber of samples: {len(syndrome):,}")

In [None]:
# Syndrome statistics
weights = syndrome.sum(axis=1)

print("Syndrome statistics:")
print(f"  Mean weight: {weights.mean():.2f}")
print(f"  Median weight: {np.median(weights):.0f}")
print(f"  Min weight: {weights.min()}")
print(f"  Max weight: {weights.max()}")
print(f"  Std weight: {weights.std():.2f}")
print(f"\nLogical flip rate: {logical.mean():.4f}")
print(f"Physical error rate: {error_prob:.4f}")

In [None]:
# Plot syndrome weight distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram of syndrome weights
axes[0].hist(weights, bins=50, edgecolor='black', alpha=0.7)
axes[0].axvline(weights.mean(), color='red', linestyle='--', label=f'Mean = {weights.mean():.2f}')
axes[0].set_xlabel('Syndrome Weight (# fired detectors)')
axes[0].set_ylabel('Count')
axes[0].set_title('Syndrome Weight Distribution')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Logical error distribution
logical_counts = [len(logical) - logical.sum(), logical.sum()]
axes[1].bar(['No Error', 'Logical Error'], logical_counts, 
            color=['green', 'red'], alpha=0.7, edgecolor='black')
axes[1].set_ylabel('Count')
axes[1].set_title('Logical Outcome Distribution')
axes[1].grid(axis='y', alpha=0.3)

# Add percentages
for i, count in enumerate(logical_counts):
    axes[1].text(i, count + len(logical)*0.01, 
                f'{count:,}\n({100*count/len(logical):.2f}%)',
                ha='center', va='bottom')

plt.tight_layout()
plt.show()

## Sample Decoding

Decode a random sample using MWPM and visualize fired detectors.

In [None]:
# Select a random sample with non-zero syndrome
nonzero_samples = np.where(weights > 0)[0]
shot_idx = np.random.choice(nonzero_samples)

shot_syndrome = syndrome[shot_idx]
shot_logical = logical[shot_idx]
fired = shot_syndrome.astype(bool)

print(f"Sample #{shot_idx}:")
print(f"  Syndrome weight: {fired.sum()} fired detectors")
print(f"  True logical outcome: {shot_logical}")

In [None]:
# Decode with MWPM
matcher = pymatching.Matching.from_detector_error_model(dem)
predicted_logical = matcher.decode(shot_syndrome)

print(f"MWPM prediction: {predicted_logical}")
print(f"True outcome:    {shot_logical}")
print(f"Correct: {np.array_equal(predicted_logical, shot_logical)}")

In [None]:
# Visualize fired detectors on the graph
fig, ax = plt.subplots(figsize=(10, 10))

# Build position dict from coordinates
coord_dict = circuit.get_detector_coordinates()
pos = {}
for det_id, coords in coord_dict.items():
    if len(coords) >= 2:
        pos[int(det_id)] = (float(coords[0]), float(coords[1]))

# Draw graph edges
G = matcher.to_networkx()
for u, v in G.edges():
    if isinstance(u, int) and isinstance(v, int) and u in pos and v in pos:
        x1, y1 = pos[u]
        x2, y2 = pos[v]
        ax.plot([x1, x2], [y1, y2], 'gray', alpha=0.2, linewidth=0.5)

# All detector positions
all_x = [pos[i][0] for i in range(len(shot_syndrome)) if i in pos]
all_y = [pos[i][1] for i in range(len(shot_syndrome)) if i in pos]
ax.scatter(all_x, all_y, c='lightblue', s=40, alpha=0.5, label='Quiet detectors')

# Fired detector positions
fired_ids = np.where(fired)[0]
fired_x = [pos[i][0] for i in fired_ids if i in pos]
fired_y = [pos[i][1] for i in fired_ids if i in pos]
ax.scatter(fired_x, fired_y, c='red', s=100, alpha=0.8, 
          edgecolors='darkred', linewidths=2, label='Fired detectors', zorder=10)

ax.set_title(f"Sample #{shot_idx}: {fired.sum()} fired detectors\n"
            f"MWPM: {predicted_logical[0]} | True: {shot_logical[0]} | "
            f"Correct: {np.array_equal(predicted_logical, shot_logical)}")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.axis('equal')
ax.legend()
ax.grid(alpha=0.2)
plt.tight_layout()
plt.show()