In [None]:
import numpy as np
import scipy

from filterpy.monte_carlo import systematic_resample

from spf.dataset.spf_dataset import v5spfdataset

ds_fn = "/mnt/md0/spf/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_10_03_38_21_nRX2_rx_circle.zarr"
ds_fn = "/mnt/md0/spf/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_15_11_44_13_nRX2_bounce.zarr"

import pickle
import os


output_prefix = "./" + os.path.basename(ds_fn) + "_"
full_p_fn = "full_p.pkl"
full_p = pickle.load(open(full_p_fn, "rb"))["full_p"]

nthetas = 65
ds = v5spfdataset(
    ds_fn,
    nthetas=nthetas,
    ignore_qc=True,
    precompute_cache="/home/mouse9911/precompute_cache_chunk16",
    paired=True,
    skip_signal_matrix=True,
    snapshots_per_session=-1,
)

# flip the order of the antennas
antenna_spacing = -ds.yaml_config["receivers"][0]["antenna-spacing-m"]
assert antenna_spacing == -ds.yaml_config["receivers"][1]["antenna-spacing-m"]

wavelength = ds.wavelengths[0]
assert wavelength == ds.wavelengths[1]

offsets = [
    ds.yaml_config["receivers"][0]["theta-in-pis"] * np.pi,
    ds.yaml_config["receivers"][1]["theta-in-pis"] * np.pi,
]

In [None]:
from spf.model_training_and_inference.models.particle_filter import (
    PFSingleThetaSingleRadio,
)


pf = PFSingleThetaSingleRadio(ds=ds, rx_idx=1, full_p_fn="full_p.pkl")
trajectory, all_particles = pf.trajectory(
    mean=np.array([[0, 0]]), std=np.array([[2, 0.1]]), return_particles=True
)
pf.metrics(trajectory=trajectory)

In [None]:
from torch.nn import (
    TransformerEncoder,
    TransformerEncoderLayer,
    LayerNorm,
    Sequential,
    Linear,
)

d_model = 512
d_hid = 128
dropout = 0.0
n_heads = 8
n_layers = 9

encoder_layers = TransformerEncoderLayer(
    d_model=d_model,
    nhead=n_heads,
    dim_feedforward=d_hid,
    dropout=dropout,
    # activation="gelu",
    batch_first=True,
)
transformer_encoder = TransformerEncoder(
    encoder_layers,
    n_layers,
    LayerNorm(d_model),
)

device = "cuda"
m = Sequential(
    transformer_encoder,
    Linear(d_model, 2),
).to(device)
import torch

target = torch.randn(7, 2).to(device)
input = torch.randn(7, 10, 512).to(device)

In [None]:
import torch

lr = 0.0000001  # 1
weight_decay = 0.00000001
optimizer = torch.optim.AdamW(m.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
for _ in range(2000):
    output = m(input)[:, 0, :]

    loss = ((target - output) ** 2).mean()
    loss.backward()
    if _ % 50 == 0:
        print(loss.item())

    optimizer.step()

In [None]:
output

In [None]:
target

In [None]:
ds = v5spfdataset(
    prefix="/mnt/md0/spf/2d_wallarray_v2_data/june_fix/wallarrayv3_2024_06_15_04_24_24_nRX2_rx_circle.zarr",
    precompute_cache="/home/mouse9911/precompute_cache_chunk16_fresh/",
    nthetas=65,
    skip_signal_matrix=True,
    paired=2,
    ignore_qc=True,
    gpu=True,
    snapshots_per_session=1000,
)

In [None]:
ds[0]

In [None]:
import torch

from spf.rf import torch_reduce_theta_to_positive_y

y_rad = torch.rand(5, 2, dtype=torch.float16).cuda()
y_rad_reduced = torch_reduce_theta_to_positive_y(y_rad).reshape(-1, 1)
print(y_rad.dtype, y_rad_reduced.dtype)