In [None]:
import logging
import torch

from spf.rf import torch_pi_norm
from spf.dataset.spf_dataset import v5spfdataset, v5spfdataset_manager
import glob

from tqdm import tqdm

from multiprocessing import Pool


def ds_to_metrics(args):
    try:
        with v5spfdataset_manager(
            args["ds_fn"],
            nthetas=65,
            ignore_qc=True,
            precompute_cache=args["precompute_cache"],
            snapshots_per_session=1,
            skip_fields=["signal_matrix"],
            paired=True,
            segmentation_version=args["segmentation_version"],
        ) as ds:
            diffs = torch_pi_norm(
                ds.ground_truth_phis
                - torch.vstack([ds.mean_phase["r0"], ds.mean_phase["r1"]])
            )
            mask = diffs.isfinite()
            return (
                ds.yaml_config["routine"],
                ds.carrier_frequencies[0],
                torch.as_tensor(
                    [
                        diffs[mask].std(),
                        mask.to(torch.float).mean(),
                        ds.mean_phase["r0"].shape[0],
                    ]
                ),
            )
    except Exception as e:
        logging.error(f"Failed to load... {args['ds_fn']} with exception: {e})
        return None,None,None


inputs = glob.glob("/mnt/4tb_ssd/nosig_data/*.zarr")
segmentation_version = 3.2
precompute_cache = "/mnt/4tb_ssd/precompute_cache_new"
jobs = [
    {
        "ds_fn": fn,
        "segmentation_version": segmentation_version,
        "precompute_cache": precompute_cache,
    }
    for fn in inputs
]


with Pool(8) as p:
    metrics_list = list(tqdm(p.imap(ds_to_metrics, jobs), total=len(jobs)))

results = {}
for routine, frequency, metrics in metrics_list:
    if frequency not in results:
        results[frequency] = {}
    if routine not in results[frequency]:
        results[frequency][routine] = []
    results[frequency][routine].append(metrics)
for frequency in results:
    for routine in results[frequency]:
        metrics = torch.vstack(results[frequency][routine])
        std = ((metrics[:, 0] * metrics[:, 2]) / metrics[:, 2].sum()).sum()
        notnan = ((metrics[:, 1] * metrics[:, 2]) / metrics[:, 2].sum()).sum()
        results[frequency][routine] = {"std": std, "notnan": notnan}

print(results)

In [None]:
import torch

torch.tensor([2.417e9]).log10() / 20
torch.tensor([0.900e9]).log10() / 20

In [None]:
results