In [1]:
%reload_ext autoreload
%autoreload 2
import sys
import os
from pathlib import Path

from torch.utils.data import DataLoader
from torch.profiler import profile, ProfilerActivity, record_function
from qecdec import RotatedSurfaceCode_Memory

d = 7
p = 0.01
rounds = d
batch_size = 256

decoder_kwargs = dict(num_iters=5, node_features=16, edge_features=16,
                    mlp_hidden_size=16, mlp_hidden_layers=2,
                    mlp_dropout_p=0.05, gru_dropout_p=0.05)
loss_fn_kwargs = dict(beta=0.8)

device = "cpu"

pytorch_dir = Path(os.getcwd()).parent
sys.path.append(str(pytorch_dir))

output_dir = pytorch_dir / "runs"
output_dir.mkdir(parents=True, exist_ok=True)


In [2]:
expmt = RotatedSurfaceCode_Memory(
    d=d,
    rounds=rounds,
    basis='Z',
    data_qubit_error_rate=p,
    meas_error_rate=p,
)
print("Number of error mechanisms:", expmt.num_error_mechanisms)
print("Number of detectors:", expmt.num_detectors)
print("Number of observables:", expmt.num_observables)

Number of error mechanisms: 512
Number of detectors: 192
Number of observables: 1


In [3]:
# Set up dataloaders.
from src.dataset import DecodingDataset

dataset_dir = pytorch_dir / "datasets" / "rotated_surface_code_memory_Z" / f"d={d}_rounds={rounds}_p={p}"
train_dataset = DecodingDataset.load_from_file(dataset_dir / "train_dataset.pt")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [4]:
# Set up decoder model and loss function.
from src.models import GNNDecoder
from src.training import IterativeDecodingLoss

decoder = GNNDecoder(expmt.chkmat, **decoder_kwargs)
loss_fn = IterativeDecodingLoss(expmt.chkmat, expmt.obsmat, **loss_fn_kwargs)

In [5]:
decoder = decoder.to(device)
decoder.train()
for syndromes, observables in train_dataloader:
    syndromes = syndromes.to(device)
    observables = observables.to(device)

    with profile(
        activities=[ProfilerActivity.CPU],
        with_stack=True,
    ) as prof:
        with record_function("forward pass"):
            llrs = decoder(syndromes)
            loss = loss_fn(llrs, syndromes, observables)
    print("=" * 100)
    prof.export_chrome_trace(str(output_dir / "forward_trace.json"))
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    with profile(
        activities=[ProfilerActivity.CPU],
        with_stack=True,
    ) as prof:
        with record_function("backward pass"):
            loss.backward()
    print("=" * 100)
    prof.export_chrome_trace(str(output_dir / "backward_trace.json"))
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

    break

------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                              forward pass         4.16%      37.539ms       100.00%     902.297ms     902.297ms             1  
                             aten::dropout         0.06%     510.330us        43.29%     390.626ms      15.625ms            25  
                          aten::bernoulli_        36.95%     333.428ms        36.95%     333.428ms      13.337ms            25  
                              aten::linear         0.06%     546.407us        24.21%     218.427ms       3.360ms            65  
                               aten::addmm        15.39%     138.874ms        23.99%     216.477m