In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import pandas as pd
import numpy as np
from plotly import graph_objects as go
from plotly.subplots import make_subplots
import xarray as xr
import plotly.io as pio

from agage_archive import Paths
from agage_archive.definitions import instrument_type_definition, unit_translator

paths = Paths()

In [None]:
#pio.renderers.default = "jupyterlab"

In [None]:
# Species names from the output directory structure
species = sorted([f.name for f in paths.output.iterdir() if f.is_dir()])

instrument_number, instrument_number_string = instrument_type_definition()

In [None]:
def load_data():

    data = []

    for file in filenames.value.split("\n"):
        data.append(xr.open_dataset(paths.output / species_dropdown.value / file))
    
    return data

def plot(*args):

    global colour_counter

    datasets = load_data()

    # Create colourblind friendly colour palette using hex codes
    colours = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7"]
    colour_counter = 0
    colour_max = len(colours)

    def plot_combined(ds, fig):

        global colour_counter

        # Get unique instrument types
        instrument_types = set(ds.instrument_type.values)

        # Create a trace for each instrument type
        for instrument_type in list(instrument_types):

            instrument_type_name = list(instrument_number.keys())[list(instrument_number.values()).index(instrument_type)]

            # Get the indices of the instrument type
            ind = ds.instrument_type == instrument_type

            # Add trace
            fig.add_trace(
                go.Scatter(
                    visible=True,
                    line=dict(color=colours[colour_counter % colour_max], width=1),
                    name=f"{ds.attrs['site_code']} - {instrument_type_name}",
                    x=ds.time[ind],
                    y=ds.mf[ind]
                )
            )

            colour_counter += 1
            
        return fig

    def plot_single(ds, fig):

        global colour_counter

        # Add trace
        fig.add_trace(
            go.Scatter(
                visible=True,
                line=dict(color=colours[colour_counter % colour_max], width=1),
                name=ds.attrs["site_code"],
                x=ds.time,
                y=ds.mf)
            )

        colour_counter += 1
        
        return fig


    unit = list(unit_translator.keys())[list(unit_translator.values()).index(datasets[0].mf.units)]

    # Clear previous figures
    clear_output(wait=True)

    # Create figure
    fig = go.Figure()

    # Set y-axis title to be species and units
    fig.update_yaxes(title_text=f"{species_dropdown.value} ({unit})")

    # Make figure less wide
    fig.update_layout(width=600,
                    height=500)

    # Make margins smaller
    fig.update_layout(margin=dict(l=20, r=20, t=20, b=20))

    # Move legend to top-right corner inside plot area
    fig.update_layout(legend=dict(x=0.02, y=1, yanchor="top"))

    # Change theme to simple
    fig.layout.template = "simple_white"

    for ds in datasets:

        # If instrument_type variable is present, split by instrument_type
        if "instrument_type" in ds.variables:

            fig = plot_combined(ds, fig)

        else:
            
            fig = plot_single(ds, fig)

    fig.show()
    
    del fig


In [None]:
# Create dropdown widget
species_dropdown = widgets.Dropdown(
    options=species,
    description='Species:',
    disabled=False,
)

species_dropdown.value = species[0]

display(species_dropdown)

def file_search_species(species):
    species_path = paths.output / species
    files = species_path.glob('*.nc')
    return list(files)

def networks_sites(files):
    networks = []
    sites = []
    for file in files:
        networks.append(file.stem.split('_')[0])
        sites.append(file.stem.split('_')[1])
    return networks, sites

# When species is selected, update checkbox for combinations of network and site
def update_network_site(*args):
    files = file_search_species(species_dropdown.value)
    networks, sites = networks_sites(files)
    network_site.options = sorted([f"{s}, {n}" for (s, n) in zip(sites, networks)])
#    network_site.value = network_site.options[0]

network_site = widgets.SelectMultiple(
    options=[],
    description='Network and Site:',
    disabled=False,
    indent=True
)

species_dropdown.observe(update_network_site, "value")

display(network_site)

def list_filenames(*args):

    filenames = []
    for ns in network_site.value:
        site, network = ns.split(", ")

        filename = f"{network}_{site}_{species_dropdown.value}.nc"

        filenames.append(filename)
    return filenames


filenames = widgets.Textarea(
    value='',
)

def update_filenames(*args):
    filenames.value = "\n".join(list_filenames())

network_site.observe(update_filenames, "value")

display(filenames)

In [None]:
plot_button = widgets.Button(description="Plot")
#plot_button.on_click(plot)

display(plot_button)

output = widgets.Output()

def plot_to_output(sender):
    with output:
        clear_output(True)
        plot()

plot_button.on_click(plot_to_output)