# This Notebook Trains MCCD.

In [None]:
from mccd.dataset import *
from mccd import get_available_device
from torch.utils.data import DataLoader
import torch
import time
from pathlib import Path
import pandas as pd


def get_collate_function(if_final_round_syndrome=False):
    return lambda batch: collate_fn(batch, if_final_round_syndrome)


def collate_fn(batch, if_final_round_syndrome=False):
    syndromes = torch.cat([b['syndromes'].unsqueeze(0) for b in batch])
    label = torch.cat([b['label'].unsqueeze(0) for b in batch])
    circuit = batch[0]['circuit']

    if if_final_round_syndrome:
        final_round_syndromes = torch.cat(
            [b['final_round_syndromes'].unsqueeze(0) for b in batch])
        return {
            'syndromes': syndromes,
            'label': label,
            'circuit': circuit,
            'final_round_syndromes': final_round_syndromes
        }

    return {
        'syndromes': syndromes,
        'label': label,
        'circuit': circuit
    }

def validate(model, dataloader, device):
    correct = 0.0
    total = 0.0
    for i, data in enumerate(dataloader):
        # 这里已经是做了20次随机电路的平均值了！
        syndromes = data['syndromes'].squeeze(0).float().to(device)
        labels = data['label'].long().squeeze(0).to(device)
        final_round_syndromes = data['final_round_syndromes'].squeeze(0).float().to(device)
        circuit = data['circuit']

        outputs = model(syndromes, circuit, final_round_syndromes)

        outputs = outputs[0]
        _, predicted = torch.max(outputs.data, -1)

        total += labels.size(0) * labels.size(1)
        correct += (predicted == labels).sum().item()

    return correct / total


def evaluate_model(model_name: str,
                    data_dir, distance, circuit_index, depth_list, batch_size=1000,
                    if_final_round_syndrome=True):

    model_path = Path('./trained_models') / model_name / 'model_0.pt'
    model = torch.load(model_path, weights_only=False).to(device)
    cf = get_collate_function(if_final_round_syndrome=if_final_round_syndrome)
    bench_result = []

    for depth in depth_list:
        begin = time.time_ns()
        dataset = CachedSyndromeDataset(
            root_dir=data_dir,
            code_distance=distance,
            circuit_index=circuit_index,
            batch_size=batch_size,
            depth=depth,
        )
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=cf)
        acc = validate(model, dataloader, device)
        end = time.time_ns()
        walltime_seconds = (end - begin) / 1e9
        print(f"Depth: {depth}, Accuracy: {acc:.2f}, Walltime: {walltime_seconds:.2f}s")
    
        res = dict(decoder='MCCD', distance=distance, circuit_type_index=circuit_index, 
                   logical_accuracy=acc, walltime_seconds=walltime_seconds, depth=depth)

        bench_result.append(res)

    df = pd.DataFrame.from_records(bench_result)
    df = pd.melt(df, id_vars=['decoder', 'distance', 'depth', 'circuit_type_index'],
             value_vars=['walltime_seconds', 'logical_accuracy'],
             var_name='metric',
             value_name='value')

    return df

In [39]:
if_final_round_syndrome=True
noise_model = 'average_depolarizing_noise'
type3_depth_list = [ 2,  4,  6,  8, 10, 12, 14, 16, 18]
type4_depth_list = [ 4,  8, 12, 16, 20, 24, 28, 32, 36]

## Training

### Distance=3

In [12]:
!python train_1q_modules_cleaned.py \
        --code_distance 3 \
        --logical_circuit_index 3 \
        --if_final_round_syndrome \
        --batch_size 1000 \
        --run_index 0 \
        --model_save_path fig4_c3_d3 \
        --train_data_dir data/bench/average_depolarizing_noise/fig4/mccd \
        --val_data_dir data/bench/average_depolarizing_noise/fig4/mccd \
        --train_depth_list 2 4 6 8 \
        --val_depth_list 2 4 6 8 10 12 14 16 18


Namespace(code_distance=3, logical_circuit_index='3', batch_size=1000, if_final_round_syndrome=True, if_large_lstm_2q=False, load_model=None, run_index='0', model_save_path='fig4_c3_d3', train_data_dir='data/bench/average_depolarizing_noise/fig4/mccd', val_data_dir='data/bench/average_depolarizing_noise/fig4/mccd', train_depth_list=[2, 4, 6, 8], val_depth_list=[2, 4, 6, 8, 10, 12, 14, 16, 18], save_model_every_n_batches=500, validate_every_n_batches=100)
Number of params in model: 270240
Model save to: fig4_c3_d3
Save to trained_models/fig4_c3_d3
0it [00:00, ?it/s][Step 1000] Train Acc: 9.00% | Val Acc: 27.78%
Model checkpoint saved.
80it [00:08,  9.74it/s]
Final model saved.


In [42]:
!python train_2q_module_cleaned.py \
        --code_distance 3 \
        --logical_circuit_index 4 \
        --if_final_round_syndrome \
        --batch_size 1000 \
        --run_index 0 \
        --model_save_path fig5_c4_d3 \
        --train_data_dir data/bench/average_depolarizing_noise/fig5/mccd \
        --val_data_dir data/bench/average_depolarizing_noise/fig5/mccd \
          --train_depth_list 4 8 12 16 \
        --val_depth_list 4 8 12 16 20 24 28 32 36 \
        --load_model trained_models/fig4_c3_d3


Namespace(code_distance=3, logical_circuit_index='4', batch_size=1000, if_final_round_syndrome=True, if_large_lstm_2q=False, load_model='trained_models/fig4_c3_d3', run_index='0', model_save_path='fig5_c4_d3', train_data_dir='data/bench/average_depolarizing_noise/fig5/mccd', val_data_dir='data/bench/average_depolarizing_noise/fig5/mccd', train_depth_list=[4, 8, 12, 16], val_depth_list=[4, 8, 12, 16, 20, 24, 28, 32, 36], save_model_every_n_batches=500, validate_every_n_batches=100)
Loading model from: trained_models/fig4_c3_d3
Number of params in model: 108544
Model save to: fig5_c4_d3
Save to trained_models/fig5_c4_d3
[Step 1000] Train Acc: 67.30% | Val Acc: 53.14%
Model checkpoint saved.
Final model saved.


### Distance=5

In [51]:
!python train_1q_modules_cleaned.py \
        --code_distance 5 \
        --logical_circuit_index 3 \
        --if_final_round_syndrome \
        --batch_size 1000 \
        --run_index 0 \
        --model_save_path fig4_c3_d5 \
        --train_data_dir data/bench/average_depolarizing_noise/fig4/mccd \
        --val_data_dir data/bench/average_depolarizing_noise/fig4/mccd \
        --train_depth_list 2 4 6 8 \
        --val_depth_list 2 4 6 8 10 12 14 16 18


Namespace(code_distance=5, logical_circuit_index='3', batch_size=1000, if_final_round_syndrome=True, if_large_lstm_2q=False, load_model=None, run_index='0', model_save_path='fig4_c3_d5', train_data_dir='data/bench/average_depolarizing_noise/fig4/mccd', val_data_dir='data/bench/average_depolarizing_noise/fig4/mccd', train_depth_list=[2, 4, 6, 8], val_depth_list=[2, 4, 6, 8, 10, 12, 14, 16, 18], save_model_every_n_batches=500, validate_every_n_batches=100)
Number of params in model: 2399032
Model save to: fig4_c3_d5
Save to trained_models/fig4_c3_d5
0it [00:00, ?it/s][Step 1000] Train Acc: 12.70% | Val Acc: 35.73%
Model checkpoint saved.
80it [00:27,  2.96it/s]
Final model saved.


In [53]:
!python train_2q_module_cleaned.py \
        --code_distance 5 \
        --logical_circuit_index 4 \
        --if_final_round_syndrome \
        --batch_size 1000 \
        --run_index 0 \
        --model_save_path fig5_c4_d5 \
        --train_data_dir data/bench/average_depolarizing_noise/fig5/mccd \
        --val_data_dir data/bench/average_depolarizing_noise/fig5/mccd \
          --train_depth_list 4 8 12 16 \
        --val_depth_list 4 8 12 16 20 24 28 32 36 \
        --load_model trained_models/fig4_c3_d5


Namespace(code_distance=5, logical_circuit_index='4', batch_size=1000, if_final_round_syndrome=True, if_large_lstm_2q=False, load_model='trained_models/fig4_c3_d5', run_index='0', model_save_path='fig5_c4_d5', train_data_dir='data/bench/average_depolarizing_noise/fig5/mccd', val_data_dir='data/bench/average_depolarizing_noise/fig5/mccd', train_depth_list=[4, 8, 12, 16], val_depth_list=[4, 8, 12, 16, 20, 24, 28, 32, 36], save_model_every_n_batches=500, validate_every_n_batches=100)
Loading model from: trained_models/fig4_c3_d5
Number of params in model: 964608
Model save to: fig5_c4_d5
Save to trained_models/fig5_c4_d5
[Step 1000] Train Acc: 60.50% | Val Acc: 51.23%
Model checkpoint saved.
Final model saved.


## Evaluation

### Distance = 3

In [62]:
fig4_d3 = evaluate_model(model_name='fig4_c3_d3',
                   data_dir=f'data/bench/{noise_model}/fig4/mccd',
                   distance=3,
                   circuit_index=3,
                   depth_list=type3_depth_list,
                   batch_size=1000,
                   if_final_round_syndrome=if_final_round_syndrome)


Depth: 2, Accuracy: 0.91%, Walltime: 0.36s
Depth: 4, Accuracy: 0.84%, Walltime: 0.48s
Depth: 6, Accuracy: 0.78%, Walltime: 0.69s
Depth: 8, Accuracy: 0.74%, Walltime: 0.84s
Depth: 10, Accuracy: 0.70%, Walltime: 1.02s
Depth: 12, Accuracy: 0.67%, Walltime: 1.54s
Depth: 14, Accuracy: 0.64%, Walltime: 1.38s
Depth: 16, Accuracy: 0.61%, Walltime: 1.54s
Depth: 18, Accuracy: 0.60%, Walltime: 1.80s


In [None]:
fig5_d3 = evaluate_model(model_name='fig5_c4_d3',
                   data_dir=f'data/bench/{noise_model}/fig5/mccd',
                   distance=3,
                   circuit_index=4,
                   depth_list=type4_depth_list,
                   batch_size=1000,
                   if_final_round_syndrome=if_final_round_syndrome)

Depth: 4, Accuracy: 0.68%, Walltime: 0.95s
Depth: 8, Accuracy: 0.57%, Walltime: 1.66s
Depth: 12, Accuracy: 0.53%, Walltime: 2.34s
Depth: 16, Accuracy: 0.51%, Walltime: 2.82s
Depth: 20, Accuracy: 0.50%, Walltime: 3.45s
Depth: 24, Accuracy: 0.50%, Walltime: 4.50s
Depth: 28, Accuracy: 0.50%, Walltime: 5.16s
Depth: 32, Accuracy: 0.50%, Walltime: 5.41s
Depth: 36, Accuracy: 0.50%, Walltime: 7.14s


### Distance = 5

In [64]:
fig4_d5 = evaluate_model(model_name='fig4_c3_d5',
                   data_dir=f'data/bench/{noise_model}/fig4/mccd',
                   distance=5,
                   circuit_index=3,
                   depth_list=type3_depth_list,
                   batch_size=1000,
                   if_final_round_syndrome=if_final_round_syndrome)


Depth: 2, Accuracy: 0.86%, Walltime: 0.83s
Depth: 4, Accuracy: 0.77%, Walltime: 0.92s
Depth: 6, Accuracy: 0.69%, Walltime: 1.21s
Depth: 8, Accuracy: 0.65%, Walltime: 1.56s
Depth: 10, Accuracy: 0.61%, Walltime: 2.08s
Depth: 12, Accuracy: 0.58%, Walltime: 2.23s
Depth: 14, Accuracy: 0.56%, Walltime: 2.71s
Depth: 16, Accuracy: 0.54%, Walltime: 2.91s
Depth: 18, Accuracy: 0.53%, Walltime: 3.22s


In [65]:

fig5_d5 = evaluate_model(model_name='fig5_c4_d5',
                   data_dir=f'data/bench/{noise_model}/fig5/mccd',
                   distance=5,
                   circuit_index=4,
                   depth_list=type4_depth_list,
                   batch_size=1000,
                   if_final_round_syndrome=if_final_round_syndrome)

Depth: 4, Accuracy: 0.60%, Walltime: 1.98s
Depth: 8, Accuracy: 0.51%, Walltime: 3.28s
Depth: 12, Accuracy: 0.50%, Walltime: 4.50s
Depth: 16, Accuracy: 0.50%, Walltime: 6.02s
Depth: 20, Accuracy: 0.50%, Walltime: 7.39s
Depth: 24, Accuracy: 0.50%, Walltime: 8.76s
Depth: 28, Accuracy: 0.50%, Walltime: 10.38s
Depth: 32, Accuracy: 0.50%, Walltime: 11.62s
Depth: 36, Accuracy: 0.50%, Walltime: 12.84s


## Save Datarframes

In [66]:
fig4 = pd.concat([fig4_d3, fig4_d5], axis=0, ignore_index=True)
fig5 = pd.concat([fig5_d3, fig5_d5], axis=0, ignore_index=True)

fig4.to_csv(f'data/bench/{noise_model}/fig4/result/fig4-mccd.csv', index=False)
fig5.to_csv(f'data/bench/{noise_model}/fig5/result/fig5-mccd.csv', index=False)
