In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')
from sub_projects.ray_optimization.configuration import params_to_func
from ray_optim.plot import Plot
import torch
from matplotlib import pyplot as plt

## Setup

In [5]:
samples_count = 3
z_count = 3
compensated = [torch.randn(z_count, 1000, 2) * 0.1 + 0.1 + 1.3 * i for i in range(samples_count)]
target = [torch.randn(z_count, 1000, 2) * 0.1 + 1.3 * i for i in range(samples_count)]
without_compensation = [torch.randn(z_count, 1000, 2) * 0.1 + 0.4 + 1.3 * i for i in range(samples_count)]

## Tests
### Footprint

In [None]:
Plot.plot_data([torch.randn(3, 1000, 2) for i in range(3)])

### Fixed position

In [None]:
lims_x = (-2.0, 2.0)
lims_y = (-2.0, 2.0)
lims_x, lims_y = Plot.switch_lims_if_out_of_lim(target, lims_x=lims_x, lims_y=lims_y)
Plot.fixed_position_plot(
    compensated, target, without_compensation, lims_x, lims_y, epoch=42, training_samples_count=4
)

In [None]:
Plot.fixed_position_plot(
    compensated[:1], target[:1], without_compensation[:1], (-2.0, 2.0), (-2.0, 2.0), epoch=12, training_samples_count=12
)

In [None]:
Plot.compensation_plot(
    compensated, target, without_compensation, epoch=42, training_samples_count=20, covariance_ellipse=True
)

### Parameter comparison

In [None]:

parameters = {
            "x_var": [-5., 5.],
            "y_var": [0., 10.],
            }
search_space = params_to_func(parameters)()
real_params = params_to_func({"x_var": 0, "y_var":5})()
predicted_params = params_to_func({"x_var": -2.5, "y_var":10.})()
Plot.plot_param_comparison(epoch=42, real_params=real_params, predicted_params=predicted_params, search_space=search_space, training_samples_count=4)

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


# [[[3, 1000, 2] ... ], [[3, 1000, 2], ...]]
def get_scatter_xyz(ray_tensor: torch.Tensor):
    y = ray_tensor.flatten(0, 1)[:, 0]
    z = ray_tensor.flatten(0, 1)[:, 1]
    x = torch.cat(
        [torch.ones_like(ray_tensor[0, :, 0]) * i for i in range(ray_tensor.shape[0])]
    )
    return x, y, z


def fancy_ray(data: list[torch.Tensor], labels: list[str] | None = None, max_cols=4):
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    cols = min(max_cols, len(data[0]))
    rows = len(data[0]) // cols
    if len(data[0]) % max_cols != 0:
        rows += 1
    specs = [[{"type": "scatter3d"} for _ in range(cols)] for _ in range(rows)]
    fig = make_subplots(rows=rows, cols=cols, specs=specs)

    for sample_idx, sample in enumerate(data[0]):
        for i, list_entry in enumerate(data):
            x, y, z = get_scatter_xyz(list_entry[sample_idx])
            name = labels[i] if labels is not None else None
            trace = dict(
                type="scatter3d",
                x=x,
                y=y,
                z=z,
                name=name,
                mode="markers",
                legendgroup="group" + str(i),
                showlegend=sample_idx == 0,
                line=dict(color=colors[i % len(colors)]),
                opacity=0.3,
            )
            row = sample_idx // max_cols
            col = sample_idx % max_cols
            if i == 0:
                fig.add_trace(trace, row + 1, col + 1)
            else:
                fig.append_trace(trace, row + 1, col + 1)

    fig.update_traces(marker_size=2)
    return fig


fancy_ray(
    [target, compensated, without_compensation],
    ["Target", "Compensated", "Uncompensated"],
)

In [7]:
Plot.fancy_ray(
    [target, compensated, without_compensation],
    ["Target", "Compensated", "Uncompensated"],
    z_index=[-1., 1., 4.]
)