In [9]:
%reload_ext autoreload
%autoreload 2
import numpy as np
import torch

from qecdec import RotatedSurfaceCode_Memory
from learned_decoders import *
from train_utils import *
from datetime import datetime
import pickle

d = 5
rounds = 5
p = 0.01
num_bp_iters = 5
learning_rate = 0.002
loss_beta = 1.0
train_offset = False
train_nf = True

In [2]:
expmt = RotatedSurfaceCode_Memory(
    d=d,
    rounds=rounds,
    basis='Z',
    data_qubit_error_rate=p,
    meas_error_rate=p,
)
print("Number of error mechanisms:", expmt.num_error_mechanisms)
print("Number of detectors:", expmt.num_detectors)
print("Number of observables:", expmt.num_observables)

Number of error mechanisms: 186
Number of detectors: 72
Number of observables: 1


In [None]:
train_dataset, val_dataset = build_datasets(
    expmt,
    train_shots=10_000,
    val_shots=1_000,
    seed=42,
)
print("Size of train_dataset:", len(train_dataset))
print("Size of val_dataset:", len(val_dataset))

model = LearnedDMemOffNormBP(expmt.chkmat, expmt.prior, num_iters=num_bp_iters, train_offset=train_offset, train_nf=train_nf)
loss_fn = DecodingLoss(expmt.chkmat, expmt.obsmat, beta=loss_beta)
metric = DecodingMetric(expmt.chkmat, expmt.obsmat)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Size of train_dataset: 25556
Size of val_dataset: 815


In [4]:
train_gamma(
    model,
    train_dataset,
    val_dataset,
    loss_fn,
    metric,
    optimizer,
    num_epochs=20,
    batch_size=256,
    device="cpu",
    scheduler_kwargs=dict(factor=0.2, patience=3, threshold=1e-3, threshold_mode="abs"),
    early_stopper=EarlyStopper(patience=5, min_delta=1e-3),
)

Using cpu device


Epoch 1/20: 100%|██████████| 100/100 [07:02<00:00,  4.22s/it, avg_loss=0.706679, grad_norm=0.165694]


Epoch 1 Summary:
  Avg Train Loss: 0.706679
  Avg Val Loss: 0.736395
  wrong_syndrome_rate: 0.047853
  wrong_observable_rate: 0.006135
  failure_rate: 0.047853
  Learning Rate: 0.002000



Epoch 2/20: 100%|██████████| 100/100 [06:42<00:00,  4.02s/it, avg_loss=0.606926, grad_norm=0.134927]


Epoch 2 Summary:
  Avg Train Loss: 0.606926
  Avg Val Loss: 0.685253
  wrong_syndrome_rate: 0.041718
  wrong_observable_rate: 0.006135
  failure_rate: 0.041718
  Learning Rate: 0.002000



Epoch 3/20: 100%|██████████| 100/100 [06:41<00:00,  4.02s/it, avg_loss=0.575189, grad_norm=0.167043]


Epoch 3 Summary:
  Avg Train Loss: 0.575189
  Avg Val Loss: 0.664348
  wrong_syndrome_rate: 0.035583
  wrong_observable_rate: 0.009816
  failure_rate: 0.035583
  Learning Rate: 0.002000



Epoch 4/20: 100%|██████████| 100/100 [07:00<00:00,  4.20s/it, avg_loss=0.560948, grad_norm=0.139897]


Epoch 4 Summary:
  Avg Train Loss: 0.560948
  Avg Val Loss: 0.654378
  wrong_syndrome_rate: 0.033129
  wrong_observable_rate: 0.009816
  failure_rate: 0.033129
  Learning Rate: 0.002000



Epoch 5/20: 100%|██████████| 100/100 [06:49<00:00,  4.09s/it, avg_loss=0.553642, grad_norm=0.119104]


Epoch 5 Summary:
  Avg Train Loss: 0.553642
  Avg Val Loss: 0.648804
  wrong_syndrome_rate: 0.033129
  wrong_observable_rate: 0.011043
  failure_rate: 0.033129
  Learning Rate: 0.002000



Epoch 6/20: 100%|██████████| 100/100 [07:03<00:00,  4.23s/it, avg_loss=0.549233, grad_norm=0.140235]


Epoch 6 Summary:
  Avg Train Loss: 0.549233
  Avg Val Loss: 0.645372
  wrong_syndrome_rate: 0.030675
  wrong_observable_rate: 0.011043
  failure_rate: 0.030675
  Learning Rate: 0.002000



Epoch 7/20: 100%|██████████| 100/100 [07:46<00:00,  4.66s/it, avg_loss=0.546304, grad_norm=0.135438]


Epoch 7 Summary:
  Avg Train Loss: 0.546304
  Avg Val Loss: 0.643604
  wrong_syndrome_rate: 0.028221
  wrong_observable_rate: 0.011043
  failure_rate: 0.028221
  Learning Rate: 0.002000



Epoch 8/20: 100%|██████████| 100/100 [08:20<00:00,  5.01s/it, avg_loss=0.544280, grad_norm=0.115290]


Epoch 8 Summary:
  Avg Train Loss: 0.544280
  Avg Val Loss: 0.643180
  wrong_syndrome_rate: 0.030675
  wrong_observable_rate: 0.011043
  failure_rate: 0.030675
  Learning Rate: 0.002000



Epoch 9/20: 100%|██████████| 100/100 [08:26<00:00,  5.07s/it, avg_loss=0.542905, grad_norm=0.114171]


Epoch 9 Summary:
  Avg Train Loss: 0.542905
  Avg Val Loss: 0.642101
  wrong_syndrome_rate: 0.030675
  wrong_observable_rate: 0.012270
  failure_rate: 0.030675
  Learning Rate: 0.002000



Epoch 10/20: 100%|██████████| 100/100 [09:41<00:00,  5.82s/it, avg_loss=0.541889, grad_norm=0.164322]


Epoch 10 Summary:
  Avg Train Loss: 0.541889
  Avg Val Loss: 0.642250
  wrong_syndrome_rate: 0.029448
  wrong_observable_rate: 0.011043
  failure_rate: 0.029448
  Learning Rate: 0.002000



Epoch 11/20: 100%|██████████| 100/100 [09:03<00:00,  5.44s/it, avg_loss=0.540859, grad_norm=0.140868]


Epoch 11 Summary:
  Avg Train Loss: 0.540859
  Avg Val Loss: 0.641284
  wrong_syndrome_rate: 0.028221
  wrong_observable_rate: 0.011043
  failure_rate: 0.028221
  Learning Rate: 0.002000



Epoch 12/20: 100%|██████████| 100/100 [06:56<00:00,  4.17s/it, avg_loss=0.540219, grad_norm=0.138315]


Epoch 12 Summary:
  Avg Train Loss: 0.540219
  Avg Val Loss: 0.639943
  wrong_syndrome_rate: 0.025767
  wrong_observable_rate: 0.008589
  failure_rate: 0.025767
  Learning Rate: 0.002000



Epoch 13/20: 100%|██████████| 100/100 [07:04<00:00,  4.24s/it, avg_loss=0.539877, grad_norm=0.157961]


Epoch 13 Summary:
  Avg Train Loss: 0.539877
  Avg Val Loss: 0.639934
  wrong_syndrome_rate: 0.026994
  wrong_observable_rate: 0.009816
  failure_rate: 0.026994
  Learning Rate: 0.002000



Epoch 14/20: 100%|██████████| 100/100 [07:02<00:00,  4.23s/it, avg_loss=0.539336, grad_norm=0.107047]


Epoch 14 Summary:
  Avg Train Loss: 0.539336
  Avg Val Loss: 0.638983
  wrong_syndrome_rate: 0.024540
  wrong_observable_rate: 0.008589
  failure_rate: 0.024540
  Learning Rate: 0.002000



Epoch 15/20: 100%|██████████| 100/100 [07:04<00:00,  4.24s/it, avg_loss=0.538892, grad_norm=0.094967]


Epoch 15 Summary:
  Avg Train Loss: 0.538892
  Avg Val Loss: 0.639920
  wrong_syndrome_rate: 0.024540
  wrong_observable_rate: 0.007362
  failure_rate: 0.024540
  Learning Rate: 0.002000



Epoch 16/20: 100%|██████████| 100/100 [07:36<00:00,  4.56s/it, avg_loss=0.538826, grad_norm=0.148664]


Epoch 16 Summary:
  Avg Train Loss: 0.538826
  Avg Val Loss: 0.638933
  wrong_syndrome_rate: 0.026994
  wrong_observable_rate: 0.007362
  failure_rate: 0.026994
  Learning Rate: 0.002000



Epoch 17/20: 100%|██████████| 100/100 [07:10<00:00,  4.30s/it, avg_loss=0.538380, grad_norm=0.125235]


Epoch 17 Summary:
  Avg Train Loss: 0.538380
  Avg Val Loss: 0.638336
  wrong_syndrome_rate: 0.026994
  wrong_observable_rate: 0.007362
  failure_rate: 0.026994
  Learning Rate: 0.002000



Epoch 18/20: 100%|██████████| 100/100 [07:10<00:00,  4.30s/it, avg_loss=0.538267, grad_norm=0.134970]


Epoch 18 Summary:
  Avg Train Loss: 0.538267
  Avg Val Loss: 0.638609
  wrong_syndrome_rate: 0.026994
  wrong_observable_rate: 0.008589
  failure_rate: 0.026994
  Learning Rate: 0.002000



Epoch 19/20: 100%|██████████| 100/100 [07:12<00:00,  4.33s/it, avg_loss=0.538015, grad_norm=0.130692]


Epoch 19 Summary:
  Avg Train Loss: 0.538015
  Avg Val Loss: 0.638925
  wrong_syndrome_rate: 0.025767
  wrong_observable_rate: 0.007362
  failure_rate: 0.025767
  Learning Rate: 0.002000



Epoch 20/20: 100%|██████████| 100/100 [07:11<00:00,  4.31s/it, avg_loss=0.537913, grad_norm=0.128910]


Epoch 20 Summary:
  Avg Train Loss: 0.537913
  Avg Val Loss: 0.639143
  wrong_syndrome_rate: 0.023313
  wrong_observable_rate: 0.007362
  failure_rate: 0.023313
  Learning Rate: 0.000400



In [10]:
now = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

gamma = model.gamma.detach().numpy().astype(np.float64)
np.save(f"learned_params/dmemoffnormbp_{now}_gamma.npy", gamma)

if train_offset:
    offset = [x.detach().tolist() for x in model.offset] # list[list[float]]
    with open(f"learned_params/dmemoffnormbp_{now}_offset.pkl", "wb") as f:
        pickle.dump(offset, f)

if train_nf:
    nf = [x.detach().tolist() for x in model.nf] # list[list[float]]
    with open(f"learned_params/dmemoffnormbp_{now}_nf.pkl", "wb") as f:
        pickle.dump(nf, f)
