In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# AiiDA imports.
%load_ext aiida
%aiida

import io
import os
import tempfile

import ase.io.cube
import ipywidgets as ipw
import nglview
import numpy as np
import traitlets as tl
from aiida import orm
from aiida_shell import ShellJob
from cubehandler import Cube
from IPython.display import clear_output, display

In [None]:
node = load_node(1123)

In [None]:
class CubeArrayData3dViewerWidget(ipw.VBox):
    """Widget to View 3-dimensional AiiDA ArrayData object in 3D."""

    cube = tl.Instance(Cube, allow_none=True)

    def __init__(self, **kwargs):
        self.data_3d = None
        self.structure = None
        self.viewer = nglview.NGLWidget()
        self.orb_isosurf_slider = ipw.FloatSlider(
            continuous_update=False,
            value=1e-3,
            min=1e-4,
            max=1e-2,
            step=1e-4,
            description="Isovalue",
            readout_format=".1e",
        )
        self.orb_isosurf_slider.observe(
            lambda c: self.set_cube_isosurf([c["new"], -c["new"]], ["red", "blue"]),
            names="value",
        )
        super().__init__([self.viewer, self.orb_isosurf_slider], **kwargs)

    @tl.observe("cube")
    def on_observe_cube(self, _=None):
        """Update object attributes when cube trait is modified."""

        self.data_3d = self.cube.data
        self.structure = self.cube.ase_atoms
        self.update_plot()

    def update_plot(self):
        """Update the 3D plot."""
        while hasattr(self.viewer, "component_0"):
            self.viewer.component_0.clear_representations()
            self.viewer.remove_component(self.viewer.component_0.id)
        self.setup_cube_plot()
        self.set_cube_isosurf(
            [
                self.orb_isosurf_slider.value,
                -self.orb_isosurf_slider.value,
            ],
            ["red", "blue"],
        )

    def setup_cube_plot(self):
        """Setup cube plot."""
        n_repeat = 2
        atoms_xn = self.structure.repeat((n_repeat, 1, 1))
        data_xn = np.tile(self.data_3d, (n_repeat, 1, 1))
        self.viewer.add_component(nglview.ASEStructure(atoms_xn))
        with tempfile.NamedTemporaryFile(mode="w") as tempf:
            ase.io.cube.write_cube(tempf, atoms_xn, data_xn)
            c_2 = self.viewer.add_component(tempf.name, ext="cube")
            c_2.clear()

    def set_cube_isosurf(self, isovals, colors):
        """Set cube isosurface."""
        if hasattr(self.viewer, "component_1"):
            c_2 = self.viewer.component_1
            c_2.clear()
            for isov, col in zip(isovals, colors):
                c_2.add_surface(color=col, isolevelType="value", isolevel=isov)

In [None]:
class HandleCubeFiles(ipw.VBox):

    def __init__(self):
        self.node = None
        self.output = ipw.Output()
        self.show_selected = ipw.Button(description="Show selected cube")
        self.show_selected.on_click(self.show_selected_cube)
        self.dict_cube_files = {}
        self.select_calc_widget = ipw.Dropdown(
            description="Calculation:", options=self.get_calcs()
        )
        self.select_calc_widget.observe(self.select_calculation)
        self.select_calculation()
        super().__init__([self.select_calc_widget, self.show_selected, self.output])

    def list_cube_files(self):
        # Keep only the files that end with .cube
        cube_files = self.node.list_object_names()
        with self.output:
            clear_output()
            for f in cube_files:
                cube_selector = ipw.Checkbox(
                    description=f,
                    value=True,
                    style={"description_width": "initial"},
                    layout={"width": "initial"},
                )
                display(cube_selector)
                self.dict_cube_files[f] = cube_selector
        return

    def show_selected_cube(self, _):
        for name, widget in self.dict_cube_files.items():
            if widget.value:
                cube = Cube.from_content(node.get_object_content(name))
                viewer = CubeArrayData3dViewerWidget(cube=cube)
                display(viewer)

    def get_calcs(self):
        query = QueryBuilder()
        query.append(ShellJob, filters={"label": "cube-shrink"}, project="uuid")
        return query.all(flat=True)

    def select_calculation(self, _=None):
        calc = orm.load_node(self.select_calc_widget.value)
        self.node = calc.outputs.out_cubes
        self.list_cube_files()

In [None]:
cube_files = HandleCubeFiles()
display(cube_files)