In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# AiiDA imports.
%load_ext aiida
%aiida

import ipywidgets as ipw
from IPython.display import display, clear_output

import nglview
import traitlets as tl

from cubehandler import Cube
import tempfile
import io
import os
import numpy as np
import ase.io.cube

In [None]:
ANG_2_BOHR = 1.889725989

In [None]:
node = load_node(6793)

In [None]:
class CubeArrayData3dViewerWidget(ipw.VBox):
    """Widget to View 3-dimensional AiiDA ArrayData object in 3D."""

    cube = tl.Instance(Cube, allow_none=True)

    def __init__(self, **kwargs):
        self.data_3d = None
        self.structure = None
        self.viewer = nglview.NGLWidget()
        self.orb_isosurf_slider = ipw.FloatSlider(
            continuous_update=False,
            value=1e-3,
            min=1e-4,
            max=1e-2,
            step=1e-4,
            description="Isovalue",
            readout_format=".1e",
        )
        self.orb_isosurf_slider.observe(
            lambda c: self.set_cube_isosurf([c["new"], -c["new"]], ["red", "blue"]),
            names="value",
        )
        super().__init__([self.viewer, self.orb_isosurf_slider], **kwargs)

    @tl.observe("cube")
    def on_observe_cube(self, _=None):
        """Update object attributes when cube trait is modified."""

        self.data_3d = self.cube.data
        self.structure = self.cube.ase_atoms
        self.update_plot()

    def update_plot(self):
        """Update the 3D plot."""
        while hasattr(self.viewer, "component_0"):
            self.viewer.component_0.clear_representations()
            self.viewer.remove_component(self.viewer.component_0.id)
        self.setup_cube_plot()
        self.set_cube_isosurf(
            [
                self.orb_isosurf_slider.value,
                -self.orb_isosurf_slider.value,
            ],
            ["red", "blue"],
        )

    def setup_cube_plot(self):
        """Setup cube plot."""
        n_repeat = 2
        atoms_xn = self.structure.repeat((n_repeat, 1, 1))
        data_xn = np.tile(self.data_3d, (n_repeat, 1, 1))
        self.viewer.add_component(nglview.ASEStructure(atoms_xn))
        with tempfile.NamedTemporaryFile(mode="w") as tempf:
            ase.io.cube.write_cube(tempf, atoms_xn, data_xn)
            c_2 = self.viewer.add_component(tempf.name, ext="cube")
            c_2.clear()

    def set_cube_isosurf(self, isovals, colors):
        """Set cube isosurface."""
        if hasattr(self.viewer, "component_1"):
            c_2 = self.viewer.component_1
            c_2.clear()
            for isov, col in zip(isovals, colors):
                c_2.add_surface(color=col, isolevelType="value", isolevel=isov)

In [None]:
class HandleCubeFiles(ipw.VBox):
    def __init__(self, node):
        self.node = node
        self.button = ipw.Button(description="List cube files")
        self.button.on_click(self.list_cube_files)
        self.output = ipw.Output()
        self.show_selected = ipw.Button(description="Show selected cube")
        self.show_selected.on_click(self.show_selected_cube)
        self.dict_cube_files = {}
        super().__init__([self.button, self.show_selected, self.output])

    def list_cube_files(self, _):
        # Keep only the files that end with .cube
        cube_files = [f for f in node.listdir() if f.endswith('.cube')]
        with self.output:
            clear_output()
            for f in cube_files:
                cube_selector = ipw.Checkbox(description=f, value=True, description_width='initial')
                display(cube_selector)
                self.dict_cube_files[f] = cube_selector
        return
    
    def show_selected_cube(self, _):
        for name, widget in self.dict_cube_files.items():
            if widget.value:
                # Create temporary folder and store a file in it using tempfile library
                tempdir = tempfile.TemporaryDirectory()
                fpath = os.path.join(tempdir.name, name)
                node.getfile(name, fpath)
                cube = Cube.from_file(fpath)
                viewer = CubeArrayData3dViewerWidget(cube=cube)
                display(viewer)

In [None]:
cube_files = HandleCubeFiles(node)

display(cube_files)