In [3]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from qutip import *
import matplotlib.pyplot as plt
from visualisation_tools import *
from model_generator import *
from qsq_protocol import *
from csv_utils import *
from optimizer_utils import *
import datetime
import itertools
import random
plt.rcParams.update({
    "figure.figsize": (10, 6),
    "text.usetex": False, # Remove this line if your code doesn't run/ take too long
    "font.family": 'Times New Roman',
    "figure.dpi": 100,
    "font.size": 14
})
import os
import csv

In [4]:
def average_fidelity_gauge(rho, gate, M, avg_fidelity=None):
    """
    Compute the gauge-corrected average fidelity.
    """
    if avg_fidelity is None:
        avg_fidelity = average_fidelity(rho,gate,M)
    
    eigenvalues, eigenstates = M.eigenstates()
    U_guess = eigenstates[0] @ plus_state.dag() - 1j * eigenstates[1] @ minus_state.dag()
    initial_guess = decompose_U(U_guess)
    bounds = [(-np.pi, np.pi), (0, np.pi/2), (-np.pi, np.pi)]
    result_C = optimize_method('basinhopping', objective_function, initial_guess, rho, gate, M, bounds)
    result_D = optimize_method('differential_evolution', objective_function, initial_guess, rho, gate, M, bounds)
    result = max(1-result_C.fun, 1-result_D.fun, avg_fidelity)
    
    return result


In [None]:
def fill_model_gauge_fidelity(csv_file, flush_every=10):
    # Step 1: read everything
    with open(csv_file, mode='r', newline='') as file:
        reader = csv.DictReader(file)
        fieldnames = reader.fieldnames
        rows = list(reader)

    print(f"{len(rows)} rows to be processed")

    # Step 2: process rows in place
    for i, row in enumerate(rows, start=1):
        if not row.get('model_average_fidelity_gauge'):
            model_data = eval(row['model'])
            rho, gate, M = [Qobj(m) for m in model_data]
            avg_fidelity = float(row['model_average_fidelity'])

            row['model_average_fidelity_gauge'] = average_fidelity_gauge(
                rho, gate, M, avg_fidelity
            )

        # Flush FULL dataset
        if i % flush_every == 0:
            print(f"[{i}/{len(rows)}] flushed ({100*i/len(rows):.1f}%)")
            _rewrite_csv(csv_file, fieldnames, rows)

    # Final flush
    _rewrite_csv(csv_file, fieldnames, rows)


import time
import os
import csv

def _rewrite_csv(csv_file, fieldnames, rows, retries=5, delay=0.2):
    tmp_file = csv_file + ".tmp"

    with open(tmp_file, mode='w', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    for attempt in range(retries):
        try:
            os.replace(tmp_file, csv_file)
            return
        except PermissionError:
            if attempt == retries - 1:
                raise
            time.sleep(delay)



In [16]:
#  Define the CSV file path
csv_file = 'data.csv'
# generate_model_data(csv_file, 'random',100)
generate_model_data(csv_file, 'perturbed',1000)
# Usage
data = load_model_data(csv_file)

# # Access the data by variable name, for example:
# print(data['models'])     # Access the models (now Qobjs)



In [17]:
fill_model_gauge_fidelity("data.csv",5)

1260 rows to be processed
Flush up to row 5
Flush up to row 10
Flush up to row 15
Flush up to row 20
Flush up to row 25
Flush up to row 30
Flush up to row 35
Flush up to row 40
Flush up to row 45
Flush up to row 50
Flush up to row 55
Flush up to row 60
Flush up to row 65
Flush up to row 70
Flush up to row 75
Flush up to row 80
Flush up to row 85
Flush up to row 90
Flush up to row 95
Flush up to row 100
Flush up to row 105
Flush up to row 110
Flush up to row 115
Flush up to row 120
Flush up to row 125
Flush up to row 130
Flush up to row 135
Flush up to row 140
Flush up to row 145
Flush up to row 150
Flush up to row 155
Flush up to row 160
Flush up to row 165
Flush up to row 170
Flush up to row 175
Flush up to row 180
Flush up to row 185
Flush up to row 190
Flush up to row 195
Flush up to row 200
Flush up to row 205
Flush up to row 210
Flush up to row 215
Flush up to row 220
Flush up to row 225
Flush up to row 230
Flush up to row 235
Flush up to row 240
Flush up to row 245
Flush up to ro

KeyboardInterrupt: 