In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
from vega_datasets import data

from src.oracle import Oracle, OracleWeight
from src.generator import Generator
from src.explorer import Explorer, ExplorerConfig, TrainResult
import altair as alt

df = data.movies()
df = df[[
    col
    for col in df.columns
    if (df[col].dtype == "object" and df[col].nunique() < 10)
    or df[col].dtype != "object"
]]


In [None]:
train_results: list[TrainResult] = []
oracle_weight = OracleWeight(
    specificity=1.0,
    interestingness=1.0,
    diversity=1.0,
    coverage=1.0,
    conciseness=1.0,
)
oracle = Oracle(df, oracle_weight)
gen = Generator(df)
expl_config = ExplorerConfig(1000, 100, 0.1)
expl = Explorer(df, expl_config)

In [None]:
from IPython.display import clear_output
# type: ignore

for epoch in range(expl_config.n_epoch):
    train_result = expl.train(gen, oracle, ["IMDB_Votes", "boxplot"])
    train_results.append(train_result)
    if expl.result is None:
        break
    clear_output(wait=True)
    print(f"Epoch {epoch}")
    
    
    data = pd.DataFrame(
        {
            "epoch": range(epoch + 1),
            "scores": [r.scores.mean() for r in train_results],
            "maxs" : [r.scores.max() for r in train_results],
            "specificity": [r.specificity.mean() for r in train_results],
            "interestingness": [r.interestingness.mean() for r in train_results],
            "coverage": [r.coverage.mean() for r in train_results],
            "diversity": [r.diversity.mean() for r in train_results],
            "conciseness": [r.conciseness.mean() for r in train_results],
            "n_charts": [r.n_charts.mean() for r in train_results],
        }
    )
    line = (
        alt.Chart(data)
        .mark_line()
        .encode(
            x="epoch",
        )
        .properties(width=150, height=100)
    )
    display(
        (
            line.encode(y=alt.Y("scores", scale=alt.Scale(zero=False)))
        | line.encode(y=alt.Y("maxs", scale=alt.Scale(zero=False)))
        | line.encode(y=alt.Y("n_charts", scale=alt.Scale(zero=False)))
        )
        & (
            line.encode(y=alt.Y("specificity", scale=alt.Scale(zero=False)))
            | line.encode(y=alt.Y("interestingness", scale=alt.Scale(zero=False)))
            | line.encode(y=alt.Y("coverage", scale=alt.Scale(zero=False)))
            | line.encode(y=alt.Y("diversity", scale=alt.Scale(zero=False)))
            | line.encode(y=alt.Y("conciseness", scale=alt.Scale(zero=False)))
        )
    )
    display(expl.result.display_dashboard(100, 4))
    print(expl.result)