In [1]:
from plotly import express as px

In [107]:
import plotly.graph_objects as go
import polars as pl

from cci.utils import project_dir

res_dir = project_dir() / "results_go/CNN/15431867f2202e8923bccde32d31a8a277c5afcf95242cfb7df373cd89f1b2d3"


def get_df(group) -> pl.DataFrame:
    df = pl.read_csv(res_dir / f"0_{group}.csv").with_columns(
        pl.lit(0).alias("fold"), pl.lit(group).alias("group"), pl.lit(f"{group}_0").alias("identifier")
    )

    for i in range(1, 5):
        df = df.vstack(
            pl.read_csv(res_dir / f"{i}_{group}.csv").with_columns(
                pl.lit(i).alias("fold"), pl.lit(group).alias("group"), pl.lit(f"{group}_{i}").alias("identifier")
            )
        )

    return df


df = get_df("train")

df = df.vstack(get_df("val"))


# fig = px.line(df, x="epoch", y="loss", line_group="identifier", color="group", height=800)


def avg_plot(metric: str):
    fig = go.Figure()
    for x, color in (["train", "blue"], ["val", "red"]):
        df = get_df(x)
        loss_df = (
            df.group_by("epoch", maintain_order=True)
            .agg(pl.col(metric).mean().alias("avg"), pl.col(metric).std().alias("std"))
            .to_pandas()
        )
        fig.add_trace(
            go.Scatter(
                name=f"Avg {x} {metric}",
                x=loss_df["epoch"],
                y=loss_df["avg"],
                mode="lines",
            )
        )
        fig.add_trace(
            go.Scatter(
                name=f"{x}upper bound",
                x=loss_df["epoch"],
                y=loss_df["avg"] + loss_df["std"],
                line=dict(width=0),
                marker=dict(color=color),
                showlegend=False,
            )
        )
        fig.add_trace(
            go.Scatter(
                name=f"{x}lower bound",
                x=loss_df["epoch"],
                y=loss_df["avg"] - loss_df["std"],
                line=dict(width=0),
                marker=dict(color=color),
                fill="tonexty",
                showlegend=False,
            ),
        )
    fig.update_layout(hovermode="x", height=800)
    fig.show()


avg_plot("loss")

In [108]:
avg_plot("bac")

In [109]:
avg_plot("acc")

In [110]:
avg_plot("f1")

In [111]:
avg_plot("auroc")