In [None]:
import sys
from pathlib import Path

import asdf
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch as xp
from matplotlib.gridspec import GridSpec
from scipy import stats

import stream_ml.pytorch as sml
import stream_ml.visualization as smlvis

# Add the parent directory to the path
sys.path.append(Path().resolve().parents[1].as_posix())
sys.path.append((Path().resolve().parents[2] / "scripts").as_posix())
# isort: split

import paths  # noqa: E402

In [None]:
# Matplotlib style
plt.style.use(paths.scripts / "paper.mplstyle")

In [None]:
with asdf.open(paths.data / "mock" / "data.asdf") as af:
    data = sml.Data(**af["data"]).astype(xp.Tensor, dtype=xp.float32)
    stream_table = af["stream_table"]
    table = af["table"]
    n_stream = af["n_stream"]
    n_background = af["n_background"]

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(14, 13))
gs = GridSpec(2, 1, height_ratios=(1, 1), figure=fig)
gs0 = gs[0].subgridspec(3, 1, height_ratios=(5, 5, 5))

cmap = plt.get_cmap()

# Weight plot
ax01 = fig.add_subplot(gs0[0, :])
ax01.set(ylabel="Stream fraction", ylim=(0, 0.5))
ax01.set_xticklabels([])

# Truth (at some bandwidth)
# TODO: replace with histograms.
phi1 = stream_table["phi1"].to_value("deg")
bw_method = 0.045
sk = stats.gaussian_kde(phi1, bw_method=bw_method)
a, b = phi1.min(), phi1.max()
on_stream = (a < data["phi1"]) & (data["phi1"] < b)
tk = stats.gaussian_kde(table["phi1"], bw_method=bw_method)
x = data["phi1"][on_stream]
ax01.plot(
    x,
    ((sk(x) / tk(x)) * (n_stream / (n_stream + n_background))),
    c="k",
    label=f"Stream (`true', bw={bw_method})",
)
ax01.legend(loc="upper left")

# Phi2
ax02 = fig.add_subplot(gs0[1, :])
ax02.set_xticklabels([])
ax02.set(ylabel=r"$\phi_2$ [$\degree$]")

ax02.scatter(
    table["phi1"].to_value("deg"),
    table["phi2"].to_value("deg"),
    s=3,
    c="gray",
    alpha=0.5,
    zorder=-100,
    label="Ground Truth",
)
ax02.scatter(
    phi1,
    stream_table["phi2"].to_value("deg"),
    s=10,
    c="k",
    alpha=0.5,
    zorder=-100,
    label="Ground Truth",
)
ax02.legend(loc="upper left")

# Distance
ax03 = fig.add_subplot(gs0[2, :])
ax03.set(xlabel=r"$\phi_1$ [deg]", ylabel=r"$\varpi$ [mas yr$^-1$]")

k_dist = "parallax"
ax03.scatter(
    table["phi1"].to_value("deg"),
    table["parallax"].to_value("mas"),
    s=3,
    c="gray",
    alpha=0.5,
    zorder=-100,
    label="Ground Truth",
)
ax03.scatter(
    phi1,
    stream_table["parallax"].to_value("mas"),
    s=10,
    c="k",
    alpha=0.5,
    zorder=-100,
    label="Ground Truth",
)
ax03.legend(loc="upper left")


# Slice plots
gs1 = gs[1].subgridspec(4, 4)

# Bin the data for plotting
bins = np.linspace(data["phi1"].min(), data["phi1"].max(), num=5, endpoint=True)
which_bin = np.digitize(data["phi1"], bins[:-1])

for i, b in enumerate(np.unique(which_bin)):
    sel = which_bin == b

    # ---------------------------------------------------------------------------
    # Phi2

    ax10i = fig.add_subplot(gs1[0, i])

    # Connect to top plot(s)
    for ax in (ax01, ax02):
        ax.axvline(bins[i], color="gray", ls="--", zorder=-200)
        ax.axvline(bins[i + 1], color="gray", ls="--", zorder=-200)
    smlvis._slices.connect_slices_to_top(  # noqa: SL
        fig, ax03, ax10i, left=bins[i], right=bins[i + 1], color="gray"
    )

    cphi2s = np.ones((sel.sum(), 2)) * table["phi2"][sel][:, None]
    ws = np.stack(
        (table["label"][sel] == "background", table["label"][sel] == "stream"),
        axis=1,
        dtype=int,
    )
    ax10i.hist(
        cphi2s,
        bins=50,
        weights=ws,
        color=[cmap(0.01), cmap(0.99)],
        density=True,
        stacked=True,
        label=["", "Ground Truth"],
    )

    ax10i.set_xlabel(r"$\phi_2$ [$\degree$]")
    if i == 0:
        ax10i.set_ylabel("frequency")
        ax10i.legend(loc="upper left")

    # ---------------------------------------------------------------------------
    # Distance

    ax11i = fig.add_subplot(gs1[1, i])

    cplxs = np.ones((sel.sum(), 2)) * table["parallax"][sel].value[:, None]
    ws = np.stack(
        (table["label"][sel] == "background", table["label"][sel] == "stream"),
        axis=1,
        dtype=int,
    )
    ax11i.hist(
        cplxs,
        bins=50,
        weights=ws,
        color=[cmap(0.01), cmap(0.99)],
        density=True,
        stacked=True,
        label=["", "Ground Truth"],
    )

    ax11i.set_xlabel(r"$\varpi$ [mas]")
    if i == 0:
        ax11i.set_ylabel("frequency")
        ax11i.legend(loc="upper left")

    # ---------------------------------------------------------------------------
    # Photometry

    # ------------------------------------------
    # Stream

    ax12i = fig.add_subplot(gs1[2, i])

    prob = np.array(table["label"] == "stream", dtype=int)[sel]
    sorter = np.argsort(prob)
    ax12i.scatter(
        data["g"][sel][sorter],
        data["r"][sel][sorter],
        c=prob[sorter],
        cmap="seismic",
        s=1,
        rasterized=True,
    )
    ax12i.set_xticklabels([])

    if i == 0:
        ax12i.set_ylabel("r [mag]")
    else:
        ax12i.set_yticklabels([])

    # ------------------------------------------
    # Background

    ax13i = fig.add_subplot(gs1[3, i])
    prob = np.array(table["label"] == "background", dtype=int)[sel]
    sorter = np.argsort(prob)
    ax13i.scatter(
        data["g"][sel][sorter],
        data["r"][sel][sorter],
        c=1 - prob[sorter],
        cmap="seismic",
        s=1,
        rasterized=True,
    )
    ax13i.set(xlabel="g [mag]")

    if i == 0:
        ax13i.set_ylabel("r [mag]")
    else:
        ax13i.set_yticklabels([])


plt.savefig(paths.figures / "mock" / "diagnostic" / "data.pdf")
fig.show();