In [17]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from vega_datasets import data
from src.space.DataModel import Attribute

from src.oracle import ColumbusProbOracle, OracleWeight
from src.ProbColumbus import (
    ProbColumbus,
    ColumbusConfig,
    SamplingWeight,
    chart_type,
    agg_type,
    ProbabilisticNode,
)

df = data.movies()
name_attrs = [
    col
    for col in df.columns
    if (df[col].dtype == "object" and df[col].nunique() < 10)
    or df[col].dtype != "object"
]
dic = {col: "C" if df[col].dtype == "object" else "Q" for col in df.columns}
dic[None] = None

df = df[name_attrs]
attrs = [None] + name_attrs
chart_type

['bar', 'point', 'arc', 'rect', 'tick', 'boxplot']

In [18]:
oracle_weight = OracleWeight(
    specificity=1.0,
    uniqueness=1.0,
    coverage=1.0,
    interestingness=1.0,
)
oracle = ColumbusProbOracle(oracle_weight)
prob = ProbColumbus(df, ColumbusConfig())

In [19]:
from IPython.display import clear_output
from collections import Counter
import altair as alt

n_epoch = 1000
n_dashboards = 100
halving = 0.1

means = []
maxs = []

conjugate_priors = SamplingWeight(
    x=np.ones(len(attrs)-1),
    y=np.ones(len(attrs)),
    z=np.ones(len(attrs)),
    ct=np.ones(len(chart_type)),
    at=np.ones(len(agg_type)),
    attr=attrs,
    chart_type=chart_type,
    agg_type=agg_type,
)

def mean(l):
    return sum(l) / len(l)

for epoch in range(n_epoch):
    
    candidate: list = [prob._sample_n(10, conjugate_priors) for _ in range(n_dashboards)]
    scores: list[float] = [prob._infer(dashboard, oracle, ["IMDB_Votes", "rect"]) for dashboard in candidate]
    candi_n_score = list(zip(candidate, scores))
    candi_n_score = sorted(candi_n_score, key=lambda x: x[1], reverse=True)
    
    
    means.append(mean(scores))
    maxs.append(max(scores))
    
    # halving
    halved_candidate = candi_n_score[: int(n_dashboards * halving)]
    
    counters = [Counter() for _ in range(5)]
    counters[0][None] = 0
    counters[1][None] = 0
    counters[2][None] = 0
    
    for dashboard in halved_candidate:
        for chart in dashboard[0]:
            for i in range(3):
                if chart[i] is None:
                    counters[i][None] += 1
                elif chart[i].name in counters[i]:
                    counters[i][chart[i].name] += 1
                else:
                    counters[i][chart[i].name] = 1

            for i in [3,4]:
                if chart[i] in counters[i]:
                    counters[i][chart[i]] += 1
                else:
                    counters[i][chart[i]] = 1
    
    # update conjucate prior with liklihood by counter
    x = np.array([counters[0][attr] for attr in attrs[1:]])
    y = np.array([counters[1][attr] for attr in attrs])
    z = np.array([counters[2][attr] for attr in attrs])
    ct = np.array([counters[3][c] for c in chart_type])
    at = np.array([counters[4][a] for a in agg_type])
    
    conjugate_priors.x += x
    conjugate_priors.y += y
    conjugate_priors.z += z
    conjugate_priors.ct += ct
    conjugate_priors.at += at
    
    # visualize mean and max using altair
    data = pd.DataFrame({"epoch": range(epoch+1), "mean": means, "max": maxs})
    line = alt.Chart(data).mark_line().encode(
        x="epoch",
    )
    clear_output(wait=True)
    print(f"Epoch {epoch}")
    display(line.encode(y=alt.Y("mean", scale=alt.Scale(zero=False))) | line.encode(y=alt.Y("max",scale=alt.Scale(zero=False))))    
    

Epoch 74


KeyboardInterrupt: 