In [None]:
from spf.dataset.v5_data import v5rx_new_dataset
from spf.utils import zarr_open_from_lmdb_store, zarr_shrink
import yaml
import zarr
import shutil
from spf.dataset.spf_dataset import v5spfdataset


def compare_and_copy_n(prefix, src, dst, n):
    if isinstance(src, zarr.hierarchy.Group):
        for key in src.keys():
            compare_and_copy_n(prefix + "/" + key, src[key], dst[key], n)
    else:
        if prefix == "/config":
            if src.shape != ():
                dst[:] = src[:]
        else:
            for x in range(n):
                dst[x] = src[x]


def partial_dataset(input_fn, output_fn, n):
    input_fn.replace(".zarr", "")
    z = zarr_open_from_lmdb_store(input_fn + ".zarr")
    timesteps, _, buffer_size = z["receivers/r0/signal_matrix"].shape
    input_yaml_fn = input_fn + ".yaml"
    output_yaml_fn = output_fn + ".yaml"
    yaml_config = yaml.safe_load(open(input_yaml_fn, "r"))
    shutil.copyfile(input_yaml_fn, output_yaml_fn)
    new_z = v5rx_new_dataset(
        filename=output_fn + ".zarr",
        timesteps=timesteps,
        buffer_size=buffer_size,
        n_receivers=len(yaml_config["receivers"]),
        chunk_size=512,
        compressor=None,
        config=yaml_config,
        remove_if_exists=False,
    )
    compare_and_copy_n("", z, new_z, n)
    new_z.store.close()
    new_z = None
    zarr_shrink(output_fn)


n = 128
noise = 0.3
nthetas = 65
orbits = 2


tmpdirname = "/home/mouse9911/gits/spf/spf/notebooks/test"
ds_fn = f"{tmpdirname}/sample_dataset_for_ekf_n{n}_noise{noise}"
ds_fn_out = f"{tmpdirname}/sample_dataset_for_ekf_n{n}_noise{noise}_partial"

partial_dataset(ds_fn, ds_fn_out, 25)

In [None]:
ds = v5spfdataset(
    ds_fn_out,
    precompute_cache=tmpdirname,
    nthetas=65,
    skip_signal_matrix=True,
    paired=True,
    ignore_qc=True,
    gpu=False,
    temp_file=True,
    temp_file_suffix="",
)
ds.valid_entries

In [None]:
ds.refresh()
ds.z["receivers/r0/system_timestamp"][:]
len(ds), ds.valid_entries

In [None]:
ds[19][0]["windowed_beamformer"].mean()

In [None]:
ds_og = v5spfdataset(
    ds_fn,
    precompute_cache=tmpdirname,
    nthetas=65,
    skip_signal_matrix=True,
    paired=True,
    ignore_qc=True,
    gpu=True,
)
ds_og.valid_entries

In [None]:
ds_og[19][0]["windowed_beamformer"].mean()

In [None]:
datasets[0].z["recievers"][]

In [None]:
import tempfile
from spf.model_training_and_inference.models.particle_filter import (
    plot_single_theta_dual_radio,
    plot_single_theta_single_radio,
    plot_xy_dual_radio,
    run_single_theta_single_radio,
)

from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml
from spf.dataset.spf_dataset import v5spfdataset

n = 128
noise = 0.3
nthetas = 65
orbits = 2

tmpdir = tempfile.TemporaryDirectory()
tmpdirname = tmpdir.name
tmpdirname = "/home/mouse9911/gits/spf/spf/notebooks/test"
ds_fn = f"{tmpdirname}/sample_dataset_for_ekf_n{n}_noise{noise}"

full_p_fn = f"{tmpdirname}/full_p.pkl"
datasets = [
    v5spfdataset(
        prefix,
        precompute_cache=tmpdirname,
        nthetas=65,
        skip_signal_matrix=True,
        paired=True,
        ignore_qc=True,
        gpu=False,
        temp_file=True,
        temp_file_suffix="",
    )
    for prefix in [ds_fn]
]

In [None]:
create_fake_dataset(
    filename=ds_fn, yaml_config_str=fake_yaml, n=n, noise=noise, orbits=orbits
)

In [None]:
datasets[0].mean_phase["r0"]

In [None]:
datasets[0].get_mean_phase(0, 10)

In [None]:
datasets[0].cached_keys[0].keys()

In [None]:
len(datasets[0])

In [None]:
args = {
    "ds_fn": ds_fn,
    "precompute_fn": tmpdirname,
    "full_p_fn": full_p_fn,
    "N": 1024 * 4,
    "theta_err": 0.01,
    "theta_dot_err": 0.01,
}
run_single_theta_single_radio(**args)

In [None]:
from spf.model_training_and_inference.models.create_empirical_p_dist import (
    apply_symmetry_rules_to_heatmap,
    get_heatmap,
)


heatmap = get_heatmap(datasets, bins=50)
heatmap = apply_symmetry_rules_to_heatmap(heatmap)

import pickle

full_p_fn = f"{tmpdirname}/full_p.pkl"
pickle.dump({"full_p": heatmap}, open(full_p_fn, "wb"))

In [None]:
from spf.model_training_and_inference.models.particle_filter import (
    plot_single_theta_dual_radio,
    plot_single_theta_single_radio,
    plot_xy_dual_radio,
    run_single_theta_single_radio,
)


args = {
    "ds_fn": ds_fn,
    "precompute_fn": tmpdirname,
    "full_p_fn": full_p_fn,
    "N": 1024 * 4,
    "theta_err": 0.01,
    "theta_dot_err": 0.01,
}
run_single_theta_single_radio(**args)
plot_single_theta_single_radio(datasets[0], full_p_fn)
# plot_single_theta_dual_radio(datasets[0], full_p_fn)

# plot_xy_dual_radio(datasets[0], full_p_fn)

In [None]:
run_single_theta_single_radio(**args)

In [None]:
# run_single_theta_single_radio(**args)

In [None]:
from spf.model_training_and_inference.models.particle_filter import (
    run_single_theta_dual_radio,
)


run_single_theta_dual_radio(**args)

In [None]:
from spf.model_training_and_inference.models.particle_filter import run_xy_dual_radio

args = {
    "ds_fn": ds_fn,
    "precompute_fn": tmpdirname,
    "full_p_fn": full_p_fn,
    "N": 1024,
    "pos_err": 50,
    "vel_err": 0.1,
}

run_xy_dual_radio(**args)