## TODO

- [ ] setup mpl widgets
- [ ] ensure that categoricals still work
- [ ] convert everything
- [ ] figure out how to embed controls into a vbox
- [ ] slider format strings
- [ ] play buttons
- [ ] rename functions? 
    - rename module? `pyplot` -> `ipyplot`
    - `from mpl_interactions import ipyplot as iplt`
    - `iplt.plot()`


In [None]:
%matplotlib ipympl
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

%load_ext autoreload
%autoreload 2
from collections import defaultdict
from collections.abc import Callable
from functools import partial

from mpl_interactions import *

In [None]:
class Controls:
    def __init__(self, out=None, play_buttons=False, play_button_pos="right", **kwargs):
        # it might make sense to also accept kwargs as a straight up arg
        # to allow for passing the dictionary, but then it would need a different name
        # and we'd have to combine dicitonarys which looks like a hassle

        self.out = out
        self.kwargs = kwargs
        slider_format_string = None
        self.slider_format_strings = create_slider_format_dict(slider_format_string, True)
        self.controls = {}
        self.params = {}
        self.figs = defaultdict(list)  # maybe should only store weakrefs?
        self.indices = defaultdict(lambda: 0)
        self._update_funcs = defaultdict(list)
        self.controller_list = []
        self.vbox = widgets.VBox([])
        self.add_kwargs(kwargs, play_buttons, play_button_pos)

    def add_kwargs(self, kwargs, play_buttons=False, play_button_pos="right"):
        """
        If you pass a redundant kwarg it will just be overwritten
        maybe should only raise a warning rather than an error?

        need to implement matplotlib widgets
        also a big question is how to dynamically update the display of matplotlib widgets.
        """
        if isinstance(play_buttons, bool):
            has_play_button = defaultdict(lambda: play_buttons)
        elif isinstance(play_buttons, defaultdict):
            has_play_button = play_buttons
        elif isinstance(play_buttons, dict):
            has_play_button = defaultdict(lambda: False, play_buttons)
        elif isinstance(play_buttons, Iterable) and all([isinstance(p, str) for p in play_buttons]):
            has_play_button = defaultdict(
                lambda: False, dict(zip(play_buttons, [True] * len(play_buttons)))
            )
        else:
            has_play_button = play_buttons
        for k, v in kwargs.items():
            print(has_play_button[k])
            if k in self.params:
                raise ValueError("can't overwrite an existing param in the controller")
            # create slider
            self.params[k], control = kwarg_to_ipywidget(
                k,
                v,
                partial(self.slider_updated, key=k),
                "{:.2f}",
                play_button=has_play_button[k],
                play_button_pos=play_button_pos,
            )
            if control:
                self.controls[k] = control
                self.indices
                self.vbox.children = list(self.vbox.children) + [control]

    def slider_updated(self, change, key, values):
        """
        gotta also give the indices in order to support hyperslicer without horrifying contortions
        """
        with self.out:
            self.params[key] = values[change["new"]]
            self.indices[key] = change["new"]
            for f, params in self._update_funcs[key]:
                ps = {}
                idxs = {}
                for k in params:
                    ps[k] = self.params[k]
                    idxs[k] = self.indices[k]
                f(params=ps, indices=idxs)
            for f in self.figs[key]:
                f.canvas.draw_idle()

    def register_function(self, f, fig, params=None):
        """
        if params is None use the entire current set of params
        """
        if params is None:
            params = self.params.keys()
        # listify to ensure it's not a reference to dicts keys
        # bc that's mutable
        params = list(params)
        for p in params:
            self._update_funcs[p].append((f, params))
            self.figs[p].append(fig)  # maybe should use a weakref?

    def __getitem__(self, key):
        """
        hack to allow calls like
        interactive_plot(...beta=(0,1), controls = controls["tau"])
        also allows [None] to grab None of the current params
        to imply that we only want tau from the existing set of commands
        """

        # make sure keys is a list
        # bc in gogogo_controls it may get added to another list
        if isinstance(key, str):
            key = [key]
        elif key is None:
            key = []
        return self, key

    def _ipython_display_(self):
        display(self.vbox)


#         display(widgets.VBox(list(self.controls.values())))
# kwargs = {'param1':(0,1), 'param2':(1,10), 'tau':(0,2*np.pi)}
# out = widgets.Output()
# controls = Controls(kwargs, out)
# display(controls)
# display(out)
# def f(params, indices):
#     print(params, indices)
# # controls.register_function(f,['param1'])
# # widgets.HBox([widgets.VBox(controls),out])

In [None]:
def gogogo_controls(kwargs, controls, display_controls, play_buttons, play_button_pos):
    if controls:
        if isinstance(controls, tuple):
            # it was indexed by the user when passed in
            extra_keys = controls[1]
            controls = controls[0]
            controls.add_kwargs(kwargs, play_buttons, play_button_pos)
            params = {k: controls.params[k] for k in list(kwargs.keys()) + list(extra_keys)}
        else:
            controls.add_kwargs(kwargs)
            params = controls.params
    else:
        out = widgets.Output()
        controls = Controls(
            out, play_buttons=play_buttons, play_button_pos=play_button_pos, **kwargs
        )
        params = controls.params
        if display_controls:
            display(controls)
    return controls, params

In [None]:
out = widgets.Output()


def interactive_imshow(
    X,
    cmap=None,
    norm=None,
    aspect=None,
    interpolation=None,
    alpha=None,
    vmin=None,
    vmax=None,
    origin=None,
    extent=None,
    autoscale_cmap=True,
    filternorm=True,
    filterrad=4.0,
    resample=None,
    url=None,
    ax=None,
    slider_format_string=None,
    title=None,
    display=True,
    force_ipywidgets=False,
    play_buttons=False,
    play_button_pos="right",
    controls=None,
    display_controls=True,
    **kwargs,
):
    params = {}
    ipympl = notebook_backend()
    fig, ax = gogogo_figure(ipympl, ax)
    use_ipywidgets = ipympl or force_ipywidgets
    slider_format_strings = create_slider_format_dict(slider_format_string, use_ipywidgets)

    controls, params = gogogo_controls(
        kwargs, controls, display_controls, play_buttons, play_button_pos
    )

    def update(params, indices):
        if title is not None:
            ax.set_title(title.format(**params))

        if isinstance(X, Callable):
            new_data = np.asarray(X(**params))
            im.set_data(new_data)
            if autoscale_cmap and (new_data.ndim != 3) and vmin is None and vmax is None:
                im.norm.autoscale(new_data)
        if isinstance(vmin, Callable):
            im.norm.vmin = vmin(**params)
        if isinstance(vmax, Callable):
            im.norm.vmax = vmax(**params)

    controls.register_function(update, fig, params.keys())

    # make it once here so we can use the dims in update
    new_data = callable_else_value(X, {k: controls.params[k] for k in kwargs})
    im = ax.imshow(
        new_data,
        cmap=cmap,
        norm=norm,
        aspect=aspect,
        interpolation=interpolation,
        alpha=alpha,
        vmin=callable_else_value(vmin, params),
        vmax=callable_else_value(vmax, params),
        origin=origin,
        extent=extent,
        filternorm=filternorm,
        filterrad=filterrad,
        resample=resample,
        url=url,
    )
    # this is necessary to make calls to plt.colorbar behave as expected
    ax._sci(im)
    if title is not None:
        ax.set_title(title.format(**params))
    return controls

In [None]:
def interactive_plot2(
    f,
    x=None,
    xlim="stretch",
    ylim="stretch",
    slider_format_string=None,
    plot_kwargs=None,
    title=None,
    ax=None,
    force_ipywidgets=False,
    play_buttons=False,
    play_button_pos="right",
    controls=None,
    display_controls=True,
    **kwargs,
):
    ipympl = notebook_backend()
    use_ipywidgets = ipympl or force_ipywidgets
    fig, ax = gogogo_figure(ipympl, ax=ax)
    funcs = np.atleast_1d(f)
    slider_format_strings = create_slider_format_dict(slider_format_string, use_ipywidgets)
    controls, params = gogogo_controls(
        kwargs, controls, display_controls, play_buttons, play_button_pos
    )

    def update(params, indices):
        # update plot
        for i, f in enumerate(funcs):
            if x is not None and not indexed_x:
                lines[i].set_data(x, f(x, **params))
            elif indexed_x:
                lines[i].set_data(x, f(**params))
            else:
                lines[i].set_data(*f(**params))

        cur_xlims = ax.get_xlim()
        cur_ylims = ax.get_ylim()
        ax.relim()  # this may be expensive? don't do if not necessary?
        if ylim == "auto":
            ax.autoscale_view(scalex=False)
        elif ylim == "stretch":
            new_lims = [ax.dataLim.y0, ax.dataLim.y0 + ax.dataLim.height]
            new_lims = [
                new_lims[0] if new_lims[0] < cur_ylims[0] else cur_ylims[0],
                new_lims[1] if new_lims[1] > cur_ylims[1] else cur_ylims[1],
            ]
            ax.set_ylim(new_lims)
        if xlim == "auto":
            ax.autoscale_view(scaley=False)
        elif xlim == "stretch":
            new_lims = [ax.dataLim.x0, ax.dataLim.x0 + ax.dataLim.width]
            new_lims = [
                new_lims[0] if new_lims[0] < cur_xlims[0] else cur_xlims[0],
                new_lims[1] if new_lims[1] > cur_xlims[1] else cur_xlims[1],
            ]
            ax.set_xlim(new_lims)
        if title is not None:
            ax.set_title(title.format(**params))

    controls.register_function(update, fig, params.keys())

    indexed_x = False
    if x is not None:
        x = np.asarray(x)
        if x.ndim != 1:
            raise ValueError(f"x must be None or be 1D but is {x.ndim}D")
    else:
        # call f once to determine it returns x
        out = np.asarray(f(**params))
        if len(out.shape) != 2 or (len(out.shape) == 2 and out.shape[0] == 1):
            # probably should use arange to set the x values
            indexed_x = True
            x = np.arange(out.size)

    if plot_kwargs is None:
        plot_kwargs = []
        for i in range(len(funcs)):
            plot_kwargs.append({})
    else:
        plot_kwargs = np.atleast_1d(plot_kwargs)
        if not len(plot_kwargs) == len(funcs):
            raise ValueError(
                "If using multiple functions"
                " then plot_kwargs must be a list"
                " of the same length or None."
            )

    # make sure plot labels make sense
    for i in range(len(funcs)):
        if "label" not in plot_kwargs[i]:
            plot_kwargs[i]["label"] = funcs[i].__name__

    lines = []
    for i, f in enumerate(funcs):
        if x is not None and not indexed_x:
            lines.append(ax.plot(x, f(x, **params), **plot_kwargs[i])[0])
        elif indexed_x:
            lines.append(ax.plot(x, f(**params), **plot_kwargs[i])[0])
        else:
            lines.append(ax.plot(*f(**params), **plot_kwargs[i])[0])
    if not isinstance(xlim, str):
        ax.set_xlim(xlim)
    if not isinstance(ylim, str):
        ax.set_ylim(ylim)
    if title is not None:
        ax.set_title(title.format(**params))

    # make sure the home button will work
    if hasattr(fig.canvas, "toolbar") and fig.canvas.toolbar is not None:
        fig.canvas.toolbar.push_current()

    return controls

In [None]:
fig, ax = plt.subplots()

x = np.linspace(0, np.pi, 100)
tau = np.linspace(0.5, 10, 100)
beta = np.linspace(1, 10, 100)


def f1(x, tau, beta):
    return np.sin(x * tau) * x * beta


def f2(x, tau, beta):
    return np.sin(x * beta) * x * tau


def f3(x, tau, beta):
    return 3 * np.sin(x * beta) * x * tau


def f4(x, tau, beep):
    return 3 * np.sin(x * beta) * x**beep


controls = interactive_plot2([f1, f2], x=x, tau=tau, beta=beta, play_buttons=True)
heck = interactive_plot2(f3, x=x, controls=controls)
# display(controls)
controls.out

In [None]:
heck = interactive_plot2(f4, x=x, beep=(0, 1), controls=controls["tau"])

In [None]:
x = np.linspace(0, np.pi, 200)
y = np.linspace(0, 10, 200)
X, Y = np.meshgrid(x, y)


def f(param1, param2):
    return np.sin(X) * param2 + np.exp(np.cos(Y * param1)) + param2


fig2, ax = plt.subplots()
controls = interactive_imshow(f, param1=(-5, 5), param2=(-3, 12), controls=controls[None])

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
controls = interactive_plot2([f1, f2], x=x, tau=tau, beta=beta, ax=ax1)
controls = interactive_imshow(
    f, param1=(-5, 5), param2=(-3, 12), controls=controls[None], ax=ax2, play_buttons=True
)