In [12]:
import ipywidgets as widgets
from io import BytesIO, StringIO
import base64
from ipywidgets import HTML
from IPython.display import display
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyrolite.geochem
from pyrolite.plot.color import process_color
from pyrolite.util.plot.legend import proxy_line
from pyrolite.util.plot.style import mappable_from_values
from pyrolite.plot.color import get_cmode
from traitlets import Int


#############################


def _import_frame(content):
    df = pd.read_csv(BytesIO(content))
    df.columns = [c.strip() for c in df.columns]
    df.pyrochem.compositional = df.pyrochem.compositional.apply(
        pd.to_numeric, axis=1, errors="coerce"
    )
    return df


def get_contents():
    return [(f, _import_frame(d["content"])) for (f, d) in uploader.value.items()]


def process_frame(df):
    non_comp = [c for c in df.columns if c not in df.pyrochem.list_compositional]
    out = df.loc[:, non_comp + df.pyrochem.list_compositional]
    if transform_dropdown.value:
        comptransform = getattr(
            df.pyrochem.compositional.pyrocomp, transform_dropdown.value
        )()
        out = out.join(comptransform)
    return out


def plot(df, *vars, color=None):
    return df[vars].pyroplot.scatter(c=color).figure


#############################


download_button_html = """<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body>
<a download="{filename}" href="data:text/csv;base64,{payload}" download>
<button class="p-Widget jupyter-widgets jupyter-button widget-button mod-success">{buttonname}</button>
</a>
</body>
</html>
"""


def get_tab(name, df, index):
    def get_raw_frame():
        name, df = get_contents()[index]
        return df

    def get_processed_frame():
        df = get_raw_frame()
        return process_frame(df)

    def get_export():
        name, df = get_contents()[index]

        filename = Path(name).stem + "_" + transform_dropdown.value + ".csv"

        procdf = get_processed_frame()
        output_buffer = StringIO()
        procdf.to_csv(output_buffer)

        payload = base64.b64encode(output_buffer.getvalue().encode()).decode()

        return widgets.HBox(
            [
                widgets.Label("Download: "),
                HTML(
                    download_button_html.format(
                        payload=payload, filename=filename, buttonname=filename
                    )
                ),
            ]
        )

    def update_plot_options(df):
        _x, _y, _color = xvar.value, yvar.value, color.value
        with plotbox.hold_sync():  # not sure if this works..
            xvar.options, yvar.options, color.options = (
                (None, *df.select_dtypes("float").columns),
                (None, *df.select_dtypes("float").columns),
                (None, *df.columns),
            )
            for start, var in zip([_x, _y, _color], [xvar, yvar, color]):
                if start in var.options:
                    if not var.value == start:
                        var.value = start

    def on_upload_change(change):
        table_output.clear_output()
        figure_output.clear_output()
        download_output.clear_output()
        with table_output:
            if transform_dropdown.value:
                df = get_processed_frame()
            else:
                df = get_raw_frame()
            update_plot_options(df)
            display(df)

    def on_transform_change(change):
        table_output.clear_output()
        figure_output.clear_output()
        download_output.clear_output()
        if uploader.value:
            with table_output:
                if transform_dropdown.value:
                    df = get_processed_frame()
                else:
                    df = get_raw_frame()
                display(df)

            with download_output:
                if transform_dropdown.value:
                    display(get_export())

            update_plot_options(df)

    def on_plotconfig_change(change):
        figure_output.clear_output()
        if uploader.value:
            with figure_output:
                max_legend_length = 12
                if xvar.value is not None and yvar.value is not None:
                    frame = get_processed_frame()
                    plt.close()
                    ax = frame.loc[:, [xvar.value, yvar.value]].pyroplot.scatter(
                        c=None if color.value is None else frame[color.value],
                        figsize=(12, 6),
                    )
                    if color.value is not None:
                        if get_cmode(frame[color.value]) == "categories":
                            u = frame[color.value].unique()
                            if len(u) < max_legend_length * 3:
                                proxies = {
                                    k: proxy_line(marker="D", lw=0, color=c)
                                    for (k, c) in zip(u, process_color(c=u)["c"])
                                }
                                ax.legend(
                                    proxies.values(),
                                    proxies.keys(),
                                    fontsize=14,
                                    markerscale=1.5,
                                    title=color.value,
                                    title_fontsize=16,
                                    ncol=np.ceil(len(u) / max_legend_length).astype(
                                        int
                                    ),
                                )
                        elif get_cmode(frame[color.value]) == "value_array":
                            ax.figure.colorbar(
                                mappable_from_values(frame[color.value]),
                                ax=ax,
                                label=color.value,
                            )
                        else:
                            pass
                else:
                    fig, ax = plt.subplots(
                        1, figsize=(12, 6)
                    )  # a default plot to sub in

                display(ax.figure)

    table_output = widgets.Output()
    download_output = widgets.Output()
    figure_output = widgets.Output()

    xvar = widgets.Dropdown(options=[None], value=None)
    yvar = widgets.Dropdown(options=[None], value=None)
    color = widgets.Dropdown(options=[None], value=None)

    # events ##########################################################################

    uploader.observe(on_upload_change, names="value")
    uploader.observe(on_transform_change, names="value")
    uploader.observe(on_plotconfig_change, names="value")

    transform_dropdown.observe(on_transform_change, names="value")
    for p in [xvar, yvar, color]:
        p.observe(on_plotconfig_change, names="value")

    # layout ##########################################################################
    plotbox = widgets.HBox(
        [
            widgets.VBox(
                [
                    widgets.Label("X:"),
                    xvar,
                    widgets.Label("Y:"),
                    yvar,
                    widgets.Label("Color:"),
                    color,
                ]
            ),
            figure_output,
        ]
    )

    tabbox = widgets.Accordion(
        children=[plotbox, widgets.VBox([download_output, table_output])],
        titles=("Plot", "Data"),
    )
    for ix, (t, name) in enumerate(zip(tabbox.children, ("Plot", "Data"))):
        tabbox.set_title(ix, name)

    on_upload_change({})
    on_plotconfig_change({})
    return tabbox


def create_tabs(change):
    tabs_output.clear_output()
    # uploader._trait_notifiers['_counter']['change'][0](uploader, {'name': 'value', 'old': uploader._counter, 'new': len(uploader.value)})

    contents = get_contents()
    with tabs_output:
        tabs = widgets.Tab(
            [get_tab(name, df, index) for index, (name, df) in enumerate(contents)]
        )
        for ix, (t, (name, df)) in enumerate(zip(tabs.children, contents)):
            tabs.set_title(ix, name)
        display(tabs)


uploader = widgets.FileUpload(  # note that the counter for file upload is broken, but should be fixed for ipywidgets 8.0
    accept=".csv",  # Accepted file extension e.g. '.txt', '.pdf', 'image/*', 'image/*,.pdf'
    multiple=True,  # True to accept multiple files upload else False
    name="Upload",
)

transform_dropdown = widgets.Dropdown(
    options=[
        (" - ", ""),
        ("Additive Log-Ratio", "ALR"),
        ("Centred Log-Ratio", "CLR"),
        ("Isometric Log-Ratio", "ILR"),
        ("Spherical Coordinates", "sphere"),
    ],
    value="",
)

uploader.observe(create_tabs, names="value")
tabs_output = widgets.Output()

topbox = widgets.VBox(
    [
        widgets.HBox(
            [
                widgets.Label("Upload CSV file:"),
                uploader,
                widgets.Label("Select Transform:"),
                transform_dropdown,
            ]
        ),
        tabs_output,
    ]
)


display(topbox)

VBox(children=(HBox(children=(Label(value='Upload CSV file:'), FileUpload(value={}, accept='.csv', description…