# Statistical Summary/Analysis
Use this notebook to calculate summary stats, correlation analyses
and other useful metrics for training and evalution data

In [None]:
import os
from typing import Literal
from glob import glob

import pandas as pd

Analysis = Literal["summary", "error"]
ANALYSES: list[Analysis]
COLS_TO_DROP = None
FILTER_QUERY = None
PREVIEW_DATA = False
TARGETS = None
"""show a preview of data before analysis"""

# LOL dfs win score prediction
# DATA_FILES = [os.path.join(
#     "/", "fantasy-experiments", "df-hist", "data", "lol-draftkings-CLASSIC-GPP.csv"
# )]
# COLS_TO_DROP = ["slate_id", "link", "style", "type", "date"]
# FILTER_QUERY = "slate_id.notna()"
# ANALYSES = ["summary"]

# DFS score prediction results for
DATA_FILES = [
    filepath
    for filepath in glob(
        os.path.join(
            "/",
            "fantasy-experiments",
            "df-hist",
            "eval_results",
            "*.prediction.csv",
        )
    )
]
ANALYSES = ["error"]

In [None]:
from math import sqrt

from matplotlib import pyplot as plt
from sklearn import metrics


def load(path: str, cols_to_drop: list[str] | None = None, filter_query: str | None = None):
    """
    filter_query: Rows not matching this query will be dropped
    """
    df = pd.read_csv(path)
    file_len = len(df)
    if filter_query:
        df = df.query(filter_query)
        print(f"Filter query dropped {file_len - len(df)} rows, {len(df)} remaining")
    if cols_to_drop is not None:
        print(f"Dropping columns: {cols_to_drop}")
        df = df.drop(columns=cols_to_drop)
    return df


def summarize(df: pd.DataFrame, targets: list[str] | None):
    if targets:
        raise NotImplementedError()
    summary = {
        "desc": df.describe(),
        "corr-cross": df.corr(),
    }
    return summary


def error_analysis(df: pd.DataFrame, desc):
    assert {"truth", "prediction", "error"} <= set(df.columns)

    r2 = round(metrics.r2_score(df.truth, df.prediction), 4)
    rmse = round(sqrt(metrics.mean_squared_error(df.truth, df.prediction)), 4)
    mae = round(sqrt(metrics.mean_absolute_error(df.truth, df.prediction)), 4)

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle(f"{desc or 'unknown model'} : n={len(df)} {r2=} {rmse=} {mae=}")
    for ax in axs:
        ax.axis("equal")

    min_v = min(df.truth.min(), df.prediction.min())
    max_v = max(df.truth.max(), df.prediction.max())

    axs[0].plot((min_v, max_v), (min_v, max_v), "-g", linewidth=1)
    df.plot(kind="scatter", x="truth", y="prediction", ax=axs[0])

    axs[1].yaxis.set_label_position("right")
    axs[1].plot((min_v, max_v), (0, 0), "-g", linewidth=1)
    df.plot(kind="scatter", x="truth", y="error", ax=axs[1])
    # plt.close(axs)

In [None]:
for filepath in DATA_FILES:
    df = load(filepath, filter_query=FILTER_QUERY, cols_to_drop=COLS_TO_DROP)
    if PREVIEW_DATA:
        display(
            f"data n={len(df)}",
            df.style.hide()
        )

    if "summary" in ANALYSES:
        summary = summarize(df, TARGETS)
        for name, df in summary.items():
            display(name, df)

    if "error" in ANALYSES:
        error_analysis(df, os.path.basename(filepath))