In [1]:
from panqec.codes import surface_2d
from panqec.error_models import PauliErrorModel

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
import time
import os
#import tool
from panq_functions import GNNDecoder, collate, fraction_of_solved_puzzles, compute_accuracy, logical_error_rate, \
    surface_code_edges, generate_syndrome_error_volume, adapt_trainset, ler_loss, save_model, load_model

from ldpc.mod2 import *

# ------------------------------------------
# Device selection: MPS (Mac GPU) → CPU
# ------------------------------------------
if torch.backends.mps.is_available():
    device = torch.device("mps")
    use_amp = False  # AMP on MPS is unstable
    amp_data_type = torch.float32
else:
    device = torch.device("cpu")
    use_amp = False
    amp_data_type = torch.float32

print("Using device:", device)

Using device: mps


In [2]:
# ============================================
# Parameters (surface code d=3)
# ============================================
d = 3
error_model_name = "DP"  # depolarizing

# DEPOLARIZING ERROR MODEL
error_model = PauliErrorModel(0.34, 0.32, 0.34)

# GNN Hyperparameters
n_node_inputs = 4
n_node_outputs = 4
n_iters = 3
n_node_features = 50
n_edge_features = 50

len_test_set = 500
test_err_rate = 0.10

len_train_set = 5000
max_train_err_rate = 0.15

lr = 1e-4
weight_decay = 1e-4

msg_net_size = 512
msg_net_dropout_p = 0.05
gru_dropout_p = 0.05

print("GNN hyperparameters:",
      "iters:", n_iters, "node features:", n_node_features,
      "edge features:", n_edge_features)

GNN hyperparameters: iters: 3 node features: 50 edge features: 50


In [4]:
# ============================================
# Build Surface Code + GNN
# ============================================

dist = d
code = surface_2d.RotatedPlanar2DCode(dist)

gnn = GNNDecoder(
    dist=dist,
    n_node_inputs=n_node_inputs,
    n_node_outputs=n_node_outputs,
    n_iters=n_iters,
    n_node_features=n_node_features,
    n_edge_features=n_edge_features,
    msg_net_size=msg_net_size,
    msg_net_dropout_p=msg_net_dropout_p,
    gru_dropout_p=gru_dropout_p,
)

gnn.to(device)

# Tanner graph edges
src, tgt = surface_code_edges(code)
GNNDecoder.surface_code_edges = (
    torch.LongTensor(src),
    torch.LongTensor(tgt),
)

# -------------------------------------------
# Degeneracy nullspaces (using ldpc.mod2)
# -------------------------------------------

hx_null = nullspace(code.Hx.toarray())
hz_null = nullspace(code.Hz.toarray())

# --- convert sparse → dense ---
if not isinstance(hx_null, np.ndarray):
    hx_null = hx_null.toarray()

if not isinstance(hz_null, np.ndarray):
    hz_null = hz_null.toarray()

GNNDecoder.hxperp = torch.tensor(hx_null, dtype=torch.float32, device=device)
GNNDecoder.hzperp = torch.tensor(hz_null, dtype=torch.float32, device=device)

GNNDecoder.device = device

print("Total model parameters:", sum(p.numel() for p in gnn.parameters()))


# ============================================
# Create Test Dataset
# ============================================

testset = adapt_trainset(
    generate_syndrome_error_volume(
        code,
        error_model=error_model,
        p=test_err_rate,
        batch_size=len_test_set,
        for_training=False,
    ),
    code,
    num_classes=n_node_inputs,
    for_training=False,
)

testloader = DataLoader(testset, batch_size=512, collate_fn=collate, shuffle=False)



Total model parameters: 618778


In [5]:
# ============================================
# Training Setup
# ============================================

os.makedirs("trained_models", exist_ok=True)

fnamenew = (
    f"trained_models/d{dist}_{error_model_name}_"
    f"{n_iters}_{n_node_features}_{n_edge_features}_"
)

optimizer = optim.AdamW(gnn.parameters(), lr=lr, weight_decay=weight_decay)

epochs = 200
batch_size = 64
criterion = nn.CrossEntropyLoss()

le_rates = np.zeros((epochs, 5))
start_time = time.time()

size = 2 * dist**2 - 1
error_index = dist**2 - 1

min_ler_tot = 1.0


# ============================================
# Prepare training set (regenerated each epoch)
# ============================================

trainset = adapt_trainset(
    generate_syndrome_error_volume(
        code,
        error_model=error_model,
        p=max_train_err_rate,
        batch_size=len_train_set,
    ),
    code,
    num_classes=n_node_inputs,
)

trainloader = DataLoader(trainset, batch_size=batch_size, collate_fn=collate, shuffle=False)


In [7]:
# ============================================
# Main Training Loop
# ============================================

print("Starting training...")

for epoch in range(epochs):

    gnn.train()
    epoch_losses = []

    for inputs, targets, src_ids, dst_ids in trainloader:

        inputs, targets = inputs.to(device), targets.to(device)
        src_ids, dst_ids = src_ids.to(device), dst_ids.to(device)

        loss = 0.0


    # ---- Forward pass (NO autocast on MPS) ----
    if device.type == "cuda" and use_amp:
        with torch.autocast(device_type="cuda", dtype=amp_data_type):
            outputs = gnn(inputs, src_ids, dst_ids)
    else:
        # MPS or CPU — no autocast
        outputs = gnn(inputs, src_ids, dst_ids)

    loss = 0.0
    for out in outputs:
        eloss = criterion(
            out.view(-1, size, n_node_inputs)[:, error_index:].reshape(-1, n_node_inputs),
            targets.view(-1, size)[:, error_index:].flatten(),
        )
        sloss = criterion(
            out.view(-1, size, n_node_inputs)[:, :error_index].reshape(-1, n_node_inputs),
            targets.view(-1, size)[:, :error_index].flatten(),
        )
        loss = loss + ler_loss(out, targets, code) + sloss + eloss

    loss = loss / outputs.shape[0]

    if device.type == "cuda" and use_amp:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        optimizer.step()

    optimizer.zero_grad()
    epoch_losses.append(loss.item())

    # ------------- END BATCH LOOP -------------

    # Evaluate every epoch
    gnn.eval()
    with torch.no_grad():
        fraction = fraction_of_solved_puzzles(gnn, testloader, code)
        test_loss = compute_accuracy(gnn, testloader, code)
        lerx, lerz, ler_tot = logical_error_rate(gnn, testloader, code)

    le_rates[epoch] = [lerx, lerz, ler_tot, test_loss, np.mean(epoch_losses)]

    print(f"Epoch {epoch} | LER_tot={ler_tot:.5f} | LERx={lerx:.5f} | LERz={lerz:.5f}")

    # Save best model
    if ler_tot < min_ler_tot:
        min_ler_tot = ler_tot
        save_model(gnn, fnamenew + f"gnn_best_{epoch}.pth {lerx}_{lerz}", confirm=False)

    if epoch % 10 == 0:
        np.save(fnamenew + "training_progress.npy", le_rates)

print("Training finished.")

Starting training...
Epoch 0 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 1 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 2 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 3 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 4 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 5 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 6 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 7 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 8 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 9 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 10 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 11 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 12 | LER_tot=1.00000 | LERx=0.87800 | LERz=0.62000
Epoch 13 | LER_tot=0.99800 | LERx=0.87800 | LERz=0.98000
Model saved to trained_models/d3_DP_3_50_50_gnn_best_13.pth 0.878_0.98.
Epoch 14 | LER_tot=0.99800 | LERx=0.87800 | LERz=0.98000
Epoch 15 | LER_tot=0.99800 | LERx=0.87800 | LERz=0.98000
Epoch