# What is this?
This notebook is an interactive interface for AEPsych, a Python package for adaptive experimetation in psychophysics and related domains. AEPsych utilizes active learning to efficiently explore parameter spaces, allowing experimenters to find just-noticeable-differences (or other quantities of interest) in far fewer trials than traditional methods. This notebook will allow you to use AEPsych without having to write any code.

# Instructions
1. Run the codeblock below. You will see a set of widgets appear.
2. If you are resuming a previous session, you can use the "Resume Session" button to upload a saved .pkl file and resume with all of your settings and data intact.
3. If you are starting from scratch, use the widgets to change AEPsych's settings. Here is an explanation of the settings:

    **Strategy**: There are three strategies for exploring the parameter space:
    
    *Threshold Finding*: AEPsych will try to find the set of parameter values at which the outcome probability equals some target value.

    *Exploration*: AEPsych will try to model the outcome at every point in the parameter space.

    *Optimization*: AEPsych will try to find the parameter values that maximize the probability of an outcome of 1.

    **Threshold**: Sets the target value for the *Threshold Finding* strategy. It is ignored by the other strategies.

    **Initialization Trials**: Sets the number of initialization trials before the model-based strategy begins. Parameter values during these trials are generated quasi-randomly. After the model has been initialized, it will begin to choose parameter values according to the strategy you have picked.   
 
     **Outcome Labels**: These determine the labels of your outcomes. Currently AEPsych only supports binary outcomes, so one outcome will be coded as a 0 in the data, and the other outcome will be coded as a 1. Pay attention to how you label your outcomes! The *Optimization* strategy and the *Monotonic* parameter settings depend on which outcome is labeled as a 1.
 
    **Parameters**: These settings control the parameter space that AEPsych will explore. Use the "Add Parameter" and "Remove Parameter" buttons to add or remove parameters to the experiment. For each parameter you can specify its name, bounds, and whether or not it should be monotonically increasing with the probability of an outcome of 1 (in other words, you can specify that increasing this parameter never decreases the probability of an outcome of 1). Currently AEPsych only supports continuous parameters.


4. Click the "Start AEPsych". A new set of widgets will appear.
5. You will see the set of parameters AEPsych recommends you try. To see a different set of parameters, click "Next Parameters". It may take a few seconds for the parameters to appear.
6. After testing the parameters, enter the outcome, and click "Update Model" to update the model and see the next set of recommended parameters. You can also enter data at any time with any parameter values; you are not restricted to only using the parameters that AEPsych recommends. 
7. You can also upload data from files using the "Upload Data" button. The data should be stored in .csv files according to the following template: 

```
parametername1,parametername2,outcome
1.1,0.4,1
0.25,1,0
```

8. After you enter data, a table containing each set of parameters and their outcome will appear. You can download this data by clicking the "aepsych_data.csv" link.
9. After the AEPsych model has been initialized, a plot of the model's posterior will appear to the right of the data table. Currently plotting only works for 1 or 2-dimensional problems.
10. To save your work, you can download the "aepsych_server.pkl" link at the top and upload it again later.
11. If you ever need to start over, simply rerun the code block.


In [1]:
import io
import warnings

import dill
import ipywidgets as widgets
import matplotlib.pyplot as plt
import pandas as pd
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
from aepsych.plotting import plot_strat
from aepsych.server import AEPsychServer
from IPython.display import FileLink, clear_output, display

plt.rcParams["figure.figsize"] = (10, 10)
warnings.filterwarnings("ignore")


server = AEPsychServer()
par_precision = 3
dim = 0
inducing_scale = 50
acq_dict = {
    "Exploration": "MonotonicMCPosteriorVariance",
    "Optimization": "qNoisyExpectedImprovement",
    "Threshold Finding": "MonotonicMCLSE",
}
style = {"description_width": "initial"}
csv_file_name = "aepsych_data.csv"
strat_file_name = "aepsych_server.pkl"


def add_param(b):
    global dim
    dim += 1
    hb = widgets.HBox(
        [
            widgets.Text(f"par{dim}", description="Name", style=style),
            widgets.FloatText(
                0.0, description="Lower Bound:", step=10 ** -par_precision, style=style
            ),
            widgets.FloatText(
                1.0, description="Upper Bound:", step=10 ** -par_precision, style=style
            ),
            widgets.Checkbox(value=False, description="Monotonic"),
        ]
    )
    params_boxes.children = tuple(list(params_boxes.children) + [hb])
    pars = [child.children[0].value for child in params_boxes.children]
    lbs = [child.children[1].value for child in params_boxes.children]
    ubs = [child.children[2].value for child in params_boxes.children]


def rem_param(b):
    global dim
    if dim > 1:
        dim -= 1
        params_boxes.children = tuple(list(params_boxes.children[:-1]))


def start_server(b):
    config = make_config()
    server.handle_setup_v01({"message": {"config_str": config}})
    server.one_outcome = one_outcome.value
    server.zero_outcome = zero_outcome.value

    with data_output:
        clear_output()

    with plot_output:
        clear_output()
        
    tell_boxes.children = [
        widgets.BoundedFloatText(
            lb, description=par, min=lb, max=ub, step=10 ** -par_precision, style=style
        )
        for par, lb, ub in zip(server.parnames, server.strat.lb, server.strat.ub)
    ]
    outcome_box.options = [('', None), (zero_outcome.value, 0), (one_outcome.value, 1)]
    clear_output()
    display(server_download, params_cont, plot_data_cont)
    get_next(None)


def resume_server(change):
    global server
    for name, csv in server_uploader.value.items():
        with io.BytesIO(csv["content"]) as f:
            server = dill.load(f)
            # When the server is pickled, it deletes these attributes.
            # This is an ugly hack around that.
            server.socket = None
            server.db = None

            tell_boxes.children = [
                widgets.BoundedFloatText(
                    lb,
                    description=par,
                    min=lb,
                    max=ub,
                    step=10 ** -par_precision,
                    style=style,
                )
                for par, lb, ub in zip(
                    server.parnames, server.strat.lb, server.strat.ub
                )
            ]
            zero_outcome.value = server.zero_outcome
            one_outcome.value = server.one_outcome
            outcome_box.options = [('', None), (zero_outcome.value, 0), (one_outcome.value, 1)]
            
        clear_output()
        display(server_download, params_cont, plot_data_cont)
        display_data()
        display_plot()
        get_next(None)

    server_uploader.value.clear()


def make_config():
    dim = len(params_boxes.children)
    pars = [child.children[0].value for child in params_boxes.children]
    parnames = f"[{','.join(par for par in pars)}]"
    lbs = [child.children[1].value for child in params_boxes.children]
    ubs = [child.children[2].value for child in params_boxes.children]
    monotonic = [
        i for i, child in enumerate(params_boxes.children) if child.children[3]
    ]
    target = threshold_box.value
    n_sobol = n_sobol_box.value
    acq = acq_dict[strategy_btns.value]
    model = "GPClassificationModel" if acq == "qNoisyExpectedImprovement" else "MonotonicRejectionGP"
    generator = "OptimizeAcqfGenerator" if acq == "qNoisyExpectedImprovement" else "MonotonicRejectionGenerator"

    config = f"""
        [common]
        parnames = {parnames}
        lb = {lbs}
        ub = {ubs}
        outcome_type = single_probit
        target = {target}
        strategy_names = [init_strat, opt_strat]

        [init_strat]
        n_trials = {n_sobol}
        generator = SobolGenerator

        [opt_strat]
        n_trials = -1
        refit_every = 1
        generator = {generator}

        [experiment]
        acqf = {acq}
        model = {model}
        
        [SobolGenerator]
        n_points = {n_sobol}
        
        [GPClassificationModel]
        inducing_size = {inducing_scale*dim} #TODO: find a better way to scale this

        [MonotonicRejectionGP]
        inducing_size = {inducing_scale*dim} #TODO: find a better way to scale this
        mean_covar_factory = monotonic_mean_covar_factory
        monotonic_idxs = {monotonic}
        
        [metadata]
        experiment_name = myname
        experiment_description = this is a cool experiment
        experiment_id = 21
        somemetadata = this thingy
        somemetadatatwo = that thingy
        """
    return config


def tell_model(b):
    if outcome_box.value is not None:
        with upload_output:
            clear_output()
        params = {child.description: child.value for child in tell_boxes.children}
        outcome = outcome_box.value
        server.tell(outcome, params)
        for child in tell_boxes.children:
            child.value = child.min
        outcome_box.value = None
        get_next(None)
        display_data()
        display_plot()
    else:
        with upload_output:
            clear_output()
            print("Select an outcome for this set of parameters!")


def get_next(b):
    tell_btn.disabled = True
    ask_btn.disabled = True
    uploader.disabled = True
    outcome_box.disabled = True
    for child in tell_boxes.children:
        child.disabled = True

    if server.strat.x is None and server.strat._count >= n_sobol_box.value:
        n_sobol_box.value = 1
        config = make_config()
        server.configure(config_str=config)
        next_pars = server.ask()

    else:
        next_pars = server.ask()

    for child, value in zip(tell_boxes.children, next_pars.values()):
        child.value = round(value[0], par_precision)

    tell_btn.disabled = False
    ask_btn.disabled = False
    uploader.disabled = False
    outcome_box.disabled = False
    for child in tell_boxes.children:
        child.disabled = False
    write_server()


def write_server():
    server_download.disabled = True
    with open(strat_file_name, "wb") as f:
        dill.dump(server, f)
    server_download.disabled = False


def display_data():
    if server.strat.x is not None:
        data = {par: server.strat.x[:, i] for i, par in enumerate(server.parnames)}
        data["outcome"] = server.strat.y
        data = pd.DataFrame(data)
        data.to_csv(csv_file_name, index=False)
        with data_output:
            clear_output()
            display(FileLink(csv_file_name), data)


def display_plot():
    with plot_output:
        clear_output()
        if server.strat.dim <= 2:
            if server.strat._strat_idx > 0:
                xlabel = server.parnames[0]
                ylabel = server.parnames[1] if server.strat.dim == 2 else None
                yes_label = one_outcome.value
                no_label = zero_outcome.value
                acqf = server.strat._strat.generator.acqf
                thresh = (
                    threshold_box.value
                    if acqf == MonotonicMCLSE
                    else None
                )
                plot_strat(
                    server.strat, xlabel=xlabel, ylabel=ylabel, target_level=thresh,
                    yes_label=yes_label, no_label=no_label
                )
            else:
                print(
                    "\n\n\n\n\n Initializing model. Collect more data to plot posterior."
                )
        else:
            print("Plotting currently only works for <=2D")


def mass_tell(change):
    for name, csv in uploader.value.items():
        with io.BytesIO(csv["content"]) as f:
            try:
                data = pd.read_csv(f)
                for i, row in data.iterrows():
                    server.tell(
                        row["outcome"], {par: row[par] for par in server.parnames}
                    )
                    idx = server.strat._strat_idx
                    server.strat.strat_list[idx]._count += 1
                with upload_output:
                    clear_output()
                get_next(None)
                display_data()
                display_plot()
            except:
                with upload_output:
                    clear_output()
                    print("Data is improperly formatted!")
    uploader.value.clear()
    write_server()


server_uploader = widgets.FileUpload(
    description="Resume Session", accept=".pkl", multiple=False, style=style
)
server_uploader.observe(resume_server, names="_counter")

outcome_label = widgets.Label(value='Outcome Labels:')
zero_outcome = widgets.Text("No Trial", description="0: ", style=style)
one_outcome = widgets.Text("Yes Trial", description="1: ", style=style)
outcomes_labels = widgets.VBox([outcome_label, zero_outcome, one_outcome])

params_label = widgets.Label(value="Parameters:")
params_boxes = widgets.VBox([])
add_param(None)

add_param_btn = widgets.Button(description="Add Parameter")
add_param_btn.on_click(add_param)

rem_param_btn = widgets.Button(description="Remove Parameter")
rem_param_btn.on_click(rem_param)

btns = widgets.HBox([add_param_btn, rem_param_btn])

strategy_btns = widgets.RadioButtons(
    options=["Threshold Finding", "Exploration", "Optimization"],
    value="Threshold Finding",
    description="Strategy:",
)

threshold_box = widgets.BoundedFloatText(
    value=0.75, min=0, max=1.0, step=0.05, description="Threshold:"
)

n_sobol_box = widgets.BoundedIntText(
    value=10, min=0, description="Initialization Trials:", style=style
)

start_server_btn = widgets.Button(description="Start AEPsych")
start_server_btn.on_click(start_server)

strat_settings = widgets.HBox([strategy_btns, threshold_box, n_sobol_box])

config = make_config()
server.configure(config_str=config)

tell_boxes = widgets.VBox()
outcome_box = widgets.Dropdown(
    options=[('No Trial', 0), ('Yes Trial', 1), ('', None)],
    value=None,
    description='Outcome:',
)

ask_btn = widgets.Button(description="Next Parameters")
ask_btn.on_click(get_next)

tell_btn = widgets.Button(description="Update Model")
tell_btn.on_click(tell_model)

uploader = widgets.FileUpload(description="Upload Data", accept=".csv", multiple=False)
uploader.observe(mass_tell, names="_counter")

server_download = FileLink(strat_file_name)

upload_output = widgets.Output()

ask_tell_cont = widgets.HBox([ask_btn, tell_btn, uploader, upload_output])

params_cont = widgets.VBox([ask_tell_cont, widgets.HBox([tell_boxes, outcome_box])])

data_output = widgets.Output()
plot_output = widgets.Output()
plot_data_cont = widgets.HBox([data_output, plot_output])

server_btns = widgets.HBox([start_server_btn, server_uploader])

display(
    server_btns, strat_settings, outcomes_labels, btns, params_boxes,
)


2022-07-05 19:09:13,656 [INFO   ] Found DB at ./databases/default.db, appending!


KeyError: 'type'