In [1]:
import argparse
import sys
import os
import time
import datetime
import ast
from collections import Counter
from pathlib import Path
import hydra

import numpy as np
import torch

# Ensure local src/ is importable
sys.path.insert(0, str(Path(os.getcwd()).parent / "src"))

from my_genQC.inference.eval_metrics import UnitaryFrobeniusNorm, UnitaryInfidelityNorm
from my_genQC.inference.evaluation_helper import get_unitaries, get_srvs
from my_genQC.inference.sampling import generate_compilation_tensors, generate_tensors, decode_tensors_to_backend
from my_genQC.pipeline.diffusion_pipeline import DiffusionPipeline
from my_genQC.platform.simulation import Simulator, CircuitBackendType
from my_genQC.platform.tokenizer.circuits_tokenizer import CircuitTokenizer
from my_genQC.utils.misc_utils import infer_torch_device, get_entanglement_bins
from my_genQC.dataset import circuits_dataset
from my_genQC.models.config_model import ConfigModel
from my_genQC.utils.config_loader import load_config, store_tensor, load_tensor

In [2]:
def load_dataset(dataset_path: Path, device: torch.device):
    config_path = os.path.join(dataset_path, "config.yaml")

    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Config file not found at {config_path}")

    cfg_data = load_config(config_path)
    target = cfg_data.get("target", "")
    if target.endswith("MixedCircuitsConfigDataset"):
        dataset_cls = circuits_dataset.MixedCircuitsConfigDataset
    else:
        dataset_cls = circuits_dataset.CircuitsConfigDataset

    dataset = dataset_cls.from_config_file(
        config_path=config_path,
        device=device,
        save_path=os.path.join(dataset_path, "dataset", "ds")
    )
    
    return dataset

ds = load_dataset("../datasets/paper_quditkit/srv_8q_dataset", device="cuda")
samples = ds.x.shape[0]

[INFO]: Loading tensor from `../datasets/paper_quditkit/srv_8q_dataset/dataset/ds_x.safetensors` onto device: cuda.
[INFO]: Loading tensor from `../datasets/paper_quditkit/srv_8q_dataset/dataset/ds_y.safetensors` onto device: cuda.
[INFO]: Instantiated config_dataset from given config on cuda.


In [3]:
tensor_out = load_tensor("../scripts/inference/8q_599936_samples.pt", device="cuda")

[INFO]: Loading tensor from `../scripts/inference/8q_599936_samples.pt` onto device: cuda.


In [4]:
# idx_shuffled = torch.randperm(len(tensor_out))
# tensor_out_shuffled = tensor_out[idx_shuffled]
# tensor_out = tensor_out[:10000]

In [4]:
vocabulary = {gate: idx for idx, gate in enumerate(ds.gate_pool)}
tokenizer = CircuitTokenizer(vocabulary)
simulator = Simulator(CircuitBackendType.QISKIT) #.QUDITKIT)
pqdm(loop_set, f, n_jobs=n_jobs)



In [None]:
from my_genQC.utils.async_fn import run_parallel_jobs
from my_genQC.platform.simulation import Simulator 
from my_genQC.platform.tokenizer.base_tokenizer import BaseTokenizer
from my_genQC.pipeline.pipeline import Pipeline

from typing import Optional, Sequence
import time
from tqdm import tqdm

from pqdm.threads import pqdm

def exists(val): 
    return val is not None

def _chunk_iterable(it, size: int):
    """Yield fixed-size chunks from any iterable without materializing the whole thing."""
    chunk = []
    for item in it:
        chunk.append(item)
        if len(chunk) == size:
            yield chunk
            chunk = []
    if chunk:
        yield chunk


def run_parallel_jobs(f: callable,
                      loop_set,
                      n_jobs: int = 1,
                      prefer: str = "processes",
                      batch_size: int | None = None):
    """
    Run a function in parallel over an iterable.

    - prefer="processes" avoids the GIL for Python-heavy work (default).
    - batch_size groups items to amortize scheduling/pickling overhead.
    """
    # fast path
    if n_jobs == 1:
        return [f(x) for x in loop_set]

    if batch_size:
        def _run_batch(batch):
            return [f(x) for x in batch]

        batches = _chunk_iterable(loop_set, batch_size)
        res_batches = Parallel(n_jobs=n_jobs,
                               prefer=prefer,
                               batch_size="auto")(delayed(_run_batch)(b) for b in batches)
        # flatten
        return [item for batch in res_batches for item in batch]

    return Parallel(n_jobs=n_jobs,
                    prefer=prefer,
                    batch_size="auto")(delayed(f)(x) for x in loop_set)


def decode_tensors_to_backend(simulator: Simulator, 
                              tokenizer: BaseTokenizer, 
                              tensors: torch.Tensor, 
                              params: Optional[torch.Tensor] = None, 
                              silent: bool = True,
                              n_jobs: int = 1,
                              filter_errs: bool = True,
                              return_tensors: bool = False,
                    ) -> tuple[Sequence[any], int] | tuple[Sequence[any], int, torch.Tensor]:
    tensors = tensors.cpu()

    if exists(params):
        params  = params.cpu()
        iter_pack = zip(tensors, params)
        _decode   = lambda x, p: tokenizer.decode(x, p)
        
    else:
        iter_pack = zip(tensors, )
        _decode   = lambda x: tokenizer.decode(x)
    
    def _f(iter_vars):
        try:
            instructions = _decode(*iter_vars)
            backend_obj  = simulator.backend.genqc_to_backend(instructions, place_barriers=False)
            return backend_obj
        except Exception as err:
            if silent: return None
            raise err
        
    pot_qcs = run_parallel_jobs(_f, iter_pack, n_jobs)

    if filter_errs:
        backend_obj_list = [pot_qc for pot_qc in pot_qcs if exists(pot_qc)]
        err_cnt          = sum(1 for pot_qc in pot_qcs if not_exists(pot_qc))
        assert len(backend_obj_list) + err_cnt == len(pot_qcs)

        if return_tensors:
            tensors = tensors[torch.tensor([exists(pot_qc) for pot_qc in pot_qcs])]
        
    else:
        backend_obj_list = pot_qcs
        err_cnt = None

    if return_tensors:
        return backend_obj_list, err_cnt, tensors
    return backend_obj_list, err_cnt

In [None]:
decoded_circuits, _ = decode_tensors_to_backend(
    simulator=simulator,
    tokenizer=tokenizer,
    tensors=tensor_out,
    silent=True,
    params=None,
    n_jobs=24,
    filter_errs=False,
)

In [26]:
decoded_circuits[:10]

[<qiskit.circuit.quantumcircuit.QuantumCircuit at 0x78593c65b9d0>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x7859de3ac990>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x785937bd8590>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x78593c63e890>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x785937fb9550>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x785937bae350>,
 None,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x785937b7e690>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x78593c619610>,
 <qiskit.circuit.quantumcircuit.QuantumCircuit at 0x785937b70c50>]

In [27]:
valid = [(idx, qc) for idx, qc in enumerate(decoded_circuits) if qc is not None]
valid_indices = [idx for idx, _ in valid]
backend_circuits = [qc for _, qc in valid]
err_cnt = len(decoded_circuits) - len(backend_circuits)

print("==== genQC Evaluation ====")
print(f"Samples requested: {samples}")
print(f"Decoded circuits : {len(backend_circuits)}")
print(f"Decode failures  : {err_cnt}")

==== genQC Evaluation ====
Samples requested: 599936
Decoded circuits : 926
Decode failures  : 74


In [28]:
system_size = ds.x.shape[1]
max_gates = ds.x.shape[2]
num_qubits = 8

In [29]:
def parse_srv_targets(labels: np.ndarray) -> torch.Tensor:
    """Extract SRV vectors from stored prompt strings."""
    srv_list = []
    for label in labels:
        text = str(label)
        start = text.find("[")
        end = text.find("]", start)
        if start == -1 or end == -1:
            raise ValueError(f"Could not parse SRV from label: {text}")
        srv = ast.literal_eval(text[start:end+1])
        srv_list.append(srv)
    return torch.tensor(srv_list, dtype=torch.long)


def entanglement_histogram(srvs: torch.Tensor, num_qubits: int) -> tuple[list[float], list[str], float]:
    """Return histogram over entanglement bins defined in genQC."""
    if srvs.numel() == 0:
        return [], [], 0.0

    bins, labels = get_entanglement_bins(num_qubits)
    mapping = {}
    for idx, bucket in enumerate(bins):
        for vector in bucket:
            mapping[tuple(vector)] = idx

    counts = Counter(mapping.get(tuple(vec.tolist()), -1) for vec in srvs)
    total = srvs.shape[0]
    hist = [counts.get(i, 0) / total for i in range(len(labels))]
    other_ratio = counts.get(-1, 0) / total
    return hist, labels, other_ratio


def get_srvs(simulator: Simulator, backend_obj_list: Sequence, n_jobs: int = 1, **kwargs): 
    """Returns SRVs of a given list of backen objects `backend_obj_list`."""
    def _f(backend_obj):
        return simulator.backend.schmidt_rank_vector(backend_obj, **kwargs)
        
    return run_parallel_jobs(_f, backend_obj_list, n_jobs)

In [30]:
target_srvs = parse_srv_targets(ds.y[:samples])[valid_indices]
predicted_srvs = torch.tensor(
    get_srvs(simulator, backend_circuits, n_jobs=16),  
    dtype=torch.long,
)

if target_srvs.shape != predicted_srvs.shape:
    raise RuntimeError(f"SRV shape mismatch: target {target_srvs.shape} vs predicted {predicted_srvs.shape}")

exact_match = (predicted_srvs == target_srvs).all(dim=1)
per_qubit = (predicted_srvs == target_srvs).float().mean(dim=0)

srv_exact_match_rate = exact_match.float().mean().item()
print(f"SRV exact-match rate : {srv_exact_match_rate:.4f}")

qubit_rank_acc = {i: acc for i, acc in enumerate(per_qubit.tolist())}
print("Per-qubit rank acc   : " + ", ".join(f"q{i}={acc:.3f}" for i, acc in qubit_rank_acc.items()))
# print("Per-qubit rank acc   : " + ", ".join(f"q{i}={acc:.3f}" for i, acc in enumerate(per_qubit.tolist())))

pred_hist, ent_labels, pred_other = entanglement_histogram(predicted_srvs, num_qubits)
targ_hist, _, targ_other = entanglement_histogram(target_srvs, num_qubits)

if ent_labels:
    print("Entanglement-bin distribution (target | pred):")
    for label, t_frac, p_frac in zip(ent_labels, targ_hist, pred_hist):
        print(f"  {label:>20}: {t_frac:6.2%} | {p_frac:6.2%}")
    if targ_other > 0 or pred_other > 0:
        print(f"  {'Other/invalid':>20}: {targ_other:6.2%} | {pred_other:6.2%}")

QUEUEING TASKS | :   0%|          | 0/926 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/926 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/926 [00:00<?, ?it/s]

SRV exact-match rate : 0.8337
Per-qubit rank acc   : q0=0.974, q1=0.968, q2=0.973, q3=0.974, q4=0.970, q5=0.974, q6=0.969, q7=0.976
Entanglement-bin distribution (target | pred):
     0 qubit entangled: 14.79% | 15.55%
     2 qubit entangled: 15.66% | 17.17%
     3 qubit entangled:  8.53% |  9.83%
     4 qubit entangled: 10.91% | 13.93%
     5 qubit entangled: 16.85% | 14.58%
     6 qubit entangled: 13.61% | 12.20%
     7 qubit entangled: 12.96% | 10.91%
     8 qubit entangled:  6.70% |  5.83%


In [39]:
srvs = get_srvs(simulator, backend_circuits, n_jobs=47)

QUEUEING TASKS | :   0%|          | 0/599936 [00:00<?, ?it/s]

PROCESSING TASKS | :   0%|          | 0/599936 [00:00<?, ?it/s]

COLLECTING RESULTS | :   0%|          | 0/599936 [00:00<?, ?it/s]

In [41]:
decoded_circuits[0]

AttributeError("Can't pickle local object 'decode_tensors_to_backend.<locals>._f'")