In [2]:
# %load Modules/plot_densities.py
"""
Created on Tue Aug 31

@author: kruu

Plot script for go-arounds generation marginal densities
"""

# %%
import pandas as pd
import altair as alt

alt.data_transformers.disable_max_rows()

# Data
df_real = pd.read_pickle("Data/distributions_along_lines.pkl")
df_generated = pd.read_pickle("Data/generated_GM_and_sampling.pkl")

df_real.iloc[:, 21:25]*=-1

def marginals(df_real, df_generated):

    data_real = (
        df_real.melt(ignore_index=False)
        .rename(columns=dict(variable="step"))
        .reset_index(level=0, inplace=False)
    )
    data_gen = (
        df_generated.melt(ignore_index=False)
        .rename(columns=dict(variable="step"))
        .reset_index(level=0, inplace=False)
    )

    # Brush for selection
    selec1 = alt.selection_multi(
        encodings=["color"], init=[{"index": data_real.index.values[0]}]
    )
    color1 = alt.condition(
        selec1, alt.Color("index:N", legend=None), alt.value("lightgrey")
    )
    opacity1 = alt.condition(selec1, alt.value(1), alt.value(0.1))

    selec2 = alt.selection_multi(
        encodings=["color"], init=[{"index": data_gen.index.values[0]}]
    )
    color2 = alt.condition(
        selec2, alt.Color("index:N", legend=None), alt.value("lightgrey")
    )
    opacity2 = alt.condition(selec2, alt.value(1), alt.value(0.1))

    chart1 = (
        alt.Chart(data_real, title="Real Data")
        .mark_point()
        .encode(alt.Row("step"), alt.X("value"), color=color1, opacity=opacity1)
        .add_selection(selec1)
    )

    chart2 = (
        alt.Chart(data_gen, title="Generated Data")
        .mark_point()
        .encode(alt.Row("step"), alt.X("value"), color=color2, opacity=opacity2)
        .add_selection(selec2)
    )

    return chart1 | chart2


def ridgeline(df):

    data = df.melt().rename(columns=dict(variable="step"))

    step = 20
    overlap = 1

    chart = (
        alt.Chart(data, height=step)
        .mark_area(
            interpolate="monotone", fillOpacity=0.8, stroke="lightgray", strokeWidth=0.5
        )
        .encode(
            alt.Row(
                "step:N",
                title=None,
                header=alt.Header(
                    labelAngle=0,
                    labelAlign="left",
                    labelFont="Ubuntu",
                    labelFontSize=14,
                ),
            ),
            alt.X("bin_min:Q", axis=None),
            alt.Y("count:Q", axis=None, scale=alt.Scale(range=[step, -step * overlap])),
            alt.Fill(
                "std:Q",
                legend=None,
                scale=alt.Scale(domain=(5000, 0), scheme="redyellowblue"),
            ),
        )
        .transform_joinaggregate(
            std="stdev(value)",
            groupby=["step"],
        )
        .transform_bin(["bin_max", "bin_min"], "value", bin=alt.Bin(maxbins=30))
        .transform_joinaggregate(
            count="count()",
            groupby=["step", "bin_min", "bin_max"],
        )
        .transform_impute(
            impute="count", groupby=["step", "std"], key="bin_min", value=0
        )
        .properties(title="Ridgeline plot of marginal distributions", bounds="flush")
        .configure_facet(
            spacing=0,
        )
        .configure_view(stroke=None)
        .configure_title(anchor="start", font="Ubuntu", fontSize=16)
        .configure_axis()
    )

    return chart


# %%

#marginals(df_real, df_generated).save(
#    "marginals.html", embed_options={"renderer": "svg"}
#)
#ridgeline(df_real).save("ridgeline_true.html", embed_options={"renderer": "svg"})


In [3]:
marginals(df_real, df_generated)

In [24]:
ridgeline(df_real)

In [57]:
data = df_real.melt().rename(columns=dict(variable="step"))

step = 20
overlap = 1


data_real = (
    df_real.melt(ignore_index=False)
    .rename(columns=dict(variable="step"))
    .reset_index(level=0, inplace=False)
    .assign(bin_min="value")
)


chart = (
    alt.layer(
        alt.Chart(height=step)
        .mark_area(
            interpolate="monotone", fillOpacity=0.8, stroke="lightgray", strokeWidth=0.5
        )
        .encode(
            alt.X(
                "bin_min:Q",
                axis=None,
                scale=alt.Scale(domain=[-4000, 12000], clamp=True),
            ),
            alt.Y("count:Q", axis=None, scale=alt.Scale(range=[step, -step * overlap])),
            alt.Fill(
                "std:Q",
                legend=None,
                scale=alt.Scale(domain=(5000, 0), scheme="redyellowblue"),
            ),
        )
        .transform_joinaggregate(
            std="stdev(value)",
            groupby=["step"],
        )
        .transform_bin(["bin_max", "bin_min"], "value", bin=alt.Bin(maxbins=30))
        .transform_joinaggregate(
            count="count()",
            groupby=["step", "bin_min", "bin_max"],
        )
        .transform_impute(
            impute="count", groupby=["step", "std"], key="bin_min", value=0
        ),
        alt.Chart()
        .mark_point(color="red")
        .transform_filter("datum.index==23 ")
        .encode(
            alt.X("v"),
        ),
        data=data.assign(v=data.value),
    )
    .facet(row="step:N")
    .properties(title="Ridgeline plot of marginal distributions", bounds="flush")
    .configure_facet(spacing=0)
    .configure_header(
        labelAngle=0,
        labelAlign="left",
        labelFont="Ubuntu",
        labelFontSize=14,
        titleFontSize=0,
    )
    .configure_view(stroke=None)
    .configure_title(anchor="start", font="Ubuntu", fontSize=16)
    .configure_axis()
)
chart

In [56]:
chart.data

Unnamed: 0,step,value,v
0,0,-25.045851,-25.045851
1,0,18.036940,18.036940
2,0,-35.431216,-35.431216
3,0,-152.014982,-152.014982
4,0,204.968214,204.968214
...,...,...,...
12205,29,0.200111,0.200111
12206,29,2.740724,2.740724
12207,29,33.514343,33.514343
12208,29,3.705007,3.705007


In [62]:
data_real = (
    df_real.melt(ignore_index=False)
    .rename(columns=dict(variable="step"))
    .reset_index(level=0, inplace=False)
    .assign(bin_min="value")
)
(
    alt.Chart(data_real)
    .mark_point()
    .transform_filter("datum.index==305")
    .encode(alt.X("value"), color="index:N")
)

In [42]:
%load_ext lab_black

In [74]:
chart = (
    (
        alt.Chart(data, height=step)
        .mark_area(
            interpolate="monotone", fillOpacity=0.8, stroke="lightgray", strokeWidth=0.5
        )
        .encode(
            alt.Row(
                "step:N",
                title=None,
                header=alt.Header(
                    labelAngle=0,
                    labelAlign="left",
                    labelFont="Ubuntu",
                    labelFontSize=14,
                ),
            ),
            alt.X(
                "bin_min:Q",
                axis=None,
                scale=alt.Scale(domain=[-4000, 12000], clamp=True),
            ),
            alt.Y("count:Q", axis=None, scale=alt.Scale(range=[step, -step * overlap])),
            alt.Fill(
                "std:Q",
                legend=None,
                scale=alt.Scale(domain=(5000, 0), scheme="redyellowblue"),
            ),
        )
        .transform_joinaggregate(
            std="stdev(value)",
            groupby=["step"],
        )
        .transform_bin(["bin_max", "bin_min"], "value", bin=alt.Bin(maxbins=30))
        .transform_joinaggregate(
            count="count()",
            groupby=["step", "bin_min", "bin_max"],
        )
        .transform_impute(
            impute="count", groupby=["step", "std"], key="bin_min", value=0
        )
        .properties(title="Ridgeline plot of marginal distributions", bounds="flush")
        | alt.Chart(
            data_real,
        )
        .mark_point(color="#e45756")
        .transform_filter("datum.index==21")
        .encode(
            alt.Row(
                "step",
                title=None,
                header=alt.Header(
                    labelAngle=0,
                    labelAlign="left",
                    labelFont="Ubuntu",
                    labelFontSize=14,
                ),
            ),
            alt.X(
                "value",
                axis=None,
                scale=alt.Scale(domain=[-4000, 12000], clamp=True),
            ),
        )
    )
    .configure_facet(spacing=0)
    .configure_view(stroke=None)
    .configure_title(anchor="start", orient="top", font="Ubuntu", fontSize=20, dy=-10)
    .configure_axis()
)
chart

In [70]:
!mamba  install -y altair_saver


                  __    __    __    __
                 /  \  /  \  /  \  /  \
                /    \/    \/    \/    \
███████████████/  /██/  /██/  /██/  /████████████████████████
              /  / \   / \   / \   / \  \____
             /  /   \_/   \_/   \_/   \    o \__,
            / _/                       \_____/  `
            |/
        ███╗   ███╗ █████╗ ███╗   ███╗██████╗  █████╗
        ████╗ ████║██╔══██╗████╗ ████║██╔══██╗██╔══██╗
        ██╔████╔██║███████║██╔████╔██║██████╔╝███████║
        ██║╚██╔╝██║██╔══██║██║╚██╔╝██║██╔══██╗██╔══██║
        ██║ ╚═╝ ██║██║  ██║██║ ╚═╝ ██║██████╔╝██║  ██║
        ╚═╝     ╚═╝╚═╝  ╚═╝╚═╝     ╚═╝╚═════╝ ╚═╝  ╚═╝

        mamba (0.8.2) supported by @QuantStack

        GitHub:  https://github.com/mamba-org/mamba
        Twitter: https://twitter.com/QuantStack

█████████████████████████████████████████████████████████████


Looking for: ['altair_saver']

pkgs/r/osx-64            [=>                  ] (--:--) No change
pkgs/r/noarch   

In [75]:
chart.save("marginal.svg")

WARN row encoding should be discrete (ordinal / nominal / binned).


In [1]:
%load_ext lab_black

In [2]:
import json

metrics = json.load(open("Data/metrics.json"))
metrics

In [3]:
metrics = {
    "MVN_raw": [8.8, 61178.30201811782, 60.000635726534604, 10.864821434952292],
    "MVN_reduced": [11.8, 47332.59438142815, 34.41678462089916, 12.700828078599761],
    "GM_raw": [7.4, 11564.296236698683, 59.78776294398959, 10.789795105968635],
    "GM_reduced": [9.2, 7404.061489430264, 34.445612065143926, 12.683910931466047],
    "Vines_raw": [11.6, 18743.199396777567, 328.6718772137646, 10.122038700233203],
    "Vines_reduced": [13.6, 10302.834056618314, 84.08505582999118, 12.573063150351294],
}

In [2]:
a = pd.DataFrame(metrics)
b = a.div(a.max(axis=1), axis=0)
norm = b.to_dict(orient="list")
norm

In [9]:
df = pd.DataFrame.from_records(metrics).T.rename(
    columns={
        0: "Number of turns",
        1: "e-distance",
        2: "Mahalanobis to mean",
        3: "Mean Mahalanobis",
    },
    index={
        "GM_raw": "GMM",
        "GM_reduced": "GMM with DR",
        "MVN_raw": "MVN",
        "MVN_reduced": "MVN with DR",
        "Vines_raw": "Vine Copula",
        "Vines_reduced": "Vine Copula with DR",
    },
)
df

Unnamed: 0,Number of turns,e-distance,Mahalanobis to mean,Mean Mahalanobis
GMM,7.4,11564.296237,59.787763,10.789795
GMM with DR,9.2,7404.061489,34.445612,12.683911
MVN,8.8,61178.302018,60.000636,10.864821
MVN with DR,11.8,47332.594381,34.416785,12.700828
Vine Copula,11.6,18743.199397,328.671877,10.122039
Vine Copula with DR,13.6,10302.834057,84.085056,12.573063


In [12]:
chart = (
    alt.Chart(df.melt(ignore_index=False).reset_index())
    .encode(
        alt.X("value", title=None),
        alt.Y("index", title=None, axis=None),
        alt.Facet(
            "variable",
            title=None,
            header=alt.Header(
                # orient="top",
                labelAnchor="start",
                labelFont="Fira Sans",
                labelFontWeight="bold",
                labelFontSize=18,
            ),
        ),
        alt.Color("index", title="Method", scale=alt.Scale(scheme="tableau20")),
    )
    .mark_bar()
    .configure_axis(
        labelAngle=0,
        labelAlign="left",
        labelFont="Ubuntu",
        labelFontSize=14,
    )
    .configure_legend(
        labelFont="Ubuntu",
        labelFontSize=16,
        titleFont="Fira Sans",
        titleFontSize=18,
        orient="bottom",
    )
    .configure_header(labelOrient="top")
    .configure_facet(columns=2)
    .resolve_scale(x="independent")
)
chart

In [13]:
chart.save("metrics.svg")