In [None]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import torch

from qecdec import RotatedSurfaceCode_Memory
from learned_decoders import *
from train_utils import *
from datetime import datetime

d = 5
rounds = 5
p = 0.01
num_bp_iters = 5
learning_rate = 0.002
loss_beta = 1.0

In [None]:
expmt = RotatedSurfaceCode_Memory(
    d=d,
    rounds=rounds,
    basis='Z',
    data_qubit_error_rate=p,
    meas_error_rate=p,
)
print("Number of error mechanisms:", expmt.num_error_mechanisms)
print("Number of detectors:", expmt.num_detectors)
print("Number of observables:", expmt.num_observables)

Number of error mechanisms: 186
Number of detectors: 72
Number of observables: 1


In [3]:
train_dataset, val_dataset = build_datasets(
    expmt,
    train_shots=10_000,
    val_shots=1_000,
    seed=42,
)
print("Size of train_dataset:", len(train_dataset))
print("Size of val_dataset:", len(val_dataset))

model = LearnedDMemBP(expmt.chkmat, expmt.prior, num_iters=num_bp_iters)
loss_fn = DecodingLoss(expmt.chkmat, expmt.obsmat, beta=loss_beta)
metric = DecodingMetric(expmt.chkmat, expmt.obsmat)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Size of train_dataset: 25556
Size of val_dataset: 815


In [4]:
train_gamma(
    model,
    train_dataset,
    val_dataset,
    loss_fn,
    metric,
    optimizer,
    num_epochs=20,
    batch_size=256,
    device="cpu",
    scheduler_kwargs=dict(factor=0.2, patience=3, threshold=1e-3, threshold_mode="abs"),
    early_stopper=EarlyStopper(patience=5, min_delta=1e-3),
)

Using cpu device


Epoch 1/20: 100%|██████████| 100/100 [04:25<00:00,  2.66s/it, avg_loss=0.766982, grad_norm=0.249598]


Epoch 1 Summary:
  Avg Train Loss: 0.766982
  Avg Val Loss: 0.829145
  wrong_syndrome_rate: 0.053988
  wrong_observable_rate: 0.011043
  failure_rate: 0.053988
  Learning Rate: 0.002000



Epoch 2/20: 100%|██████████| 100/100 [04:25<00:00,  2.65s/it, avg_loss=0.703902, grad_norm=0.124192]


Epoch 2 Summary:
  Avg Train Loss: 0.703902
  Avg Val Loss: 0.783050
  wrong_syndrome_rate: 0.051534
  wrong_observable_rate: 0.011043
  failure_rate: 0.051534
  Learning Rate: 0.002000



Epoch 3/20: 100%|██████████| 100/100 [04:25<00:00,  2.65s/it, avg_loss=0.677188, grad_norm=0.119595]


Epoch 3 Summary:
  Avg Train Loss: 0.677188
  Avg Val Loss: 0.763699
  wrong_syndrome_rate: 0.045399
  wrong_observable_rate: 0.012270
  failure_rate: 0.045399
  Learning Rate: 0.002000



Epoch 4/20: 100%|██████████| 100/100 [04:22<00:00,  2.63s/it, avg_loss=0.663185, grad_norm=0.155058]


Epoch 4 Summary:
  Avg Train Loss: 0.663185
  Avg Val Loss: 0.754795
  wrong_syndrome_rate: 0.045399
  wrong_observable_rate: 0.012270
  failure_rate: 0.045399
  Learning Rate: 0.002000



Epoch 5/20: 100%|██████████| 100/100 [04:24<00:00,  2.65s/it, avg_loss=0.653930, grad_norm=0.152184]


Epoch 5 Summary:
  Avg Train Loss: 0.653930
  Avg Val Loss: 0.751365
  wrong_syndrome_rate: 0.044172
  wrong_observable_rate: 0.012270
  failure_rate: 0.044172
  Learning Rate: 0.002000



Epoch 6/20: 100%|██████████| 100/100 [04:25<00:00,  2.65s/it, avg_loss=0.647820, grad_norm=0.146551]


Epoch 6 Summary:
  Avg Train Loss: 0.647820
  Avg Val Loss: 0.750128
  wrong_syndrome_rate: 0.038037
  wrong_observable_rate: 0.013497
  failure_rate: 0.039264
  Learning Rate: 0.002000



Epoch 7/20: 100%|██████████| 100/100 [04:27<00:00,  2.67s/it, avg_loss=0.642901, grad_norm=0.133063]


Epoch 7 Summary:
  Avg Train Loss: 0.642901
  Avg Val Loss: 0.750312
  wrong_syndrome_rate: 0.035583
  wrong_observable_rate: 0.013497
  failure_rate: 0.036810
  Learning Rate: 0.002000



Epoch 8/20: 100%|██████████| 100/100 [04:25<00:00,  2.66s/it, avg_loss=0.639216, grad_norm=0.121764]


Epoch 8 Summary:
  Avg Train Loss: 0.639216
  Avg Val Loss: 0.751466
  wrong_syndrome_rate: 0.034356
  wrong_observable_rate: 0.013497
  failure_rate: 0.035583
  Learning Rate: 0.002000



Epoch 9/20: 100%|██████████| 100/100 [04:24<00:00,  2.65s/it, avg_loss=0.636516, grad_norm=0.084022]


Epoch 9 Summary:
  Avg Train Loss: 0.636516
  Avg Val Loss: 0.752080
  wrong_syndrome_rate: 0.031902
  wrong_observable_rate: 0.013497
  failure_rate: 0.033129
  Learning Rate: 0.002000



Epoch 10/20: 100%|██████████| 100/100 [04:24<00:00,  2.65s/it, avg_loss=0.634569, grad_norm=0.113857]


Epoch 10 Summary:
  Avg Train Loss: 0.634569
  Avg Val Loss: 0.751789
  wrong_syndrome_rate: 0.033129
  wrong_observable_rate: 0.012270
  failure_rate: 0.034356
  Learning Rate: 0.000400



Epoch 11/20: 100%|██████████| 100/100 [04:23<00:00,  2.64s/it, avg_loss=0.632954, grad_norm=0.103478]


Epoch 11 Summary:
  Avg Train Loss: 0.632954
  Avg Val Loss: 0.751791
  wrong_syndrome_rate: 0.033129
  wrong_observable_rate: 0.012270
  failure_rate: 0.034356
  Learning Rate: 0.000400



Epoch 12/20: 100%|██████████| 100/100 [04:24<00:00,  2.64s/it, avg_loss=0.632980, grad_norm=0.163386]


Epoch 12 Summary:
  Avg Train Loss: 0.632980
  Avg Val Loss: 0.751724
  wrong_syndrome_rate: 0.033129
  wrong_observable_rate: 0.012270
  failure_rate: 0.034356
  Learning Rate: 0.000400

Early stopping triggered


In [None]:
gamma = model.gamma.detach().numpy().astype(np.float64)
now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
np.save(f"learned_params/dmembp_{now}_gamma.npy", gamma)