# This Notebook Trains MCCD.

In [10]:
ls data/bench/average_depolarizing_noise/fig4/mccd

circuits_d3_c3_D10.txt              labels_d3_c3_D10.pt
circuits_d3_c3_D12.txt              labels_d3_c3_D12.pt
circuits_d3_c3_D14.txt              labels_d3_c3_D14.pt
circuits_d3_c3_D16.txt              labels_d3_c3_D16.pt
circuits_d3_c3_D18.txt              labels_d3_c3_D18.pt
circuits_d3_c3_D2.txt               labels_d3_c3_D2.pt
circuits_d3_c3_D4.txt               labels_d3_c3_D4.pt
circuits_d3_c3_D6.txt               labels_d3_c3_D6.pt
circuits_d3_c3_D8.txt               labels_d3_c3_D8.pt
circuits_d5_c3_D10.txt              labels_d5_c3_D10.pt
circuits_d5_c3_D12.txt              labels_d5_c3_D12.pt
circuits_d5_c3_D14.txt              labels_d5_c3_D14.pt
circuits_d5_c3_D16.txt              labels_d5_c3_D16.pt
circuits_d5_c3_D18.txt              labels_d5_c3_D18.pt
circuits_d5_c3_D2.txt               labels_d5_c3_D2.pt
circuits_d5_c3_D4.txt               labels_d5_c3_D4.pt
circuits_d5_c3_D6.txt               labels_d5_c3_D6.pt
circuits_d5_c3_D8.txt               labels_d5_c3_D8.pt


In [12]:
!python train_1q_modules_cleaned.py \
        --code_distance 3 \
        --logical_circuit_index 3 \
        --if_final_round_syndrome \
        --batch_size 1000 \
        --run_index 0 \
        --model_save_path fig4_c3_d3 \
        --train_data_dir data/bench/average_depolarizing_noise/fig4/mccd \
        --val_data_dir data/bench/average_depolarizing_noise/fig4/mccd \
        --train_depth_list 2 4 6 8 \
        --val_depth_list 2 4 6 8 10 12 14 16 18


Namespace(code_distance=3, logical_circuit_index='3', batch_size=1000, if_final_round_syndrome=True, if_large_lstm_2q=False, load_model=None, run_index='0', model_save_path='fig4_c3_d3', train_data_dir='data/bench/average_depolarizing_noise/fig4/mccd', val_data_dir='data/bench/average_depolarizing_noise/fig4/mccd', train_depth_list=[2, 4, 6, 8], val_depth_list=[2, 4, 6, 8, 10, 12, 14, 16, 18], save_model_every_n_batches=500, validate_every_n_batches=100)
Number of params in model: 270240
Model save to: fig4_c3_d3
Save to trained_models/fig4_c3_d3
0it [00:00, ?it/s][Step 1000] Train Acc: 9.00% | Val Acc: 27.78%
Model checkpoint saved.
80it [00:08,  9.74it/s]
Final model saved.


In [24]:
if_final_round_syndrome=True


In [27]:
from mccd.dataset import MultiDepthCachedSyndromeDataset
from mccd import get_available_device
from torch.utils.data import DataLoader
import torch

def get_collate_function(if_final_round_syndrome=False):
    return lambda batch: collate_fn(batch, if_final_round_syndrome)


def collate_fn(batch, if_final_round_syndrome=False):
    syndromes = torch.cat([b['syndromes'].unsqueeze(0) for b in batch])
    label = torch.cat([b['label'].unsqueeze(0) for b in batch])
    circuit = batch[0]['circuit']

    if if_final_round_syndrome:
        final_round_syndromes = torch.cat(
            [b['final_round_syndromes'].unsqueeze(0) for b in batch])
        return {
            'syndromes': syndromes,
            'label': label,
            'circuit': circuit,
            'final_round_syndromes': final_round_syndromes
        }

    return {
        'syndromes': syndromes,
        'label': label,
        'circuit': circuit
    }

def validate(model, dataloader, device):
    correct = 0.0
    total = 0.0
    for i, data in enumerate(dataloader):
        syndromes = data['syndromes'].squeeze(0).float().to(device)
        labels = data['label'].long().squeeze(0).to(device)
        final_round_syndromes = data['final_round_syndromes'].squeeze(0).float().to(device)
        circuit = data['circuit']

        outputs = model(syndromes, circuit, final_round_syndromes)

        outputs = outputs[0]
        _, predicted = torch.max(outputs.data, -1)

        total += labels.size(0) * labels.size(1)
        correct += (predicted == labels).sum().item()

    return 100 * correct / total


cf = get_collate_function(if_final_round_syndrome=if_final_round_syndrome)

val_dataset = MultiDepthCachedSyndromeDataset(
    root_dir='data/bench/average_depolarizing_noise/fig4/mccd',
    code_distance=3,
    circuit_index='3',
    batch_size=1000,
    depth_list=list(map(int, '2 4 6 8 10 12 14 16 18'.split()))
)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=cf)

device = get_available_device()

model = torch.load('./trained_models/fig4_c3_d3/model_0.pt', weights_only=False).to(device)

In [28]:
testing_acc = validate(model, val_dataloader, device)


In [29]:
testing_acc

72.22444444444444

In [32]:
print(" ".join(map(str, [ 4,  8, 12, 16, 20, 24, 28, 32, 36])))

4 8 12 16 20 24 28 32 36


In [33]:
!python train_2q_module_cleaned.py \
        --code_distance 3 \
        --logical_circuit_index 4 \
        --if_final_round_syndrome \
        --batch_size 1000 \
        --run_index 0 \
        --model_save_path fig5_c4_d3 \
        --train_data_dir data/bench/average_depolarizing_noise/fig5/mccd \
        --val_data_dir data/bench/average_depolarizing_noise/fig5/mccd \
          --train_depth_list 4 8 12 16 \
        --val_depth_list 4 8 12 16 20 24 28 32 36 \
        --load_model trained_models/fig4_c3_d3


Namespace(code_distance=3, logical_circuit_index='4', batch_size=1000, if_final_round_syndrome=True, if_large_lstm_2q=False, load_model='trained_models/fig4_c3_d3', run_index='0', model_save_path='fig5_c4_d3', train_data_dir='data/bench/average_depolarizing_noise/fig5/mccd', val_data_dir='data/bench/average_depolarizing_noise/fig5/mccd', train_depth_list=[4, 8, 12, 16], val_depth_list=[4, 8, 12, 16, 20, 24, 28, 32, 36], save_model_every_n_batches=500, validate_every_n_batches=100)
Loading model from: trained_models/fig4_c3_d3
Number of params in model: 108544
Model save to: fig5_c4_d3
Save to trained_models/fig5_c4_d3
[Step 1000] Train Acc: 67.30% | Val Acc: 53.14%
Model checkpoint saved.
Final model saved.
