In [10]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd 
from vega_datasets import data

df = data.movies()

# name attributes
name_attrs = [
    col
    for col in df.columns    
    if (df[col].dtype == "object" and df[col].nunique() < 10) or df[col].dtype != "object"
]
df = df[name_attrs]
df.info()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3201 entries, 0 to 3200
Data columns (total 10 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   US_Gross                3194 non-null   float64
 1   Worldwide_Gross         3194 non-null   float64
 2   US_DVD_Sales            564 non-null    float64
 3   Production_Budget       3200 non-null   float64
 4   MPAA_Rating             2596 non-null   object 
 5   Running_Time_min        1209 non-null   float64
 6   Creative_Type           2755 non-null   object 
 7   Rotten_Tomatoes_Rating  2321 non-null   float64
 8   IMDB_Rating             2988 non-null   float64
 9   IMDB_Votes              2988 non-null   float64
dtypes: float64(8), object(2)
memory usage: 250.2+ KB


In [11]:
from src.oracle import ColumbusProbOracle, OracleWeight
from src.ProbColumbus import ProbColumbus, ColumbusConfig, SamplingWeight, chart_type, agg_type

oracle_weight = OracleWeight()
oracle = ColumbusProbOracle(oracle_weight)
prob = ProbColumbus(df, ColumbusConfig())

def func(weight):
    n_samples = 10
    n_dashboards = 20
    
    dashboards = [
        prob.sample_n(n_samples, weight)
        for _ in range(n_dashboards)
    ]
    scores = [prob.infer(dashboard, oracle, []) for dashboard in dashboards]
    sc = np.mean(scores)
    return -sc


In [13]:
from skopt import Optimizer
from skopt.space import Space, Real
from IPython.display import clear_output

attrs = [None] + name_attrs

x = [Real(0.01, 5.0, name=f"x{i}") for i in range(len(attrs[1:]))]
y = [Real(0.01, 5.0, name=f"y{i}") for i in range(len(attrs))]
z = [Real(0.01, 5.0, name=f"z{i}") for i in range(len(attrs))]
ct = [Real(0.01, 5.0, name=f"ct{i}") for i in range(len(chart_type))]
at = [Real(0.01, 5.0, name=f"at{i}") for i in range(len(agg_type))]


opt_weight = Space([
    *x, *y, *z, *ct, *at
])


opt = Optimizer(opt_weight, base_estimator="GP" ,n_initial_points=10, acq_func="EI")

for i in range(30):
    params = opt.ask()
    x = np.array(params)
    weight = SamplingWeight(
        x=x[0 : len(attrs) - 1],
        y=x[len(attrs) - 1 : 2 * len(attrs) - 1],
        z=x[2 * len(attrs) - 1 : 3 * len(attrs) - 1],
        ct=x[3 * len(attrs) - 1 : 3 * len(attrs) - 1 + len(chart_type)],
        at=x[3 * len(attrs) - 1 + len(chart_type) :],
        attr=attrs,
        chart_type=chart_type,
        agg_type=agg_type
    )
    clear_output(wait=True)
    print(i)
    display(weight.visualize())
    y = func(weight)
    opt.tell(params, y)


print(opt.get_result())




29


          fun: -2.2288371643788314
            x: [3.54768930900861, 3.6102686908052832, 3.373735533998444, 2.8654470513109063, 4.1963552918572145, 3.048319294647431, 4.526753488749342, 2.57340252318781, 2.49810466720941, 4.701302058285031, 0.6566559025516141, 3.853023571483312, 4.904664770359111, 3.387761953075547, 2.1166903605823726, 3.5966912002381703, 0.07053707761075764, 1.7676695277990884, 3.095913195947858, 1.2167647238649657, 4.261707259064027, 2.00498455085772, 1.179499373078287, 3.051917236240741, 4.818704013992269, 3.6181817435365895, 2.9864121081363604, 1.6077415551343621, 2.1526362260724743, 1.4480004978548864, 4.145975365548184, 4.3146719201850505, 4.483992503863896, 5.0, 4.76456798297117, 2.238993777276465, 0.31487962540279635, 2.6717730838587435, 1.7954314237528957, 1.7589859077203875, 1.7938349485331375, 4.671849657372013, 4.328405342160824, 1.409802980563086, 2.7124841345327386, 4.782771757837302]
    func_vals: [-2.195e+00 -2.198e+00 ... -2.091e+00 -2.185e+00]
      

In [15]:
x = opt.get_result().x

x = np.array(params)
weight = SamplingWeight(
    x=x[0 : len(attrs) - 1],
    y=x[len(attrs) - 1 : 2 * len(attrs) - 1],
    z=x[2 * len(attrs) - 1 : 3 * len(attrs) - 1],
    ct=x[3 * len(attrs) - 1 : 3 * len(attrs) - 1 + len(chart_type)],
    at=x[3 * len(attrs) - 1 + len(chart_type) :],
    attr=attrs,
    chart_type=chart_type,
    agg_type=agg_type
)
clear_output(wait=True)
display(weight.visualize())

29
