In [None]:
!kaleido_get_chrome

In [None]:
import sys
sys.path.append("..")

In [None]:
import os

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio

pio.templates["plotly_white_custom"] = pio.templates["plotly_white"]

pio.templates["plotly_white_custom"].update({"layout_font_size": 52, "layout_font_family": "Times New Roman", "layout_font_color": "black"})

pio.templates.default = "plotly_white_custom"

labeled = True
oldname = None

In [None]:
b = 5
for name in ["MNIST"]:#, "CIFAR10", "CH Reg"]:
    print(f"==== {name} ====")
    if oldname != name:
        print("Loading Data")
        df = pd.read_pickle(f"../results/data_polys_{name}_Progress.pkl")
        oldname = name
    else:
        print("Using same data")
    # print(df.head())
    print("Rows:", len(df))
    df["Epoch"] = df["Epoch"].astype(int)
    print("Loaded Data")

    gb = "Category"
    value_cols = ["Volume", "Log Volume", "Finite", "Inradius", "# Points"]
    hist = df.groupby(["Epoch", "# Faces", gb])[value_cols].agg({"Volume": "sum", "# Points": "sum", "Finite": "sum", "Inradius": "sum"})
    hist["# Polys"] = (
        df.groupby(["Epoch", "p", gb], sort=False).first().groupby(["Epoch", "# Faces", gb])[value_cols].size()
    )
    hist["Finite %"] = (100*hist["Finite"] / hist["# Polys"]).values.astype(int)
    hist["Average Volume"] = np.log(hist["Volume"] / hist["# Polys"])
    if "Inradius" in hist.columns: hist["Average Inradius"] = hist["Inradius"] / hist["# Polys"]
    hist["Log # Points"] = np.log(hist["# Points"])/np.log(b)
    hist["Log # Polys"] = np.log(hist["# Polys"])/np.log(b)
    hist = hist.reset_index()
    hist.set_index(["# Faces", gb]).sort_index()

    
    print(df[df["# Points"] == 0]["# Faces"].mean())
    hist["Category"] = hist["Category"].astype(str)

    hist = hist.sort_values(["Epoch", "Category"], ascending=[True, False])
    hist = hist[hist["Epoch"] <= 15]

    figs = dict()

    d_hist = hist[hist["Category"] == "Data Points"]
    figs["FvP Polys"] = px.bar(
        d_hist,
        x="# Faces",
        y="# Points",
        log_y=False,
        hover_data=[gb, "Log # Points", "# Polys"],
        color="# Polys",
        width=2000,
        height=1400,
        # barmode="relative",
        opacity=1,
        facet_col="Epoch",
        facet_col_wrap=4,
        color_continuous_scale="Darkmint_r",
    )
    figs["FvP Polys"].for_each_annotation(lambda a: a.update(text=f"Epoch {a.text.split("=")[-1]}"))
    # ticks = [b**i for i in range(int(np.ceil(np.log(d_hist["# Polys"].max())/np.log(b))))]
    figs["FvP Polys"].update_layout(legend_title_text="# Polys", yaxis_title="", xaxis_title="", showlegend=labeled, 
        coloraxis={
            "colorbar": {"title": {"text": "Number of<br>Polyhedrons<br><sup>&nbsp;</sup>", "side": "top"}, 
            "yanchor": "top", 
            "y":1.1, 
            "len": 1.12,},
            "showscale": labeled}
    )
    figs["FvP Polys"].add_annotation(
        showarrow=False,
        xanchor="center",
        xref="paper",
        x=-0.09,
        yref="paper",
        y=0.5,
        textangle=-90,
        text="Number of Data Points"
    )


    figs["FvP Points Single"] = px.bar(
        hist,
        x="# Faces",
        y="# Polys",
        log_y=True,
        hover_data=[gb, "# Points"],
        color="Log # Points",
        width=2000,
        height=1400,
        barmode="relative",
        opacity=1,
        facet_col="Epoch",
        facet_col_wrap=4,
    )
    figs["FvP Points Single"].for_each_annotation(lambda a: a.update(text=f"Epoch {a.text.split("=")[-1]}"))
    figs["FvP Points Single"].update_yaxes(dtick=1)

    ticks = [b**i for i in range(int(np.ceil(np.log(hist["# Points"].max())/np.log(b))))]
    figs["FvP Points Single"].update_layout(legend_title_text="# Points", yaxis_title="", xaxis_title="", showlegend=labeled, 
        coloraxis={
            "colorbar": {"title": {"text": "Number<br>of Points<br><sup>&nbsp;</sup>", "side": "top"}, 
            # "xanchor":"right",
            "yanchor": "top", 
            "y":1.1, 
            # "x": 1,
            "len": 1.12,
            "tickvals": [np.log(tick)/np.log(b) for tick in ticks], "ticktext": ticks}, 
            "cmax": np.log(hist["# Points"].max())/np.log(b), 
            "cmin": 0,
            "showscale": labeled}
    )
    figs["FvP Points Single"].add_annotation(
        showarrow=False,
        xanchor="center",
        xref="paper",
        x=-0.09,
        yref="paper",
        y=0.5,
        textangle=-90,
        text="Number of Polyhedrons"
    )

    os.makedirs(f"../figures/{name}", exist_ok=True)
    for figname, fig in figs.items():
        fig.for_each_xaxis(lambda x: x.update({"title": "", "tickvals": [5, 10, 15]}))
        fig.for_each_yaxis(lambda x: x.update({"title": ""}))
        fig.add_annotation(
            showarrow=False,
            xanchor="center",
            xref="paper",
            x=0.5,
            yref="paper",
            y=-0.11,
            text="Number of Neighbors"
        )
        fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
        fig.show()
        path = f"../figures/{name}/{figname}_{name}"
        fig.write_image(path + ("_labeled" if labeled else "_unlabeled") + "_progress.svg")
        print("Wrote image to", path + ("_labeled" if labeled else "_unlabeled") + "_progress.svg")
