In [1]:
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from panqec.codes import surface_2d
from panqec.error_models import PauliErrorModel
from panqec.decoders import BeliefPropagationOSDDecoder, MatchingDecoder
from fnn_model import FNNDecoder
from fnn_data import make_fnn_dataset
from train_fnn import train_fnn_for_distance


from ldpc.mod2 import nullspace

from panq_functions import (
    GNNDecoder,
    collate,
    generate_syndrome_error_volume,
    adapt_trainset,
    logical_error_rate,
    fraction_of_solved_puzzles,
    surface_code_edges,
    load_model,
    save_model
)

# ------------------------------------------
# Device selection: MPS (Mac GPU)  → CPU
# ------------------------------------------
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

# ==========================================================
# Parameters (must match training!)
# ==========================================================

error_model_name = "DP"
if error_model_name == "X":
    error_model = PauliErrorModel(1, 0.0, 0)
elif error_model_name == "Z":
    error_model = PauliErrorModel(0, 0.0, 1)
elif error_model_name == "XZ":
    error_model = PauliErrorModel(0.5, 0.0, 0.5)
elif error_model_name == "DP":
    error_model = PauliErrorModel(0.34, 0.32, 0.34)

n_node_inputs = 4
n_node_outputs = 4
n_iters = 3
n_node_features = 50
n_edge_features = 50

len_test_set = 100        # tiny monitoring set (like original code)
test_err_rate = 0.05

len_train_set = len_test_set * 10
max_train_err_rate = 0.15

lr = 1e-4
weight_decay = 1e-4

msg_net_size = 512
msg_net_dropout_p = 0.05
gru_dropout_p = 0.05

epochs = 100
batch_size = 64
criterion = nn.CrossEntropyLoss()

distances = [3, 5, 7, 9, 11]

print("n_iters:", n_iters,
      "n_node_outputs:", n_node_outputs,
      "n_node_features:", n_node_features,
      "n_edge_features:", n_edge_features)
print("msg_net_size:", msg_net_size,
      "msg_net_dropout_p:", msg_net_dropout_p,
      "gru_dropout_p:", gru_dropout_p)
print("learning rate:", lr,
      "weight decay:", weight_decay,
      "len train set:", len_train_set,
      "max train err rate:", max_train_err_rate,
      "len test set:", len_test_set,
      "test err rate:", test_err_rate)

os.makedirs("trained_models", exist_ok=True)

(4, 9)
(4, 9)
Using device: mps
n_iters: 3 n_node_outputs: 4 n_node_features: 50 n_edge_features: 50
msg_net_size: 512 msg_net_dropout_p: 0.05 gru_dropout_p: 0.05
learning rate: 0.0001 weight decay: 0.0001 len train set: 1000 max train err rate: 0.15 len test set: 100 test err rate: 0.05


In [18]:
# ======================================================
# main loop over distances
# ======================================================
for d in distances:
    dist = d
    print("\n=========================================")
    print("Training distance", d)
    print("=========================================")
    best_model_path = f"trained_models/d{dist}_{error_model_name}_best.pth"
    best_ler_tot = float("inf")

    # ---------------- create code & GNN -------------
    code = surface_2d.RotatedPlanar2DCode(dist)
    gnn = GNNDecoder(
        dist=dist,
        n_node_inputs=n_node_inputs,
        n_node_outputs=n_node_outputs,
        n_iters=n_iters,
        n_node_features=n_node_features,
        n_edge_features=n_edge_features,
        msg_net_size=msg_net_size,
        msg_net_dropout_p=msg_net_dropout_p,
        gru_dropout_p=gru_dropout_p
    ).to(device)

    src, tgt = surface_code_edges(code)
    src_tensor = torch.LongTensor(src)
    tgt_tensor = torch.LongTensor(tgt)
    GNNDecoder.surface_code_edges = (src_tensor, tgt_tensor)

    Hx_dense = np.asarray(code.Hx.toarray(), dtype=np.uint8)
    Hz_dense = np.asarray(code.Hz.toarray(), dtype=np.uint8)

    hx_null = nullspace(Hx_dense)
    hz_null = nullspace(Hz_dense)

# nullspace() returns a scipy sparse matrix → make it dense
    if hasattr(hx_null, "toarray"):
        hx_null = hx_null.toarray()
    if hasattr(hz_null, "toarray"):
        hz_null = hz_null.toarray()

    hx_null = np.asarray(hx_null, dtype=np.float32)
    hz_null = np.asarray(hz_null, dtype=np.float32)

    hxperp = torch.tensor(hx_null, dtype=torch.float32, device=device)
    hzperp = torch.tensor(hz_null, dtype=torch.float32, device=device)
    GNNDecoder.hxperp = hxperp
    GNNDecoder.hzperp = hzperp
    GNNDecoder.device = device

    total_params = sum(p.numel() for p in gnn.parameters())
    print("Total params:", total_params)

    # ---------- optional warm-start loading ----------
    fnameload = (
        f"trained_models/d{d}_{error_model_name}_30_500_500_200000_0.15_10000_0.05_gnn.pth 0.0144_0.0044 37"
    )
    model_loaded = False
    if os.path.isfile(fnameload):
        load_model(gnn, fnameload, device)
        model_loaded = True
        print("Loaded pre-trained model:", fnameload)

    # ---------- test set (monitoring) ---------------
    testset = adapt_trainset(
        generate_syndrome_error_volume(
            code, error_model=error_model,
            p=test_err_rate,
            batch_size=len_test_set,
            for_training=False
        ),
        code, num_classes=n_node_inputs, for_training=False
    )
    testloader = DataLoader(testset, batch_size=512,
                            collate_fn=collate, shuffle=False)

    # ---------- output filename base -----------------
    fnamenew = (
        f"trained_models/d{dist}_{error_model_name}_"
        f"{n_iters}_{n_node_features}_{n_edge_features}_"
        f"{len_train_set}_{max_train_err_rate}_{len_test_set}_"
        f"{test_err_rate}_{lr}_{weight_decay}_"
        f"{msg_net_size}_{msg_net_dropout_p}_{gru_dropout_p}_"
    )
    if model_loaded:
        fnamenew = (
            f"trained_models/d{dist}_from_d{d}_{error_model_name}_"
            f"{n_iters}_{n_node_features}_{n_edge_features}_"
            f"{len_train_set}_{max_train_err_rate}_{len_test_set}_"
            f"{test_err_rate}_{lr}_{weight_decay}_"
            f"{msg_net_size}_{msg_net_dropout_p}_{gru_dropout_p}_"
        )

    # ---------- optimizer & scheduler ----------------
    optimizer = optim.AdamW(gnn.parameters(), lr=lr, weight_decay=weight_decay)

    exploration_samples = 10 ** 7
    lr_reduce_epoch_step = exploration_samples // len_train_set
    max_training_data = 10 ** 8
    end_training_epoch = max_training_data // len_train_set

    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=lr_reduce_epoch_step, gamma=0.1
    )

    scaler = None

    le_rates = np.zeros((epochs, 5), dtype=float)
    start_time = time.time()

    size = 2 * GNNDecoder.dist ** 2 - 1
    error_index = GNNDecoder.dist ** 2 - 1

    min_test_err_rate = test_err_rate
    min_lerz = test_err_rate

    # ---------- training data (generated once) -------
    trainset = adapt_trainset(
        generate_syndrome_error_volume(
            code, error_model, p=max_train_err_rate,
            batch_size=len_train_set
        ),
        code, num_classes=n_node_inputs
    )
    trainloader = DataLoader(
        trainset, batch_size=batch_size, collate_fn=collate, shuffle=False
    )

    print("epoch, lr, wd, fract. corr. synd, LER_X, LER_Z, LER_tot, "
          "test loss, train loss, train time")

    # =================================================
    # epoch loop
    # =================================================
    for epoch in range(epochs):
        gnn.train()
        if epoch == end_training_epoch:
            break

        epoch_loss = []
        for i, (inputs, targets, src_ids, dst_ids) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            src_ids, dst_ids = src_ids.to(device), dst_ids.to(device)
            loss = 0.0

            outputs = gnn(inputs, src_ids, dst_ids)
            for out in outputs:
                out_view = out.view(-1, size, n_node_inputs)
                targ_view = targets.view(-1, size)

                eloss = criterion(
                    out_view[:, error_index:].reshape(-1, n_node_inputs),
                    targ_view[:, error_index:].flatten()
                )
                sloss = criterion(
                    out_view[:, :error_index].reshape(-1, n_node_inputs),
                    targ_view[:, :error_index].flatten()
                )

                loss = loss + ler_loss(out, targets, code) + sloss + eloss

            loss = loss / outputs.shape[0]

            if scaler is not None:      # CUDA AMP
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:                       # MPS / CPU
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

            epoch_loss.append(loss.detach())

        epoch_loss = torch.mean(torch.tensor(epoch_loss)).item()

       
        fraction_solved = fraction_of_solved_puzzles(gnn, testloader, code)
        test_loss = compute_accuracy(gnn, testloader, code)
        lerx, lerz, ler_tot = logical_error_rate(gnn, testloader, code)

        scheduler.step()
        le_rates[epoch, 0] = lerx
        le_rates[epoch, 1] = lerz
        le_rates[epoch, 2] = ler_tot
        le_rates[epoch, 3] = test_loss
        le_rates[epoch, 4] = epoch_loss
        curr_time = time.time() - start_time

        print(epoch, optimizer.param_groups[0]['lr'],
              optimizer.param_groups[0]["weight_decay"],
              fraction_solved, lerx, lerz, ler_tot,
              test_loss, epoch_loss, curr_time)

        # ---- save best model by total LER only ----
        if ler_tot < best_ler_tot:
            best_ler_tot = ler_tot
            save_model(gnn, best_model_path, confirm=False)

        # ---- (optional) save training history every 10 epochs ----
        if epoch % 10 == 0:
            np.save(fnamenew + 'training_lers_and_losses.npy', le_rates)


Training distance 3
Total params: 618778
epoch, lr, wd, fract. corr. synd, LER_X, LER_Z, LER_tot, test loss, train loss, train time
0 0.0001 0.0001 0.68 0.14 0.32 0.32 3.086498975753784 3.211421012878418 4.49920392036438
Model saved to trained_models/d3_DP_best.pth.
1 0.0001 0.0001 0.68 0.14 0.32 0.32 2.859165906906128 3.023329973220825 5.198629140853882
2 0.0001 0.0001 0.68 0.14 0.32 0.32 2.6998727321624756 2.862102508544922 5.8637590408325195
3 0.0001 0.0001 0.68 0.14 0.32 0.32 2.560600996017456 2.737417221069336 6.512370824813843
4 0.0001 0.0001 0.68 0.14 0.32 0.32 2.4429099559783936 2.642266035079956 7.150183916091919
5 0.0001 0.0001 0.68 0.14 0.32 0.32 2.358523368835449 2.561842918395996 7.804999113082886
6 0.0001 0.0001 0.68 0.21 1.0 1.0 2.2933504581451416 2.504667282104492 8.433401107788086
7 0.0001 0.0001 0.68 0.14 0.32 0.32 2.2406599521636963 2.4602503776550293 9.07447099685669
8 0.0001 0.0001 0.68 0.14 0.32 0.32 2.210193634033203 2.431593418121338 9.712883949279785
9 0.0001 

In [2]:
for d in distances:
    dist = d
    print("\n=========================================")
    print("Training distance", d)
    print("=========================================")

    train_fnn_for_distance(dist,"mps")


Training distance 3
[d=3] epoch 1, loss=17.144
[d=3] epoch 2, loss=10.583
[d=3] epoch 3, loss=8.167
[d=3] epoch 4, loss=7.454
[d=3] epoch 5, loss=7.086
[d=3] epoch 6, loss=6.762
[d=3] epoch 7, loss=6.262
[d=3] epoch 8, loss=6.141
[d=3] epoch 9, loss=5.920
[d=3] epoch 10, loss=5.753
[d=3] epoch 11, loss=5.565
[d=3] epoch 12, loss=5.401
[d=3] epoch 13, loss=5.274
[d=3] epoch 14, loss=5.141
[d=3] epoch 15, loss=5.088
[d=3] epoch 16, loss=4.957
[d=3] epoch 17, loss=4.828
[d=3] epoch 18, loss=4.771
[d=3] epoch 19, loss=4.734
[d=3] epoch 20, loss=4.587
[d=3] epoch 21, loss=4.620
[d=3] epoch 22, loss=4.477
[d=3] epoch 23, loss=4.492
[d=3] epoch 24, loss=4.460
[d=3] epoch 25, loss=4.458
[d=3] epoch 26, loss=4.361
[d=3] epoch 27, loss=4.319
[d=3] epoch 28, loss=4.276
[d=3] epoch 29, loss=4.341
[d=3] epoch 30, loss=4.223
[d=3] epoch 31, loss=4.190
[d=3] epoch 32, loss=4.171
[d=3] epoch 33, loss=4.143
[d=3] epoch 34, loss=4.128
[d=3] epoch 35, loss=4.052
[d=3] epoch 36, loss=4.034
[d=3] epoch 37