In [8]:
from plotly import express as px

In [2]:
import plotly.graph_objects as go
import polars as pl
import torch

from cci.utils import project_dir

res_dir = project_dir() / "results/CNN/69"
# model = torch.load(res_dir / "0_model.pt", map_location=torch.device("cpu"))
# print()

In [14]:
test_df = pl.read_csv(res_dir / "*_test.csv").drop("epoch")
test_df

acc,loss,f1,auroc,precision,recall,specificity,bac,TP,FP,FN,TN
f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,i64,i64
0.700935,0.635369,0.809524,0.641954,0.839506,0.781609,0.35,0.565805,13,13,19,68
0.757009,0.633239,0.855556,0.615517,0.827957,0.885057,0.2,0.542529,16,16,10,77
0.682243,0.573625,0.797619,0.636207,0.82716,0.770115,0.3,0.535057,14,14,20,67
0.728972,0.550672,0.834286,0.63046,0.829545,0.83908,0.25,0.54454,15,15,14,73
0.663551,0.613958,0.775,0.62931,0.849315,0.712644,0.45,0.581322,11,11,25,62


In [20]:
import numpy as np

from cci.metrics import confusion_matrix

for tp, fp, fn, tn in test_df.select(["TP", "FP", "FN", "TN"]).rows():
    fig = confusion_matrix(np.array([[tp, fp], [fn, tn]]))

UFuncTypeError: ufunc 'add' did not contain a loop with signature matching types (dtype('<U4'), dtype('int64')) -> None

In [10]:
def get_df(group) -> pl.DataFrame:
    df = pl.read_csv(res_dir / f"0_{group}.csv").with_columns(
        pl.lit(0).alias("fold"), pl.lit(group).alias("group"), pl.lit(f"{group}_0").alias("identifier")
    )

    for i in range(1, 5):
        df = df.vstack(
            pl.read_csv(res_dir / f"{i}_{group}.csv").with_columns(
                pl.lit(i).alias("fold"), pl.lit(group).alias("group"), pl.lit(f"{group}_{i}").alias("identifier")
            )
        )

    return df


df = get_df("train")

df = df.vstack(get_df("val"))


# fig = px.line(df, x="epoch", y="loss", line_group="identifier", color="group", height=800)


def avg_plot(metric: str):
    fig = go.Figure()
    for x, color in (["train", "blue"], ["val", "red"]):
        df = get_df(x)
        loss_df = (
            df.group_by("epoch", maintain_order=True)
            .agg(pl.col(metric).mean().alias("avg"), pl.col(metric).std().alias("std"))
            .to_pandas()
        )
        fig.add_trace(
            go.Scatter(
                name=f"Avg {x} {metric}",
                x=loss_df["epoch"],
                y=loss_df["avg"],
                mode="lines",
            )
        )
        fig.add_trace(
            go.Scatter(
                name=f"{x}upper bound",
                x=loss_df["epoch"],
                y=loss_df["avg"] + loss_df["std"],
                line=dict(width=0),
                marker=dict(color=color),
                showlegend=False,
            )
        )
        fig.add_trace(
            go.Scatter(
                name=f"{x}lower bound",
                x=loss_df["epoch"],
                y=loss_df["avg"] - loss_df["std"],
                line=dict(width=0),
                marker=dict(color=color),
                fill="tonexty",
                showlegend=False,
            ),
        )
    fig.update_layout(hovermode="x", height=800)
    fig.show()


avg_plot("loss")

In [11]:
avg_plot("bac")

In [12]:
avg_plot("acc")

In [13]:
avg_plot("f1")

In [14]:
avg_plot("auroc")