In [None]:
from pathlib import Path


zarr_path = Path("../sample-data/sample_zarr_nested_groups_from_datatree.zarr")
zarr_path

In [None]:
import xarray as xr

xdt = xr.open_datatree(zarr_path)
xdt


In [None]:
print(xdt)

In [None]:
# type: ignore
import json
from pathlib import PurePosixPath
from typing import Any

from IPython.display import HTML, Markdown
from xarray.core.datatree import DataTree


def show_collapsible_json(dict_to_jsonify: dict[Any, Any], title: str) -> Markdown:
    """
    Take a serializable dict and produce a collapsible markdown section.

    To be used in a notebook.

    Parameters
    ----------
    dict_to_jsonify
        Dict to convert to JSON
    title
        Title that will be shown in the "Summary" section (usage of summary HTML tag)

    Returns
    -------
        collapsible markdown section
    """
    return Markdown(
        f"<details><summary>{title}</summary>\n\n```json\n{json.dumps(dict_to_jsonify, indent=4)}\n\n```\n</details>",
    )


def repr_datatree(
    xdt: DataTree,
    *,
    tabsize: int = 4,
    offset: int = 1,
    with_full_path: bool = False,
    with_group_variables: bool = False,
    rich: bool = False,
    rich_plot: bool = False,
    rich_txt_dataset: bool = True,
    opened: bool = False,
) -> str:
    """
    Represent a DataTree

    Parameters
    ----------
    xdt
        DataTree to represent
    tabsize, optional
        If ``rich`` is OFF, amount of tabs in the text representation, by default 4
    offset, optional
        If ``rich`` is OFF, initial amount of tabs, by default 1
    with_full_path, optional
        Whether to repr the full path or the stem of a Node., by default False
    with_group_variables, optional
        Whether to repr groups only, or groups with their variables, by default False
    rich, optional
        Use rich HTML representation, by default False
    rich_plot, optional
        If ``rich`` is ON, call the default xarray's ``plot`` on the node being currently represented, by default False
    rich_txt_dataset, optional
        If ``rich`` is ON, whether to display text-version or HTML-version of the node being repred, by default True
    opened, optional
        If ``rich`` is ON, whether to initially display opened or closed detail sections, by default False

    Returns
    -------
        String representation
    """
    lines = []

    for node in sorted(xdt.subtree, key=lambda n: n.path):
        path = PurePosixPath(node.path)
        tabs = len(path.parts)
        path_str = f"{(node.name or '') if not with_full_path else path}"
        group_title = f"{path_str} <small> (<code>{path}</code>) </small>"
        if len(node.data_vars) == 0:
            if rich:
                lines.append(Markdown(f"{'#' * tabs} {group_title}"))
            else:
                lines.append(f"{' ' * ((tabs - offset) * tabsize)}{group_title}")

            if node.attrs:
                # lines.append(JSON(node.attrs))
                lines.append(
                    Markdown(
                        wrap_in_details(
                            f"```json\n{json.dumps(node.attrs, indent=4)}\n```",
                            summary=f"Attributes <small> (`{path}`) </small>",
                            opened=opened,
                        ),
                    ),
                )

        if with_group_variables:
            for varname in node.ds.data_vars:
                varname_str = f"{path / varname if with_full_path else varname}"
                if rich:
                    if rich_plot:
                        lines.append(Markdown(varname_str))
                        lines.append(node.ds.data_vars[varname].plot())
                else:
                    lines.append(f"{' ' * (tabs * tabsize)}{varname_str}")
        elif rich:
            if len(node.data_vars) > 0:
                if rich_txt_dataset:
                    lines.append(
                        HTML(
                            wrap_in_details(
                                node.to_dataset()._repr_html_(),
                                summary=group_title,
                                opened=opened,
                            )
                        ),
                    )
                    # lines.append(node.ds)
                else:
                    code = f"```python\n{str(node.ds)}\n```"
                    lines.append(
                        Markdown(
                            wrap_in_details(code, summary=group_title, opened=opened)
                        )
                    )
    if rich:
        return lines
    else:
        return "\n".join(lines)


def repr_datatree_text(
    xdt: DataTree,
    *,
    tabsize: int = 4,
    offset: int = 1,
    with_full_path: bool = False,
    with_group_variables: bool = False,
) -> str:
    """
    Represent a DataTree as text

    Parameters
    ----------
    xdt
        DataTree to represent
    tabsize, optional
        If ``rich`` is OFF, amount of tabs in the text representation, by default 4
    offset, optional
        If ``rich`` is OFF, initial amount of tabs, by default 1
    with_full_path, optional
        Whether to repr the full path or the stem of a Node., by default False
    with_group_variables, optional
        Whether to repr groups only, or groups with their variables, by default False

    Returns
    -------
        DataTree text representation
    """
    # group_dot = "○"
    group_dot = "□"
    coord_dot = "◐"
    coord_dot = "▲"
    coord_dot = "◆"
    coord_dot = "◇"
    variable_dot = "●"
    lines = []
    for node in xdt.subtree:
        path = PurePosixPath(node.path)
        tabs = len(path.parts)
        offset_str = "" if with_full_path else (" " * ((tabs - offset) * tabsize))
        lines.append(
            f"{offset_str}{group_dot} {str(path).removeprefix('/') if with_full_path else path.stem}"
        )
        if with_group_variables:
            for varname in node.ds.coords:
                offset_str = "" if with_full_path else (" " * (tabs * tabsize))
                lines.append(
                    f"{offset_str}{coord_dot} {str(path / varname).removeprefix('/') if with_full_path else varname}"
                )
            for varname in node.ds.data_vars:
                offset_str = "" if with_full_path else (" " * (tabs * tabsize))
                lines.append(
                    f"{offset_str}{variable_dot} {str(path / varname).removeprefix('/') if with_full_path else varname}"
                )
    return "\n".join(lines)


def wrap_in_details(
    contents: str, *, summary: str | None = None, opened: bool = True
) -> str:
    """
    Wrap a string in details

    Parameters
    ----------
    contents
        Input string to wrap
    summary, optional
        Summary in the tag of same name in the detail section, by default None
    opened, optional
        If ``rich`` is ON, whether to initially display opened or closed detail sections, by default False

    Returns
    -------
        HTML-string wrapped with a detail section
    """
    summary_tag = "" if summary is None else f"\n<summary> {summary} </summary>"
    details_prefix = f"<details {'open' if opened else ''}>{summary_tag}\n\n"
    details_suffix = "\n\n</summary>"

    return f"{details_prefix}{contents}{details_suffix}"


def boolean_emoji(b: bool) -> str:
    return "✅" if b else "❌"


for with_group_variables, with_full_path in [
    (False, False),
    (True, False),
    (False, True),
    (True, True),
]:
    print(
        f"Show group variables: {boolean_emoji(with_group_variables)} | Show full path: {boolean_emoji(with_full_path)}"
    )
    print(
        repr_datatree_text(
            xdt,
            with_group_variables=with_group_variables,
            with_full_path=with_full_path,
        )
    )
    print()


In [None]:
display(*repr_datatree(xdt, rich=True, opened=False, rich_txt_dataset=False))
display(*repr_datatree(xdt, rich=True, opened=False, rich_txt_dataset=True))


In [None]:
list(xdt.to_dict().keys())


In [None]:
for key in xdt.to_dict().keys():
    print()
    print(key)
    xds = xdt[key].to_dataset()
    print(xds)

In [None]:
xr.show_versions()

In [None]:
from pathlib import PurePosixPath

path = PurePosixPath("/topgroup/ocean_data/chemical_properties")
path.parent, path.stem

In [None]:
from dataclasses import asdict, is_dataclass
import json
import numpy as np
from typing import Any, Callable
import xarray as xr


def sanitize_nan(obj, default: Callable[[Any], str]):
    if isinstance(obj, dict):
        return {k: sanitize_nan(v, default) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [sanitize_nan(v, default) for v in obj]
    elif isinstance(obj, float) and np.isnan(obj):
        # str gives 'nan', repr gives eg np.float64(nan)
        return default(obj)
    return obj


class ComplexEncoder(json.JSONEncoder):
    def encode(self, obj, *args, **kwargs):
        return super().encode(sanitize_nan(obj, repr), *args, **kwargs)

    def default(self, obj: Any) -> Any:
        """
        Best effort to try to convert non-native Python objects to strings when serialization to JSON,
        with a fallback on calling ``str`` on the object to serialized

        Parameters
        ----------
        obj
            Object to convert to

        Returns
        -------
            Serializable representation of the object
        """
        if is_dataclass(obj):
            return asdict(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, (xr.Dataset, xr.DataArray)):
            return obj.to_dict()
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.integer):
            return int(obj)
        # Let the base class default method raise the TypeError
        # The default=str kwarg can still be passed to json.dumps
        # to catch all and convert to string.
        return super().default(obj)


print(repr(np.float64(np.nan)))
print(str(np.float64(np.nan)))
print(np.isnan(np.float64(np.nan)))
print(sanitize_nan(np.float64(np.nan), default=repr))
print(sanitize_nan(np.float64(np.nan), default=str))
try:
    json.dumps([np.float64(np.nan), xr], cls=ComplexEncoder)
except TypeError as e:
    print(repr(e))
print(json.dumps([np.float64(np.nan), xr], default=str, cls=ComplexEncoder))
