In [None]:
import matplotlib.pyplot as plt
import panel as pn

pn.extension(template="fast")
pn.config.comms = "ipywidgets"
import xarray as xr

import cfspopcon
from cfspopcon.unit_handling import ureg
from cfspopcon.named_options import ConfinementScaling, Impurity
import yaml
from typing import List
import git
from bokeh.models.formatters import PrintfTickFormatter


def get_git_hash():
    repo = git.Repo(search_parent_directories=True)
    sha = repo.head.commit.hexsha
    short_sha = repo.git.rev_parse(sha, short=True)
    return short_sha


PROJECT_NAME = "BARC"
PROJECT_DIR = f"../../example_cases/{PROJECT_NAME}"

# Load CFSPOPCON stuff.
input_parameters, algorithm, points = cfspopcon.read_case(PROJECT_DIR)
plot_style = cfspopcon.read_plot_style(f"{PROJECT_DIR}/plot_popcon.yaml")
algorithm.validate_inputs(input_parameters)
dataset = xr.Dataset(input_parameters)

# Load gui config and commit hash.
with open(f"{PROJECT_DIR}/gui.yaml", "r") as file:
    gui_config = yaml.safe_load(file)
short_sha = get_git_hash()

In [None]:
def build_design_sliders(gui_config: dict) -> List[pn.widgets.FloatSlider]:
    """Build sliders for standard design variables that can be updated via the dataset.

    Args:
        gui_config (dict): data from gui.yaml.

    Returns:
        List[pn.widgets.FloatSlider]: List of float sliders
    """
    sliders = {}
    units = {}
    for slide_conf in gui_config["sliders"]:
        dataset[slide_conf["name"]] = dataset[slide_conf["name"]].astype(float)

        # Some entries may have a "display name" that we will use instead for the slider.
        display_name = slide_conf["display_name"] if "display_name" in slide_conf.keys() else slide_conf["name"]

        # Build the slider.
        slider = pn.widgets.FloatSlider(
            name=display_name,
            start=slide_conf["start"],
            end=slide_conf["end"],
            value=float(dataset[slide_conf["name"]].pint.to(slide_conf["unit"]).values),
            step=slide_conf["step_size"],
        )
        sliders[slide_conf["name"]] = slider
        units[slide_conf["name"]] = slide_conf["unit"]
    return sliders, units


def build_impurity_sliders(gui_config: dict):
    def build_impurity_slider(imp_config: dict):
        imp_name = imp_config["name"]
        initial_value = dataset["impurities"].sel(dim_species=Impurity._member_map_[imp_name]).values
        slider = pn.widgets.FloatSlider(
            name=imp_name,
            start=float(imp_config["start"]),
            end=float(imp_config["end"]),
            step=float(imp_config["step_size"]),
            value=float(initial_value),
            format=PrintfTickFormatter(format="%.2e"),
        )
        return slider

    sliders = {}
    units = {}
    for imp_config in gui_config["impurity_concentration_sliders"]:
        imp_name = imp_config["name"]
        sliders[imp_name] = build_impurity_slider(imp_config)
        units[imp_name] = ""
    return sliders, units


# Build sidebar content.
# Select confinement scaling.
confinement_names = ConfinementScaling._member_map_.keys()
select_confinement = pn.widgets.Select(name="Confinement Scaling", options=list(confinement_names))

# Build sliders for design parameters and impurity settings.
sliders, units = build_design_sliders(gui_config)
imp_sliders, imp_units = build_impurity_sliders(gui_config)
sliders = {**sliders, **imp_sliders}
units = {**units, **imp_units}

# Build the sliders column on the sidebar.
sliders_column = pn.Column(
    select_confinement,
    pn.panel("## Parameters", margin=(0, 0, 0, 0)),
    *list(sliders.values()),
    margin=(0, 0, 0, 0),
    sizing_mode="stretch_width",
)
sliders_column.servable(target="sidebar")

In [None]:
def make_plot(select_confinement: str, **kwargs):
    """"""
    dataset["energy_confinement_scaling"] = ConfinementScaling._member_map_[select_confinement]
    for key, value in kwargs.items():
        if key in imp_sliders.keys():
            key_value = Impurity._member_map_[key]
            dataset["impurities"].loc[dict(dim_species=key_value)] = value * ureg(units[key])
        else:
            dataset[key] = value * ureg(units[key])
    algorithm.update_dataset(dataset, in_place=True)
    fig, ax = cfspopcon.plotting.make_plot(dataset, plot_style, points={}, title=f"{PROJECT_NAME} (commit: {short_sha})", output_dir=None)
    return pn.pane.Matplotlib(fig)


sliders_throttled = {k: v.param.value_throttled for k, v in sliders.items()}
interactive_plot = pn.param.ParamFunction(
    pn.bind(make_plot, select_confinement=select_confinement, **sliders_throttled), loading_indicator=True
)
app = pn.Column(interactive_plot)
app.servable(target="main")