In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys

if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir

In [None]:
from torch.utils.data import Dataset
import yaml
import torch
from spf.rf import precompute_steering_vectors
from spf.utils import zarr_open_from_lmdb_store
from spf.dataset.v5_data import v5rx_f64_keys, v5rx_2xf64_keys
import numpy as np
from spf.rf import speed_of_light
from multiprocessing.pool import ThreadPool
import os
import pickle
from multiprocessing import Pool
import time
import tqdm

#         return pickle.load(open(results_fn, "rb"))

In [None]:
from spf.dataset.spf_dataset import v5spfdataset


ds = v5spfdataset(
    # "/Volumes/SPFData/missions/april5/wallarrayv3_2024_04_05_22_13_07_nRX2_rx_circle"
    # "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_10_05_03_21_nRX2_rx_circle_tag_rand10_90_rand30_100"
    # "/Volumes/SPFData/missions/april5/wallarrayv3_2024_04_10_05_08_55_nRX2_rx_circle"
    "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce",
    nthetas=11,
)

In [None]:
session_idx = 150 * 2 + 100
session_idx = 1887

In [None]:
from spf.rf import beamformer_given_steering_nomean, get_phase_diff, simple_segment
import matplotlib.pyplot as plt

data = ds[session_idx]
# signal_matrix = load_zarr_to_numpy(z.receivers["r0"].signal_matrix[session_idx])
n = 2 * 4 * 50000
offset = 0
v = data["signal_matrix"][:, offset : offset + n].numpy()
pd = get_phase_diff(v)

fig, axs = plt.subplots(2, 1, figsize=(12, 6))

axs[0].scatter(np.arange(n), np.abs(v[0]), alpha=0.1, s=1)
axs[0].scatter(np.arange(n), np.abs(v[1]), alpha=0.1, s=1)
axs[0].set_title("Raw signal")
axs[0].set_xlabel("Sample# (time)")
axs[1].set_xlabel("Sample# (time)")
axs[1].set_title("Phase estimates")
axs[1].scatter(np.arange(n), pd, s=1, alpha=0.1)

beam_sds = [
    beamformer_given_steering_nomean(
        steering_vectors=ds.steering_vectors[receiver_idx],
        signal_matrix=v,
    )
    for receiver_idx in range(2)
]

window_sds = []
for window in simple_segment(
    v,
    window_size=2500,
    stride=2500,
    trim=20,
    mean_diff_threshold=0.2,  #
    max_stddev_threshold=0.5,  # just eyeballed this
    drop_less_than_size=3000,
    min_abs_signal=40,
)["simple_segmentation"]:
    if window["type"] == "signal":
        axs[1].plot(
            [window["start_idx"], window["end_idx"]],
            [window["mean"], window["mean"]],
            color="red",
        )
    else:
        axs[1].plot(
            [window["start_idx"], window["end_idx"]],
            [window["mean"], window["mean"]],
            color="orange",
        )
    # print(window["start_idx"], window["end_idx"])
    _beam_sds = beam_sds[0][:, window["start_idx"] : window["end_idx"]].mean(axis=1)
    # _beam_sds = _beam_sds.mean(axis=1)
    # _beam_sds -= _beam_sds.min()
    window_sds.append(_beam_sds)
window_sds = np.array(window_sds)
fig.set_tight_layout(True)

In [None]:
ds[session_idx]["simple_segmentation"]

In [None]:
segmentation = ds.get_segmentation()
mean_phase_results = {}
for receiver, results in segmentation["segmentation_by_receiver"].items():
    mean_phase_results[receiver] = np.array(
        [
            np.array([x["mean"] for x in result["simple_segmentation"]]).mean()
            for result in results
        ]
    )


first_n = 250 * 2
fig, axs = plt.subplots(1, 1)
axs.scatter(range(first_n), mean_phase_results["r0"][:first_n], s=3, label="Rx0")
axs.scatter(range(first_n), mean_phase_results["r1"][:first_n], s=3, label="Rx1")
axs.legend()
axs.axvline(x=115)
axs.set_title("Mean segmented phase diff")
axs.set_xlabel("Chunk (time)")
axs.set_ylabel("Mean phase diff of seg. chunk")

In [None]:
# segmentation_by_receiver.keys()

In [None]:
ds.get_segmentation_mean_phase()
ds.get_estimated_thetas()

In [None]:
from spf.dataset.spf_dataset import pi_norm
from spf.rf import c as speed_of_light


fig, axs = plt.subplots(1, 2, figsize=(12, 4))

estimated_thetas = ds.get_estimated_thetas()
for rx_idx in [0, 1]:

    axs[rx_idx].scatter(
        range(estimated_thetas[f"r{rx_idx}"][0].shape[0]),
        pi_norm(estimated_thetas[f"r{rx_idx}"][0]),
        s=0.4,
    )
    axs[rx_idx].scatter(
        range(estimated_thetas[f"r{rx_idx}"][1].shape[0]),
        pi_norm(estimated_thetas[f"r{rx_idx}"][1]),
        s=0.4,
    )
    axs[rx_idx].scatter(
        range(estimated_thetas[f"r{rx_idx}"][2].shape[0]),
        pi_norm(estimated_thetas[f"r{rx_idx}"][2]),
        s=0.4,
    )
    axs[rx_idx].set_xlabel("Chunk")
    axs[rx_idx].set_ylabel("estimated theta")

In [None]:
from spf.dataset.spf_dataset import pi_norm


fig, axs = plt.subplots(1, 2, figsize=(12, 4))

first_n = 250 * 2
estimated_thetas = ds.get_estimated_thetas()
for rx_idx in [0, 1]:
    expected_theta = ds.ground_truth_thetas[rx_idx]
    axs[rx_idx].plot(
        expected_theta[:first_n], alpha=1, color="red", label="ground truth"
    )

    n = estimated_thetas[f"r{rx_idx}"][0].shape[0]
    axs[rx_idx].scatter(
        range(first_n),
        pi_norm(estimated_thetas[f"r{rx_idx}"][0])[:first_n],
        s=3,
        label=f"Rx{rx_idx}_peak1",
    )
    axs[rx_idx].scatter(
        range(first_n),
        pi_norm(estimated_thetas[f"r{rx_idx}"][1])[:first_n],
        s=3,
        label=f"Rx{rx_idx}_peak2",
    )
    axs[rx_idx].set_xlabel("Chunk")
    axs[rx_idx].set_ylabel("estimated theta")
    axs[rx_idx].legend()
    axs[rx_idx].set_title(f"Receiver (Rx) {rx_idx}")

In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys

if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir

from spf.dataset.spf_dataset import v5spfdataset


ds = v5spfdataset(
    "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce",
    nthetas=11,
)

from functools import cache
import gc

from spf.dataset.spf_dataset import v5_collate_beamsegnet, v5_thetas_to_targets
from spf.model_training_and_inference.models.beamsegnet import (
    BeamNSegNetDirect,
    BeamNSegNetDiscrete,
    # BeamNetDirect,
    UNet1D,
    ConvNet,
)

torch_device = torch.device("cpu")
nthetas = 11
lr = 0.001


dataloader_params = {
    "batch_size": 4,
    "shuffle": True,
    "num_workers": 0,
    "collate_fn": v5_collate_beamsegnet,
}
torch.manual_seed(1337)
train_dataloader = torch.utils.data.DataLoader(ds, **dataloader_params)

import random

w = False
if w:

    import wandb

    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="projectspf",
        # track hyperparameters and run metadata
        config={
            "learning_rate": lr,
            "architecture": "beamsegnet1",
        },
    )


@cache
def mean_guess(shape):
    return torch.nn.functional.normalize(torch.ones(shape), p=1, dim=1)


X, Y_rad, segmentation = next(iter(train_dataloader))


def batch_to_gt_segmentation(X, Y_rad, segmentation):
    n, _, samples_per_session = X.shape
    window_size = 2048
    stride = 2048
    assert window_size == stride
    assert samples_per_session % window_size == 0
    n_windows = samples_per_session // window_size
    window_status = torch.zeros(n, n_windows)
    for row_idx in range(len(segmentation)):
        for window in segmentation[row_idx]["simple_segmentation"]:
            window_status[
                row_idx,
                window["start_idx"] // window_size : window["end_idx"] // window_size,
            ] = 1
    return window_status[:, None]


def segmentation_mask(X, segmentations):
    seg_mask = torch.zeros(
        X.shape[0], X.shape[2], device=X.device
    )  # X.new(X.shape[0], X.shape[2])
    for row_idx in range(seg_mask.shape[0]):
        for w in segmentations[row_idx]["simple_segmentation"]:
            seg_mask[row_idx, w["start_idx"] : w["end_idx"]] = 1
    return seg_mask[:, None]  # orch.nn.functional.normalize(seg_mask, p=1, dim=1)


# m = BeamNSegNetDiscrete(nthetas=nthetas, symmetry=False).to(torch_device)
# m = BeamNSegNetDirect(nthetas=nthetas, symmetry=False).to(torch_device)
# print("ALL", segmentation[0]["all_windows_stats"].shape)
m = UNet1D().to(torch_device).double()
# m = ConvNet(in_channels=3, out_channels=1, hidden=32)
optimizer = torch.optim.Adam(m.parameters(), lr=0.00001, weight_decay=0)
step = 0
sigmoid = torch.nn.Sigmoid()
X = X.double().to(torch_device)
# X[:, :2] /= 500
for epoch in range(10000):
    # for X, Y_rad, segmentation in train_dataloader:
    if True:
        optimizer.zero_grad()

        # full
        input = X.clone().to(torch_device)
        output = m(input)

        seg_mask = segmentation_mask(X, segmentation)
        print(input.shape, output.shape, seg_mask.shape)

        # downsampled
        # input = torch.Tensor(
        #     np.vstack(
        #         [
        #             segmentation[idx]["all_windows_stats"].transpose()[None]
        #             for idx in range(len(segmentation))
        #         ]
        #     )
        # )
        # input[:, 2] /= 50
        # output = m(input)
        # seg_mask = batch_to_gt_segmentation(X, Y_rad, segmentation)

        loss = ((output - seg_mask) ** 2).mean()
        loss.backward()
        optimizer.step()

        to_log = {"loss": loss.item()}

        _input = input.cpu()
        _output = output.cpu().detach().numpy()
        first_n = 3000

        if step % 1000 == 0:
            print(loss.item())
            fig, axs = plt.subplots(1, 3, figsize=(8, 3))
            s = 0.3
            axs[0].set_title("input (track 0/1)")
            axs[0].scatter(range(first_n), _input[0, 0, :first_n], s=s)
            axs[0].scatter(range(first_n), _input[0, 1, :first_n], s=s)
            axs[1].set_title("input (track 2)")
            axs[1].scatter(range(first_n), _input[0, 2, :first_n], s=s)
            # mw = mask_weights.cpu().detach().numpy()

            axs[2].set_title("output vs gt")
            axs[2].scatter(range(first_n), _output[0, 0, :first_n], s=s)
            axs[2].scatter(
                range(first_n), seg_mask.cpu().detach().numpy()[0, 0, :first_n], s=s
            )
            to_log["fig"] = fig
        if w:
            wandb.log(to_log)
        step += 1


# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

In [None]:
X.shape

In [None]:
output.shape, seg_mask.shape

In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys
import torch

if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir

from spf.dataset.spf_dataset import v5spfdataset
import matplotlib.pyplot as plt

torch_device = torch.device("mps")
nthetas = 11
lr = 0.001
batch_size = 32

ds = v5spfdataset(
    "/Volumes/SPFData/missions/april5/wallarrayv3_2024_05_06_19_04_15_nRX2_bounce",
    nthetas=11,
)

from functools import cache
import gc

from spf.dataset.spf_dataset import v5_collate_beamsegnet, v5_thetas_to_targets
from spf.model_training_and_inference.models.beamsegnet import (
    BeamNSegNet,
    BeamNetDirect,
    BeamNetDiscrete,
    ConvNet,
    UNet1D,
)


dataloader_params = {
    "batch_size": batch_size,
    "shuffle": True,
    "num_workers": 0,
    "collate_fn": v5_collate_beamsegnet,
}
torch.manual_seed(1337)
train_dataloader = torch.utils.data.DataLoader(ds, **dataloader_params)
w = True
if w:
    import wandb

    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="projectspf",
        # track hyperparameters and run metadata
        config={
            "learning_rate": lr,
            "architecture": "beamsegnet1",
        },
    )


batch_data = next(iter(train_dataloader))

segmentation_level = "downsampled"
if segmentation_level == "full":
    first_n = 10000
    seg_m = UNet1D().to(torch_device)
elif segmentation_level == "downsampled":
    first_n = 256
    seg_m = ConvNet(3, 1, 32).to(torch_device)

beam_m = BeamNetDirect(nthetas=nthetas, hidden=16, symmetry=True).to(torch_device)
m = BeamNSegNet(segnet=seg_m, beamnet=beam_m).to(torch_device)

optimizer = torch.optim.AdamW(m.parameters(), lr=0.001, weight_decay=0)

step = 0
for epoch in range(10000):
    # for X, Y_rad in train_dataloader:
    optimizer.zero_grad()

    # copy to torch device
    if segmentation_level == "full":
        x = batch_data["x"].to(torch_device)
        y_rad = batch_data["y_rad"].to(torch_device)
        seg_mask = batch_data["segmentation_mask"].to(torch_device)
    elif segmentation_level == "downsampled":
        x = batch_data["all_windows_stats"].to(torch_device)
        y_rad = batch_data["y_rad"].to(torch_device)
        seg_mask = batch_data["downsampled_segmentation_mask"].to(torch_device)
    else:
        raise NotImplementedError

    assert seg_mask.ndim == 3 and seg_mask.shape[1] == 1

    # run beamformer and segmentation
    output = m(x)

    # x to beamformer loss (indirectly including segmentation)
    x_to_beamformer_loss = -beam_m.loglikelihood(output["pred_theta"], y_rad)
    assert x_to_beamformer_loss.shape == (batch_size, 1)
    x_to_beamformer_loss = x_to_beamformer_loss.mean()

    # x to segmentation loss
    output_segmentation_upscaled = output["segmentation"] * seg_mask.sum(
        axis=2, keepdim=True
    )
    x_to_segmentation_loss = (output_segmentation_upscaled - seg_mask) ** 2
    assert x_to_segmentation_loss.ndim == 3 and x_to_segmentation_loss.shape[1] == 1
    x_to_segmentation_loss = x_to_segmentation_loss.mean()

    loss = x_to_beamformer_loss + 10 * x_to_segmentation_loss

    loss.backward()
    optimizer.step()

    to_log = {
        "loss": loss.item(),
        "segmentation_loss": x_to_segmentation_loss.item(),
        "beam_former_loss": x_to_beamformer_loss.item(),
    }
    if step % 200 == 0:
        # beam outputs
        img_beam_output = (
            (beam_m.render_discrete_x(output["pred_theta"]) * 255).cpu().byte()
        )
        img_beam_gt = (beam_m.render_discrete_y(y_rad) * 255).cpu().byte()
        train_target_image = torch.zeros(
            (img_beam_output.shape[0] * 2, img_beam_output.shape[1]),
        ).byte()
        for row_idx in range(img_beam_output.shape[0]):
            train_target_image[row_idx * 2] = img_beam_output[row_idx]
            train_target_image[row_idx * 2 + 1] = img_beam_gt[row_idx]
        if w:
            output_image = wandb.Image(
                train_target_image, caption="train vs target (interleaved)"
            )
            to_log["output"] = output_image

        # segmentation output
        _x = x.detach().cpu().numpy()
        _seg_mask = seg_mask.detach().cpu().numpy()
        _output_seg = output_segmentation_upscaled.detach().cpu().numpy()
        fig, axs = plt.subplots(1, 3, figsize=(8, 3))
        s = 0.3
        idx = 0
        axs[0].set_title("input (track 0/1)")
        axs[0].scatter(range(first_n), _x[idx, 0, :first_n], s=s)
        axs[0].scatter(range(first_n), _x[idx, 1, :first_n], s=s)
        axs[1].set_title("input (track 2)")
        axs[1].scatter(range(first_n), _x[idx, 2, :first_n], s=s)
        # mw = mask_weights.cpu().detach().numpy()

        axs[2].set_title("output vs gt")
        axs[2].scatter(range(first_n), _output_seg[idx, 0, :first_n], s=s)
        axs[2].scatter(range(first_n), _seg_mask[idx, 0, :first_n], s=s)
        if w:
            to_log["fig"] = fig
    if w:
        wandb.log(to_log)
    else:
        print(loss.item())
    step += 1

# [optional] finish the wandb run, necessary in notebooks
if w:
    wandb.finish()

In [None]:
# output_segmentation_upscaled = output["segmentation"] * seg_mask.sum()
# x_to_segmentation_loss = (output_segmentation_upscaled - seg_mask) ** 2
(output["segmentation"] * seg_mask.sum(axis=2, keepdim=True)).sum(axis=2)

In [None]:
seg_mask.sum(axis=2, keepdim=True)

In [None]:
z = output["segmentation"].detach().cpu().numpy()[0, 0]
# =_p_seg_mask[0,0]
# z=_output_seg[0,0]
plt.scatter(range(len(z)), z)

In [None]:
output["segmentation"].shape

In [None]:
X[:, 1, :].mean(), X[:, 1, :].std()

In [None]:
Y_rad

In [None]:
output.shape, Y_rad.shape

In [None]:
segmentation[0]["all_windows_stats"].shape

In [None]:
segmentation_mask(X, segmentation)

In [None]:
m(X)

In [None]:
_X = X.clone().to(torch_device)
_X[:, :2] /= 500
batch_size, input_channels, session_size = _X.shape
beam_former_input = _X.transpose(1, 2).reshape(
    batch_size * session_size, input_channels
)
print(_X.device, beam_former_input)
beam_former = m.beam_net(beam_former_input).reshape(
    batch_size, session_size, 5  # mu, o1, o2, k1, k2
)
mask_weights = m.softmax(m.unet1d(_X)[:, 0])

In [None]:
beam_former_input

In [None]:
ds[0]

In [None]:
seg_mask.sum(axis=1)

In [None]:
seg_mask.cpu().detach().numpy()[0].sum()

In [None]:
first_n = 40000
x = X[0].cpu()

fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].scatter(range(first_n), x[0, :first_n], s=0.3)
axs[0].scatter(range(first_n), x[1, :first_n], s=0.3)
axs[1].scatter(range(first_n), x[2, :first_n], s=0.3)
# mw = mask_weights.cpu().detach().numpy()
mw = m(X).cpu().detach().numpy()[0]
axs[2].scatter(range(first_n), mw[0, :first_n], s=0.3)
axs[2].scatter(range(first_n), seg_mask.cpu().detach().numpy()[0, :first_n], s=0.3)

In [None]:
mask_weights[0]

In [None]:
from spf.model_training_and_inference.models.beamsegnet import BeamNSegNetDirect


m = BeamNSegNetDirect(nthetas=nthetas)

optimizer = torch.optim.AdamW(m.parameters(), lr=0.01)

m.beam_net.beam_net[0].weight.grad

In [None]:
x, y = next(iter(train_dataloader))

In [None]:
k = x[[0]]
k_y = y[[0]]
k[:, 2] = -k[:, 2].sign() * k[:, 2]
# k[:, 2] = k[:, 2].sign() * k[:, 2]

In [None]:
k[:, 2]

In [None]:
optimizer.zero_grad()
m.train()
m.beam_net.beam_net[0].weight.grad

In [None]:
X.max()

In [None]:
output = m(k)

loss_fn = torch.nn.MSELoss()
l = loss_fn(output, k_y)
l.backward()
# mean_loss = output
# optimizer.step()

In [None]:
output

In [None]:
m.beam_net.beam_net[0].weight.grad

In [None]:
plt.imshow(Y.to("cpu"))