# Statistical Summary/Analysis
Use this notebook to calculate summary stats, correlation analyses
and other useful metrics for training and evalution data

In [None]:
import os
from typing import Literal

import pandas as pd

Analysis = Literal["summary", "error"]
ANALYSES: list[Analysis]

# LOL dfs win score prediction
DATA_FILE = os.path.join(
    "/", "fantasy-experiments", "df-hist", "data", "lol-draftkings-CLASSIC-GPP.csv"
)
COLS_TO_DROP = ["slate_id", "link", "style", "type", "date"]
FILTER_QUERY = "slate_id.notna()"
ANALYSES = ["summary"]

# Prediction results for
DATA_FILE = os.path.join(
    "/", "fantasy-experiments", "df-hist", "eval_results", "lol-draftkings-CLASSIC-GPP.prediction.csv"
)
ANALYSES = ["error"] 

In [None]:
import matplotlib as plt

def load(path: str, cols_to_drop: list[str] | None = None, filter_query: str | None = None):
    """
    filter_query: Rows not matching this query will be dropped
    """
    df = pd.read_csv(DATA_FILE)
    file_len = len(df)
    print(f"Loaded n={file_len} from '{DATA_FILE}'")
    if filter_query:
        df = df.query(filter_query)
        print(f"Filter query dropped {file_len - len(df)} rows, {len(df)} remaining")
    if cols_to_drop is not None:
        print(f"Dropping columns: {cols_to_drop}")
        df = df.drop(columns=cols_to_drop)
    return df


def summarize(df: pd.DataFrame):
    summary = {
        "desc": df.describe(),
        "corr-cross": df.corr(),
    }
    return summary

def error_analysis(df: pd.DataFrame):
    assert {"truth", "prediction", "error"} <= set(df.columns)
    raise NotImplementedError()

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    fig.suptitle(f"{desc or 'unknown model'} : {r2=} {rmse=} {mae=}")
    for ax in axs:
        ax.axis("equal")

    min_v = min(df.truth.min(), df.prediction.min())
    max_v = max(df.truth.max(), df.prediction.max())

    axs[0].plot((min_v, max_v), (min_v, max_v), "-g", linewidth=1)
    plot_data.plot(kind="scatter", x="truth", y="prediction", ax=axs[0])

    axs[1].yaxis.set_label_position("right")
    axs[1].plot((min_v, max_v), (0, 0), "-g", linewidth=1)
    plot_data.plot(kind="scatter", x="truth", y="error", ax=axs[1])
    

In [None]:
df = load(DATA_FILE, filter_query=FILTER_QUERY, cols_to_drop=COLS_TO_DROP)
display(
    f"data n={len(df)}",
    # df.style.hide()
)

if "summary" in ANALYSES:
    summary = summarize(df)
    for name, df in summary.items():
        display(name, df)
        
if "error" in ANALYSES:
    