## Imports

In [1]:
import cProfile
import json
from enum import Enum
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.graph_objs.scatter as s
import plotly.graph_objs.scatter3d as s3
from plotly.graph_objs.layout import Annotation, Legend
from plotly.subplots import make_subplots

## Flags

In [2]:
render_plots = True
generate = False
info_annotations = True

tl = True
tr = True
bl = True
br = True

## Path Setup

In [3]:
working_dir = Path().resolve()
data_dir = working_dir / "data"
template_dir = working_dir / "templates"

# Single plots
template_2d = json.loads(Path(template_dir, "template_2d.json").read_text())
data_2d = pd.read_csv(Path(data_dir, "data_2d.csv"))

template_3d = json.loads(Path(template_dir, "template_3d.json").read_text())
data_3d = pd.read_csv(Path(data_dir, "data_3d.csv"))

template_heat = json.loads(Path(template_dir, "template_heatmap.json").read_text())
data_heat = pd.read_csv(Path(data_dir, "data_heatmap.csv"))

template_discrete = json.loads(Path(template_dir, "template_discrete.json").read_text())
data_discrete = pd.read_csv(Path(data_dir, "data_discrete.csv"))

# Subplots
template_subplot_2d = json.loads(
    Path(template_dir, "template_subplot_2d.json").read_text()
)
template_subplot_3d = json.loads(
    Path(template_dir, "template_subplot_3d.json").read_text()
)
template_mixed_322 = json.loads(
    Path(template_dir, "template_mixed_322.json").read_text()
)

templates = {
    "2d": template_2d,
    "3d": template_3d,
    "heatmap": template_heat,
    "discrete": template_discrete,
    "subplot_2d": template_subplot_2d,
    "subplot_3d": template_subplot_3d,
    "mixed_3-2-2": template_mixed_322,
}
data = {
    "2d": data_2d,
    "3d": data_3d,
    "heatmap": data_heat,
    "discrete": data_discrete,
    "subplot_2d": data_2d,
    "subplot_3d": data_3d,
    "mixed_3-2-2": data_3d,
}

## Global Setup

 ### Unit Conversion

In [4]:
def unit_transformation(data: pd.Series, unit_conversion: str) -> pd.Series:
    """
    Converts data to a specified unit of measurement

    Parameters
    ----------
    data : ndarray
        Data being transformed
    unit_conversion : str
        Unit conversion being applied

    Returns
    -------
    Series
        Pandas Series of transformed data

    Raises
    ------
    ValueError
        Invalid scale factor
    """
    if unit_conversion is None or unit_conversion == "None":
        return data
    elif unit_conversion == "m to km":
        return data * 0.001
    elif unit_conversion == "m to kft":
        return data * 0.0032808
    elif unit_conversion == "m to NMI":
        return data * 0.00053996
    elif unit_conversion == "km to m":
        return data * 1000
    elif unit_conversion == "km to kft":
        return data * 3.28084
    elif unit_conversion == "km to NMI":
        return data * 0.539957
    elif unit_conversion == "ft to km":
        return data * 0.0003048
    elif unit_conversion == "ft to kft":
        return data * 0.001
    elif unit_conversion == "ft to NMI":
        return data * 0.00016458
    elif unit_conversion == "mps to kts":
        return data * 1.9438
    elif unit_conversion == "kts to mps":
        return data * 0.51444
    elif unit_conversion == "m^2 to dBsm":
        return 10 * np.log10(data + np.spacing(0))
    elif unit_conversion == "dBsm to m^2":
        return pd.Series(np.power(10, data / 10))
    elif unit_conversion == "rad to deg":
        return data * 180 / np.pi
    elif unit_conversion == "deg to rad":
        return data * np.pi / 180
    else:
        raise ValueError(f"Invalid scaling factor: {unit_conversion}")

 ### Style Settings

In [5]:
def margin_dict(is_3d:bool=False):
    b = 40
    if info_annotations:
        b += 20
    if bl or br:
        b += 10
    # if is_3d:
    #     b -= 40
        
    return {
        "t": 40 if tl or tr else 30,
        "l": 20,
        "r": 1,
        "b": b,
    }

def title_dict(title: str):
    return {
        "text": "<b>" + title.strip() + "</b>",
        "font": {"size": 13},
        "x": 0.5,
        "xanchor": "center",
        "yanchor": "top",
    }

 ### Axis Range

In [6]:
class AxisType(str, Enum):
    """Supported axis layout sizing"""

    AUTO = "Auto"
    EQUAL = "Equal"
    MANUAL = "Manual"


def axis_ranges(axis_type: AxisType, axes: dict, data: list[pd.Series]):
    if axis_type == AxisType.MANUAL:
        return [
            unit_transformation(pd.Series([axis["min"], axis["max"]]),axis["scaleFactor"]) for axis in axes
        ]
    elif axis_type == AxisType.EQUAL:
        return [
            (
                min(variable.min() for variable in data),
                max(variable.max() for variable in data),
            )
        ] * len(axes)
    elif axis_type == AxisType.AUTO:
        return [None] * len(axes)

## Annotations

### Classification

In [7]:
def resolution_shifts(is_3d: bool = False, show_sidebar: bool = False):
    return {
        "top": 40,
        "bottom": -70 if info_annotations else -50,
        "left": -20 if is_3d else -30,
        "right": 65 if show_sidebar else 0
    }

def get_annotation(
        xanchor: str, 
        yanchor: str, 
        show:bool=False, 
        is_3d:bool=False, 
        sidebar:bool=False,
        classification="UNCLASSIFIED"
) -> Annotation:
    shifts = resolution_shifts(is_3d, sidebar)
    x = 1 if xanchor == "right" else 0
    y = 1 if yanchor == "top" else 0
    return Annotation(
            xref="paper",
            yref="paper",
            x=x,
            xshift=shifts[xanchor],
            xanchor=xanchor,
            y=y,
            yshift=shifts[yanchor],
            yanchor=yanchor,
            text= f"<b>{classification}</b>",
            showarrow=False,
            visible=show,
            font={"size": 10}
    )

def get_classifications(is_3d: bool = False, show_sidebar: bool = False) -> list[Annotation]:
    return [
        get_annotation("left", "top", tl, is_3d, show_sidebar),
        get_annotation("left", "bottom", bl, is_3d, show_sidebar),
        get_annotation("right", "top", tr, is_3d, show_sidebar),
        get_annotation("right", "bottom", br, is_3d, show_sidebar),
    ]

### Info Labels

In [8]:
miss_distance = Annotation(
    x=-0.01,
    y=-0.13,
    xref="paper",
    yref="paper",
    text="Miss distance 20.1 m",
    showarrow=False,
)

missile_info = Annotation(
    x=1.01,
    y=-0.13,
    xref="paper",
    yref="paper",
    text="Missile Info: MSL_1",
    showarrow=False,
)

## Layout

In [9]:
def get_core_layout(is_3d:bool, show_legend: bool):
    annotations = get_classifications(is_3d, show_legend)
    if info_annotations:
        annotations.append(miss_distance)
        annotations.append(missile_info)
    return go.Layout(
        title=title_dict(templates[id]["title"]),
        template=templates[id]["layout"]["theme"],
        height=templates[id]["layout"]["height"],
        width=templates[id]["layout"]["width"],
        legend=Legend(title=templates[id]["layout"]["legendTitle"]),
        margin=margin_dict(is_3d),
        annotations=annotations,
    )

## Traces

In [10]:
def get_marker(variable: dict, is_3d: bool):

    if is_3d:
        return s3.Marker(
            color=variable["markerColor"],
            size=variable["markerSize"],
            symbol=variable["markerType"].lower(),
        )
    return s.Marker(
        color=variable["markerColor"],
        size=variable["markerSize"],
        symbol=variable["markerType"].lower(),
    )


def get_line(variable: dict, is_3d: bool):
    if is_3d:
        return s3.Line(
            color=variable["lineColor"],
            width=variable["lineWidth"],
            dash=variable["lineType"],
        )

    return s.Line(
        color=variable["lineColor"],
        width=variable["lineWidth"],
        dash=variable["lineType"],
        shape=variable["lineShape"],
    )

In [11]:
def build_scatter(
    data: dict, variable: dict, grid: dict, is_3d: bool, legendgroup: str | None = None
):
    scatter = go.Scatter3d if is_3d else go.Scatter
    return scatter(
        **data,
        name=variable["traceName"],
        connectgaps=variable["connectgaps"],
        mode=variable["mode"],
        marker=get_marker(variable, is_3d),
        line=get_line(variable, is_3d),
        legendgroup=legendgroup,
        legendgrouptitle_text=variable["legendGroupTitle"],
        showlegend=grid["showLegend"],
    )


def get_axis_data(data: pd.DataFrame, variable: dict, axis: dict):
    return {
        axis["name"]: unit_transformation(
            data[variable[f"{axis['name']}Variable"]], axis["scaleFactor"]
        )
    }


def build_trace(
    fig: go.Figure,
    data: pd.DataFrame,
    grids: list[dict],
    variable: dict,
    is_3d: bool,
    row: int | None = None,
    col: int | None = None,
    legendgroup: str | None = None,
):
    grid = grids[variable["subplot"] - 1]

    # Create data
    data_dict: dict = {}
    for axis in grid["axes"]:
        data_dict.update(get_axis_data(data, variable, axis))

    # Add trace
    fig.add_trace(
        build_scatter(data_dict, variable, grid, is_3d, legendgroup),
        row=row,
        col=col,
    )

    return list(data_dict.values())

## Axes

In [12]:
def get_axis_layout(axis_dict, range, axis, grid, is_3d):
    title = {"text": axis["label"], "font": {"size": 12}}
    if not is_3d:
        title["standoff"] = 8 if axis["name"] == "y" else 1

    axis_dict = {
        "range": range,
        "showgrid": axis["enableGrid"],
        "title": title,
        "linecolor": "black",
        "linewidth": 1,
        "mirror": axis["enableBox"],
    }

    # if grid["overwriteDomain"]:
    #     axis_dict["domain"] = [axis["domainMin"], axis["domainMax"]]

    # if grid_item["overwriteDomain"] and axis["name"] in ['x', 'y']:
    #     axes_dict[scene_name]["domain"] = {axis['name']: [
    #         axis["domainMin"],
    #         axis["domainMax"],
    #     ]
    return axis_dict


def get_2d_layout(trace_data: list[pd.Series], axes_dict: dict, grid: dict):
    ranges = axis_ranges(grid["axisType"], grid["axes"], trace_data)

    for index, axis in enumerate(grid["axes"]):
        # TODO: Ordering of the subplots need to be shifted for mixed plots
        name = f"{axis['name']}axis{grid['subplot']}"
        # name = f"{axis['name']}axis{grid['subplot']-1}"
        if name not in axes_dict:
            axes_dict[name] = {}
        axes_dict[name] = get_axis_layout(axes_dict[name], ranges[index], axis, grid, False)
        if grid["overwriteDomain"]:
            axes_dict[name]["domain"] = [axis["domainMin"], axis["domainMax"]]

def get_3d_layout(trace_data: list[pd.Series], axes_dict: dict, grid: dict):
    ranges = axis_ranges(grid["axisType"], grid["axes"], trace_data)

    scene_name = f"scene{grid['subplot']}"
    if scene_name not in axes_dict:
        axes_dict[scene_name] = {}

    for index, axis in enumerate(grid["axes"]):
        name = f"{axis['name']}axis"
        if name not in axes_dict[scene_name]:
            axes_dict[scene_name][name] = {}

        axes_dict[scene_name][name] = get_axis_layout(axes_dict[scene_name][name], ranges[index], axis, grid, True)

        if grid["overwriteDomain"] and axis["name"] in ['x', 'y']:
            axes_dict[scene_name]["domain"] = {axis['name']: [
                axis["domainMin"],
                axis["domainMax"],
            ]}

    axes_dict[scene_name]["camera"] = {
        "projection": {"type": "orthographic"},
        "eye": {"x": -1.25, "z": 0.8},
    }

 ## 2D Plots

In [13]:
# Variables
id = "2d"

# Initialize figures
fig2d = go.Figure()
fig2d.update_layout(get_core_layout(
        False, any(grid["showLegend"] for grid in templates[id]["grid"])
    )
)
# fig2d = go.Figure(
#     layout=get_core_layout(
#         False, any(grid["showLegend"] for grid in templates[id]["grid"])
#     )
# )

# Create data
trace_data = [
    build_trace(fig2d, data[id], templates[id]["grid"], variable, False)
    for variable in templates[id]["variables"]
]

# Axes
axes_dict: dict = {}
for grid_item in templates[id]["grid"]:
    if grid_item["plotType"] == "2d":
        get_2d_layout(trace_data, axes_dict, grid_item)
    else:
        get_3d_layout(trace_data, axes_dict, grid_item)

fig2d.update_layout(axes_dict)

if render_plots:
    fig2d.show(renderer="notebook_connected")

In [14]:
if generate:
    image_dir = working_dir / "images"
    cProfile.run("fig2d.to_plotly_json()")

    html_file = working_dir / "images" / "plot2d.html"
    cProfile.run("fig2d.write_html(html_file)")

    png_file = working_dir / "images" / "plot2d.png"
    cProfile.run("fig2d.write_image(png_file, engine='kaleido')")

## 3D Plot

In [15]:
# Variables
id = "3d"

# Initialize figures
fig3d = go.Figure(
    layout=get_core_layout(
        True, any(grid["showLegend"] for grid in templates[id]["grid"])
    )
)

# Create data
trace_data = [
    build_trace(fig3d, data[id], templates[id]["grid"], variable, True)
    for variable in templates[id]["variables"]
]

# Axes
axes_dict: dict = {}
for grid_item in templates[id]["grid"]:
    if grid_item["plotType"] == "2d":
        get_2d_layout(trace_data, axes_dict, grid_item)
    else:
        get_3d_layout(trace_data, axes_dict, grid_item)
        
fig3d.update_layout(axes_dict)

if render_plots:
    fig3d.show(renderer="notebook_connected")

## Heatmap

### Colors

In [17]:
def additional_colorscales() -> dict[str, list[str]]:
    """Additional RGB colorscales

    Parula scale came from MATLAB

    Returns
    -------
    dict[str, list[str]]
        Name of the colorscale and its values
    """
    return {
        "Parula": [
            "rgb(53,43,135)",
            "rgb(54,44,139)",
            "rgb(54,46,142)",
            "rgb(54,47,145)",
            "rgb(54,49,148)",
            "rgb(54,50,151)",
            "rgb(54,52,154)",
            "rgb(54,53,157)",
            "rgb(54,55,160)",
            "rgb(54,56,163)",
            "rgb(54,58,167)",
            "rgb(54,59,170)",
            "rgb(53,61,173)",
            "rgb(53,62,176)",
            "rgb(52,64,179)",
            "rgb(51,66,183)",
            "rgb(50,67,186)",
            "rgb(49,69,189)",
            "rgb(48,71,192)",
            "rgb(46,72,195)",
            "rgb(44,74,199)",
            "rgb(42,76,202)",
            "rgb(39,78,205)",
            "rgb(37,80,209)",
            "rgb(33,82,212)",
            "rgb(29,85,215)",
            "rgb(25,87,217)",
            "rgb(21,89,220)",
            "rgb(17,91,222)",
            "rgb(12,93,223)",
            "rgb(8,95,224)",
            "rgb(5,97,225)",
            "rgb(3,99,226)",
            "rgb(2,100,226)",
            "rgb(2,102,226)",
            "rgb(1,103,226)",
            "rgb(1,104,226)",
            "rgb(2,105,226)",
            "rgb(2,106,226)",
            "rgb(3,108,225)",
            "rgb(4,109,225)",
            "rgb(5,110,225)",
            "rgb(6,111,224)",
            "rgb(7,112,224)",
            "rgb(8,113,223)",
            "rgb(9,114,223)",
            "rgb(10,115,223)",
            "rgb(11,116,222)",
            "rgb(12,117,221)",
            "rgb(13,118,221)",
            "rgb(14,119,220)",
            "rgb(15,120,220)",
            "rgb(16,121,219)",
            "rgb(16,122,219)",
            "rgb(17,123,218)",
            "rgb(18,124,218)",
            "rgb(18,124,217)",
            "rgb(19,125,217)",
            "rgb(19,126,216)",
            "rgb(19,127,216)",
            "rgb(20,128,215)",
            "rgb(20,129,214)",
            "rgb(20,130,214)",
            "rgb(20,131,214)",
            "rgb(20,132,213)",
            "rgb(20,133,213)",
            "rgb(20,134,212)",
            "rgb(20,136,212)",
            "rgb(20,137,212)",
            "rgb(19,138,211)",
            "rgb(19,139,211)",
            "rgb(18,140,211)",
            "rgb(17,141,211)",
            "rgb(16,143,211)",
            "rgb(15,144,211)",
            "rgb(14,145,211)",
            "rgb(13,147,211)",
            "rgb(12,148,211)",
            "rgb(11,149,211)",
            "rgb(10,150,210)",
            "rgb(10,152,210)",
            "rgb(9,153,210)",
            "rgb(8,154,210)",
            "rgb(8,155,209)",
            "rgb(7,156,209)",
            "rgb(7,157,208)",
            "rgb(7,158,208)",
            "rgb(6,159,207)",
            "rgb(6,160,206)",
            "rgb(6,161,206)",
            "rgb(6,162,205)",
            "rgb(6,163,204)",
            "rgb(6,163,203)",
            "rgb(6,164,203)",
            "rgb(6,165,202)",
            "rgb(6,166,201)",
            "rgb(6,166,200)",
            "rgb(6,167,199)",
            "rgb(6,168,198)",
            "rgb(6,169,197)",
            "rgb(6,169,196)",
            "rgb(7,170,195)",
            "rgb(7,171,194)",
            "rgb(8,171,193)",
            "rgb(9,172,192)",
            "rgb(10,172,191)",
            "rgb(11,173,189)",
            "rgb(12,174,188)",
            "rgb(13,174,187)",
            "rgb(15,175,186)",
            "rgb(16,175,185)",
            "rgb(18,176,184)",
            "rgb(19,177,182)",
            "rgb(21,177,181)",
            "rgb(23,178,180)",
            "rgb(24,178,179)",
            "rgb(26,179,177)",
            "rgb(28,179,176)",
            "rgb(30,180,175)",
            "rgb(32,180,173)",
            "rgb(34,181,172)",
            "rgb(36,181,171)",
            "rgb(38,182,169)",
            "rgb(41,183,168)",
            "rgb(43,183,167)",
            "rgb(45,184,165)",
            "rgb(47,184,164)",
            "rgb(50,184,162)",
            "rgb(52,185,161)",
            "rgb(54,185,159)",
            "rgb(57,186,158)",
            "rgb(59,186,156)",
            "rgb(62,187,155)",
            "rgb(65,187,153)",
            "rgb(67,187,152)",
            "rgb(70,188,150)",
            "rgb(73,188,149)",
            "rgb(76,189,147)",
            "rgb(78,189,146)",
            "rgb(81,189,144)",
            "rgb(84,190,143)",
            "rgb(87,190,141)",
            "rgb(90,190,140)",
            "rgb(93,190,138)",
            "rgb(96,191,137)",
            "rgb(99,191,135)",
            "rgb(102,191,134)",
            "rgb(105,191,132)",
            "rgb(108,191,131)",
            "rgb(111,191,130)",
            "rgb(114,192,129)",
            "rgb(117,192,127)",
            "rgb(120,192,126)",
            "rgb(122,192,125)",
            "rgb(125,192,124)",
            "rgb(128,192,123)",
            "rgb(131,192,121)",
            "rgb(133,192,120)",
            "rgb(136,192,119)",
            "rgb(139,192,118)",
            "rgb(141,192,117)",
            "rgb(144,192,116)",
            "rgb(146,192,115)",
            "rgb(149,192,114)",
            "rgb(151,191,113)",
            "rgb(154,191,112)",
            "rgb(156,191,111)",
            "rgb(159,191,110)",
            "rgb(161,191,109)",
            "rgb(163,191,108)",
            "rgb(166,191,107)",
            "rgb(168,191,106)",
            "rgb(170,191,105)",
            "rgb(173,190,104)",
            "rgb(175,190,104)",
            "rgb(177,190,103)",
            "rgb(179,190,102)",
            "rgb(182,190,101)",
            "rgb(184,190,100)",
            "rgb(186,190,99)",
            "rgb(188,189,98)",
            "rgb(190,189,97)",
            "rgb(192,189,97)",
            "rgb(194,189,96)",
            "rgb(197,189,95)",
            "rgb(199,189,94)",
            "rgb(201,188,93)",
            "rgb(203,188,92)",
            "rgb(205,188,91)",
            "rgb(207,188,91)",
            "rgb(209,188,90)",
            "rgb(211,187,89)",
            "rgb(213,187,88)",
            "rgb(215,187,87)",
            "rgb(217,187,86)",
            "rgb(219,187,85)",
            "rgb(221,187,84)",
            "rgb(223,186,84)",
            "rgb(225,186,83)",
            "rgb(227,186,82)",
            "rgb(229,186,81)",
            "rgb(231,186,80)",
            "rgb(233,186,79)",
            "rgb(235,186,78)",
            "rgb(237,186,77)",
            "rgb(239,186,76)",
            "rgb(241,186,74)",
            "rgb(243,186,73)",
            "rgb(245,186,72)",
            "rgb(247,186,71)",
            "rgb(249,187,69)",
            "rgb(250,188,67)",
            "rgb(252,188,66)",
            "rgb(253,189,64)",
            "rgb(254,190,62)",
            "rgb(255,191,61)",
            "rgb(255,193,59)",
            "rgb(256,194,58)",
            "rgb(256,195,56)",
            "rgb(256,196,55)",
            "rgb(256,198,54)",
            "rgb(255,199,52)",
            "rgb(255,200,51)",
            "rgb(255,202,50)",
            "rgb(254,203,49)",
            "rgb(254,204,48)",
            "rgb(253,206,47)",
            "rgb(253,207,46)",
            "rgb(252,208,45)",
            "rgb(252,210,44)",
            "rgb(251,211,43)",
            "rgb(250,212,42)",
            "rgb(250,213,41)",
            "rgb(249,215,40)",
            "rgb(249,216,39)",
            "rgb(248,217,38)",
            "rgb(248,219,36)",
            "rgb(247,220,35)",
            "rgb(247,222,34)",
            "rgb(246,223,33)",
            "rgb(246,225,32)",
            "rgb(246,226,31)",
            "rgb(246,228,30)",
            "rgb(245,229,29)",
            "rgb(245,231,28)",
            "rgb(245,233,27)",
            "rgb(246,234,25)",
            "rgb(246,236,24)",
            "rgb(246,238,23)",
            "rgb(246,240,22)",
            "rgb(247,242,21)",
            "rgb(247,244,19)",
            "rgb(248,246,18)",
            "rgb(249,248,17)",
            "rgb(249,250,15)",
            "rgb(250,252,14)",
        ]
    }

In [18]:
import inspect
from decimal import Decimal, getcontext
from typing import Optional, TypedDict

import numpy as np
import pandas as pd
import plotly.express as px

class LevelDict(TypedDict):
    """Structure of a level item"""

    start: float
    stop: float
    step: float


class HeatMap():
    """
    A specialized trace data variable the creates a Plotly heatmap
    """

    def __init__(
        self,
        name: str,
        colors: Optional[pd.Series] = None,
        default_max: Optional[float] = None,
        default_length: Optional[int] = None,
        colorscale_file: Optional[str] = None,
        contours: Optional[list[LevelDict]] = None,
        colorBarTitle: Optional[str] = None,
        colorScale: str = "",
        color: bool = False,
        **kwargs,
    ) -> None:
        """Creates a heatmap variable

        Parameters
        ----------
        name : str
            Name of the color variable the data originate from
        colors : Optional[pd.Series], optional
            Heatmap values for each point, by default None
        default_max : Optional[float], optional
            Max value for the default heatmap levels, by default None
        default_length : Optional[int], optional
            Size of the default heatmap, by default None
        colorscale_file : Optional[str], optional
            Custom colorscale file containing Parula, by default None
        contours : Optional[list[LevelDict]], optional
            User defined heatmap levels, by default None
        colorBarTitle : Optional[str], optional
            Color bar totle, by default None
        colorScale : str, optional
            Color scale to apply, by default ""
        color : bool, optional
            Show the colorbar, by default False
        """
        self.data = self.__create_colors(colors, default_max, default_length)
        self.colors = self.data.copy()

        self.color_variable = name
        self.title = colorBarTitle
        self.color_scale = colorScale
        self.show_colorbar = color

        # User defined heatmap bins
        self.levels = [] if contours is None else contours

        # Plotly heatmap properties
        self.__tickvals: list[float] = []
        self.__cmin: Optional[float] = None
        self.__cmax: Optional[float] = None
        self.__color_scale_values: list[str] = []

        self.__create_color_scale(colorscale_file)

    def data_to_dict(self) -> dict:
        if len(self.data) <= 1:
            return {}

        marker = s.Marker(
            cmin=self.__cmin,
            cmax=self.__cmax,
            color=self.colors,
            colorbar=s.marker.ColorBar(
                title={
                    "text": f"<b>{self.title}</b>",
                    "side": "right",
                    "font": {"size": 12},
                },
                thickness=20,
                tickmode="array",
                tickvals=self.__tickvals if len(self.__tickvals) > 0 else None,
                # tickfont=10,
            ),
            colorscale=list(self.__color_scale_values),
            showscale=self.show_colorbar,
        )

        return {
            "marker": marker.to_plotly_json(),
            "text": [f"{self.color_variable}: {c:.3f}" for c in self.colors.to_list()],
        }

    def __create_colors(
        self,
        colors: Optional[pd.Series],
        default_max: Optional[float],
        default_length: Optional[int],
    ) -> pd.Series:
        """
        Checks and provides colors for the heatmap

        Parameters
        ----------
        colors : Optional[np.ndarray]
            Colors chosen for heatmap
        default_max : Optional[float]
            Maximum color for default color levels
        default_length : Optional[int]
            Default length for the auto-generated heatmap colors

        Returns
        -------
        pd.Series
            Heatmap color values
        """
        if colors is not None:
            return colors

        if default_max is not None and default_length is not None:
            return pd.Series(np.linspace(0, default_max, default_length))

        return pd.Series([])

    def __create_color_scale(self, colorscale_file: Optional[str] = None) -> None:
        """
        Creates color scales for the heatmap

        Raises
        ------
        ValueError
            Invalid color scales
        """
        supported_scales = self.supported_colorscales(colorscale_file)
        if self.color_scale not in supported_scales:
            raise ValueError(
                f"{self.color_scale} is not a supported color scale."
                f" Please choose from {supported_scales}"
            )

        # Converts bin ranges into a single, sorted numpy array
        bins = self.__generate_bins()

        # Digitize the heatmap data into user defined levels
        if len(bins) > 0:
            self.__digitize(bins)

        # Limits set to avoid auto-scaling
        self.__cmin = bins[0] if len(bins) > 0 else None
        self.__cmax = bins[-1] if len(bins) > 0 else None

        # Apply the colorscale
        custom_colors = additional_colorscales()
        scale = (
            custom_colors[self.color_scale]
            if self.color_scale in custom_colors
            else getattr(px.colors.sequential, self.color_scale)
        )
        self.__color_scale_values = px.colors.make_colorscale(scale)

    def __generate_bins(self) -> np.ndarray:
        """
        The levels define individual breakpoints for the color gradient.
        Heatmap bins are created for mapping the actual color data to the
        coarser user-defined gradient

        Returns
        -------
        np.ndarray
            Color container bins
        """
        getcontext().prec = 15

        bins = np.array([])
        for level in self.levels:
            # Convert level values into tick marks
            ticks = [
                float(value)
                # Decimals help with the floating point percision
                for value in np.arange(
                    Decimal(level["start"]),
                    Decimal(level["stop"]),
                    Decimal(level["step"]),
                )
            ]
            ticks.append(float(Decimal(level["stop"])))
            self.__tickvals += ticks

            # Create color bins
            bins = np.append(
                bins, np.arange(level["start"], level["stop"], level["step"])
            )
            bins = np.append(bins, level["stop"])

        return np.sort(bins)

    def __digitize(self, bins: np.ndarray) -> None:
        """Converts the individual color data in bins

        Parameters
        ----------
        bins : np.ndarray
            Color container bins
        """
        bin_idx = np.searchsorted(bins, self.colors, side="right")
        bin_idx[self.colors < bins[0]] = 1
        self.data = pd.Series(bins[bin_idx - 1])

    @staticmethod
    def supported_colorscales(custom_colorscale_file: Optional[str] = None) -> list[str]:
        """
        List of supported heatmap colorscales

        Returns
        -------
        list[str]
            Supported color scales
        """
        names = list(additional_colorscales())
        for name, body in inspect.getmembers(getattr(px.colors, "sequential")):
            if isinstance(body, list) and name[-2:] != "_r" and name[0] != "_":
                names.append(name)

        return sorted(names)

### Plot

In [19]:
# Variables
id = "heatmap"
grid = templates[id]["grid"][0]

# Annotations
annotations = get_classifications(show_sidebar=grid["showLegend"])

# Layout
layout = go.Layout(
    title=title_dict(templates[id]["title"]),
    template=templates[id]["layout"]["theme"],
    height=templates[id]["layout"]["height"],
    width=templates[id]["layout"]["width"],
    legend=Legend(title=templates[id]["layout"]["legendTitle"]),
    margin=margin_dict(),
    annotations=annotations,
)
figHeat = go.Figure(layout=layout)


trace_data = []
for variable in templates[id]["variables"]:
    # Create data
    x = unit_transformation(data[id][variable["xVariable"]], grid["axes"][0]["scaleFactor"])
    y = unit_transformation(data[id][variable["yVariable"]], grid["axes"][1]["scaleFactor"])
    color = data[id][variable["colorVariable"]]
    trace_data += [x, y]

    heatmap = HeatMap(
        variable["colorVariable"],
        data[id][variable["colorVariable"]],
        max(x.max(), y.max()),
        len(color),
        None,
        None,
        grid["colorBarTitle"],
        grid["colorScale"],
        grid["showColorBar"],
    )

    heatmap_data = heatmap.data_to_dict()
    heatmap_data["marker"].update(
        dict(
            size=variable["markerSize"],
            symbol=variable["markerType"].lower(),
        )
    )

    figHeat.add_trace(
        go.Scatter(
            name=variable["traceName"],
            x=x,
            y=y,
            connectgaps=variable["connectgaps"],
            mode=variable["mode"],
            **heatmap_data,
            line=s.Line(
                color=variable["lineColor"],
                width=variable["lineWidth"],
                dash=variable["lineType"],
                shape=variable["lineShape"],
            ),
            legendgroup=None,
            legendgrouptitle_text=variable["legendGroupTitle"],
            showlegend=grid["showLegend"],
        )
    )


# Axes
ranges = axis_ranges(grid["axisType"], grid["axes"], trace_data)
axes_dict = {"xaxis": {}, "yaxis": {}}
for index, axis in enumerate(grid["axes"]):
    standoff = 8 if axis["name"] == "y" else 1
    name = f"{axis['name']}axis"
    axes_dict[name] = {
        "range": ranges[index],
        "showgrid": axis["enableGrid"],
        "title": {
            "text": axis["label"],
            "font": {"size": 12},
            "standoff": standoff,
        },
        "linecolor": "black",
        "linewidth": 1,
        "mirror": axis["enableBox"],
    }
    if grid["overwriteDomain"]:
        axes_dict[name]["domain"] = [axis["domainMin"], axis["domainMax"]]

figHeat.update_layout(**axes_dict)

if render_plots:
    figHeat.show(renderer="notebook_connected")

## Subplots

### 2D

In [20]:
# Variables
id = "subplot_2d"


# Initialize figures
layout = get_core_layout(
    False, any(grid["showLegend"] for grid in templates[id]["grid"])
)
num_rows = len(set(variable["row"] for variable in templates[id]["variables"]))
num_columns = len(set(variable["column"] for variable in templates[id]["variables"]))
subplot_titles = [grid["title"] for grid in templates[id]["grid"]]
subplots2d = make_subplots(num_rows, num_columns, subplot_titles=subplot_titles)
subplots2d.update_layout(layout)

# Create data
trace_data = [
    build_trace(
        subplots2d,
        data[id],
        templates[id]["grid"],
        variable,
        False,
        variable["row"],
        variable["column"],
        f"{variable['row']}-{variable['column']}",
    )
    for variable in templates[id]["variables"]
]

# Axes
axes_dict: dict = {}
for grid_item in templates[id]["grid"]:
    if grid_item["plotType"] == "2d":
        get_2d_layout(trace_data, axes_dict, grid_item)
    else:
        get_3d_layout(trace_data, axes_dict, grid_item)

subplots2d.update_layout(axes_dict)

if render_plots:
    subplots2d.show(renderer="notebook_connected")

### 3D

In [21]:
# Variables
id = "subplot_3d"

# Initialize figures
layout = get_core_layout(
    True, any(grid["showLegend"] for grid in templates[id]["grid"])
)
num_rows = len(set(variable["row"] for variable in templates[id]["variables"]))
num_columns = len(set(variable["column"] for variable in templates[id]["variables"]))
subplot_titles = [grid["title"] for grid in templates[id]["grid"]]
subplots3d = make_subplots(
    num_rows,
    num_columns,
    vertical_spacing=0.05,
    subplot_titles=subplot_titles,
    specs=[[{"is_3d": True}], [{"is_3d": True}], [{"is_3d": True}]],
)
subplots3d.update_layout(layout)

# Create data
trace_data = [
    build_trace(
        subplots3d,
        data[id],
        templates[id]["grid"],
        variable,
        True,
        variable["row"],
        variable["column"],
        f"{variable['row']}-{variable['column']}",
    )
    for variable in templates[id]["variables"]
]

# Axes
axes_dict: dict = {}
for grid_item in templates[id]["grid"]:
    if grid_item["plotType"] == "2d":
        get_2d_layout(trace_data, axes_dict, grid_item)
    else:
        get_3d_layout(trace_data, axes_dict, grid_item)

subplots3d.update_layout(axes_dict)
if render_plots:
    subplots3d.show(renderer="notebook_connected")

### Mixed 3-2-2

In [22]:
# Variables
id = "mixed_3-2-2"

# Initialize figures
layout = get_core_layout(
    False, any(grid["showLegend"] for grid in templates[id]["grid"])
)
num_rows = len(set(variable["row"] for variable in templates[id]["variables"]))
num_columns = len(set(variable["column"] for variable in templates[id]["variables"]))
subplot_titles = [grid["title"] for grid in templates[id]["grid"]]

specs = [[{"is_3d": len(grid["axes"]) == 3}] for grid in templates[id]["grid"]]
fig_mixed_322 = make_subplots(
    num_rows,
    num_columns,
    vertical_spacing=0.05,  # FIXME: should be a file attribute
    # shared_xaxes=True,
    # subplot_titles=subplot_titles, # https://github.com/plotly/plotly.js/issues/2746
    specs=specs,
)
fig_mixed_322.update_layout(layout)

trace_data = [
    build_trace(
        fig_mixed_322,
        data[id],
        templates[id]["grid"],
        variable,
        variable["plotType"] == "3d",
        variable["row"],
        variable["column"],
        f"{variable['row']}-{variable['column']}",
    )
    for variable in templates[id]["variables"]
]
# Axes
axes_dict: dict = {}
for grid_item in templates[id]["grid"]:
    if grid_item["plotType"] ==  "2d":
        get_2d_layout(trace_data, axes_dict, grid_item)
    else:
        get_3d_layout(trace_data, axes_dict, grid_item)

fig_mixed_322.update_layout(axes_dict)
if render_plots:
    fig_mixed_322.show(renderer="notebook_connected")