In [None]:
repo_root = "/Users/miskodzamba/Dropbox/research/gits/spf/"
import sys

if repo_root not in sys.path:
    sys.path.append(repo_root)  # go to parent dir
from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml
from spf.notebooks.simple_train import get_parser, simple_train

import os

create_fake_dataset(
    filename="test_circle_33", yaml_config_str=fake_yaml, n=33, noise=0.0
)

args_list = [
    "--device",
    "cpu",
    "--seed",
    "0",
    "--nthetas",
    "65",
    "--datasets",
    "test_circle_33.zarr",
    # "--positional",
    # "/Users/miskodzamba/Dropbox/research/gits/spf/spf/min.zarr",
    "--batch",
    "128",
    "--workers",
    "0",
    "--hidden",
    "128",
    "--depth",
    "5",
    "--batch-norm",
    "--act",
    "leaky",
    # "--shuffle",
    "--segmentation-level",
    "downsampled",
    "--type",
    "discrete",
    "--seg-net",
    "conv",
    "--epochs",
    "75",
    "--symmetry",
    # "--skip-segmentation",
    "--no-shuffle",
    "--skip-qc",
    # "--symmetry",
    "--no-sigmoid",
    "--val-on-train",
    "--segmentation-lambda",
    "0",
    "--independent",
    # "--no-sigmoid",
    # "--block",
    "--wandb-project",
    "test124",
    "--plot-every",
    "75",
    "--lr",
    "0.001",
    "--precompute-cache",
    "/tmp/",
    "--positional",
]
args = get_parser().parse_args(args_list)

train_results = simple_train(args)

In [None]:
import numpy as np

np.array(train_results["losses"])[-10:].mean()

In [None]:
train_results["losses"][-20:]

In [None]:
train_results["losses"]

In [None]:
import numpy as np
from spf.rf import pi_norm


thetas = pi_norm(np.linspace(0, 2 * np.pi * 1, 4))

In [None]:
thetas

In [None]:
np.linspace(0, 2 * np.pi * 1, 5)

In [None]:
import torch

a = torch.tensor([[1, 2, 3, 4.3], [1, 1, 1, 1]])

In [None]:
torch.nn.functional.normalize(a, p=1, dim=1)

In [None]:
import torch
from spf.dataset.spf_dataset import v5_collate_beamsegnet, v5spfdataset

from spf.dataset.fake_dataset import create_fake_dataset, fake_yaml

create_fake_dataset(
    filename="test_circle_x", yaml_config_str=fake_yaml, n=17, noise=0.0
)
device = "cuda"
nthetas = 65
torch_device = torch.device(device)

torch.manual_seed(100)
import random

random.seed(100)

# loop over and concat datasets here
datasets = [
    v5spfdataset(
        "test_circle_x.zarr",
        precompute_cache="/tmp",
        nthetas=nthetas,
        skip_signal_matrix=False,
        paired=False,
        ignore_qc=True,
    )
]
for ds in datasets:
    ds.get_segmentation()
complete_ds = torch.utils.data.ConcatDataset(datasets)

train_ds = complete_ds
val_ds = complete_ds
print(f"Train-dataset size {len(train_ds)}, Val dataset size {len(val_ds)}")

dataloader_params = {
    "batch_size": 10000,
    "shuffle": False,
    "num_workers": 0,
    "collate_fn": v5_collate_beamsegnet,
}
train_dataloader = torch.utils.data.DataLoader(train_ds, **dataloader_params)

In [None]:
from spf.model_training_and_inference.models.beamsegnet import (
    BeamNetDirect,
    BeamNetDiscrete,
    BeamNSegNet,
    ConvNet,
    UNet1D,
)

act = torch.nn.LeakyReLU
first_n = 256
seg_m = ConvNet(3, 1, hidden=4, act=act, bn=False).to(torch_device)
positional = False


discrete = True
if discrete:
    beam_m = BeamNetDiscrete(
        nthetas=nthetas,
        hidden=128,
        depth=5,
        act=act,
        symmetry=True,  # args.symmetry,
        bn=True,
        positional_encoding=positional,
    ).to(torch_device)
    paired_net = BeamNetDiscrete(
        nthetas=nthetas,
        depth=3,
        hidden=16,
        symmetry=False,
        act=act,
        other=False,
        bn=False,
        no_sigmoid=False,
        block=False,
        rx_spacing_track=-1,
        pd_track=-1,
        mag_track=-1,
        stddev_track=-1,
        inputs=1 * beam_m.outputs,
        norm="batch",
    )
else:
    beam_m = BeamNetDirect(
        nthetas=nthetas,
        depth=6,
        hidden=256,
        symmetry=False,
        act=act,
        other=False,
        bn=False,
        no_sigmoid=True,
        block=False,
        inputs=3,  # + (1 if args.rx_spacing else 0),
        norm="batch",
        positional_encoding=False,
    ).to(torch_device)
    paired_net = BeamNetDirect(
        nthetas=nthetas,
        depth=4,
        hidden=64,
        symmetry=False,
        act=act,
        other=False,
        bn=False,
        no_sigmoid=True,
        block=False,
        rx_spacing_track=-1,
        pd_track=-1,
        mag_track=-1,
        stddev_track=-1,
        inputs=args.n_radios * beam_m.outputs,
        norm="batch",
    )

m = BeamNSegNet(
    segnet=seg_m,
    beamnet=beam_m,
    circular_mean=False,
    segmentation_lambda=0,
    independent=True,
    n_radios=1,
    paired_net=paired_net,
    rx_spacing=False,
).to(torch_device)

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=0.0005, weight_decay=0.0)

In [None]:
def batch_data_to_x_y_seg(batch_data, segmentation_level):
    x = batch_data["all_windows_stats"].to(torch_device).type(torch.float32)
    seg_mask = batch_data["downsampled_segmentation_mask"].to(torch_device)

    rx_spacing = batch_data["rx_spacing"].to(torch_device)

    y_rad = batch_data["y_rad"].to(torch_device).type(torch.float32)
    assert seg_mask.ndim == 3 and seg_mask.shape[1] == 1
    return x, y_rad, seg_mask, rx_spacing

In [None]:
to_log = None
for x in range(50):
    for batch_data in train_dataloader:
        m.train()

        if to_log is None:
            to_log = {
                "loss": [],
                "segmentation_loss": [],
                "beamformer_loss": [],
                "paired_beamformer_loss": [],
            }

        optimizer.zero_grad()

        x, y_rad, seg_mask, rx_spacing = batch_data_to_x_y_seg(
            batch_data, "downsampled"
        )
        y_rad_reduced = reduce_theta_to_positive_y(y_rad)

        output = m(x, seg_mask, rx_spacing)

        m.beamnet.beam_net(output["weighted_input"])

        loss_d = m.loss(output, y_rad_reduced, seg_mask)

        loss = loss_d["beamformer_loss"]
        # loss += loss_d["segmentation_loss"] * 0
        # loss += loss_d["paired_beamformer_loss"] * 0

        loss.backward()
        print(loss.item())

        optimizer.step()

In [None]:
 -m.beamnet.loglikelihood(output["pred_theta"][[1]], y_rad[[1]])

In [None]:
m.beamnet.beam_net(output["weighted_input"]).shape, output[
    "weighted_input"
].shape, y_rad.shape

In [None]:
# output["pred_theta"]

In [None]:
from matplotlib import pyplot as plt

plt.imshow(beam_m.render_discrete_x(output["pred_theta"][::2]).cpu().detach().numpy())

In [None]:
nthetas

In [None]:
import numpy as np

np.linspace(0, nthetas - 1, 5)

In [None]:
plt.imshow(m.beamnet.render_discrete_y(y_rad_reduced[::2]).cpu())
ax = plt.gca()
ax.set_xticks(np.linspace(0, nthetas - 1, 5))
ax.set_xticklabels(["-pi", "-pi/2", "0", "pi/2", "pi"])

# Labels for major ticks
ax.grid(which="major", color="w", linestyle="-", linewidth=2, axis="x")

In [None]:
plt.imshow(m.beamnet.render_discrete_y(y_rad[::2]).cpu())

In [None]:
output["weighted_input"][:, 0].shape, m.beamnet.outputs

In [None]:
from matplotlib import pyplot as plt

plt.plot(output["weighted_input"][:, 0][::2].numpy())
plt.plot(output["weighted_input"][:, 0][1::2].numpy())

# for idx in range(0,x.shape[0],2):
#     print(x[idx,0,seg_mask[0,0]==1].mean())

In [None]:
from matplotlib import pyplot as plt

plt.plot(
    torch.hstack(
        [x[idx, 0, seg_mask[idx, 0] == 1].mean() for idx in range(0, x.shape[0], 2)]
    ).numpy()
)
plt.plot(
    torch.hstack(
        [x[idx, 0, seg_mask[idx, 0] == 1].mean() for idx in range(1, x.shape[0], 2)]
    ).numpy()
)


# for idx in range(0,x.shape[0],2):
#     print(x[idx,0,seg_mask[0,0]==1].mean())

In [None]:
# x[:,0,:]

In [None]:
# m(x, seg_mask, rx_spacing)