In [None]:
import csv
from pathlib import Path

from mqt.problemsolver.resource_estimation.error_budget_optimization import (
    evaluate,
    generate_data,
    plot_results,
    train,
)

## Data Generation

In [None]:
total_error_budget = 0.1
number_of_randomly_generated_distributions = 1000
use_paper_data = True
use_zip_file_circuits = False

if use_paper_data:
    csv_path = Path("logical_counts.csv")
    with Path.open(csv_path, encoding="utf-8", newline="") as csvfile:
        reader = csv.DictReader(csvfile)
        logical_counts = [{k: int(v) for k, v in row.items()} for row in reader]
    data = generate_data(
        total_error_budget,
        number_of_randomly_generated_distributions,
        logical_counts=logical_counts,
    )

elif use_zip_file_circuits:
    zip_file_path = Path("mqt_bench.zip")
    if not zip_file_path.exists():
        msg = f"Data not found at {zip_file_path}"
        raise FileNotFoundError(msg)
    data = generate_data(
        total_error_budget,
        number_of_randomly_generated_distributions,
        path=zip_file_path,
    )
else:
    benchmark_defs = [
        ("ae", [3, 4, 5, 6, 7, 8, 9, 10]),
        ("dj", [3, 4, 5, 6, 7, 8, 9, 10]),
    ]
    data = generate_data(
        total_error_budget,
        number_of_randomly_generated_distributions,
        benchmarks_and_sizes=benchmark_defs,
    )

## Training

In [None]:
model, x_test, y_test = train(data)

## Evaluation

In [None]:
Y_pred = model.predict(x_test)
product_diffs = evaluate(x_test, Y_pred, total_error_budget)
product_diffs_dataset = evaluate(x_test, y_test, total_error_budget)

plot_results(product_diffs, product_diffs_dataset, legend=True, bin_width=4)