In [None]:
from aiida import load_profile

_ = load_profile()

In [None]:
import asyncio
from time import sleep

import ipywidgets as ipw
import matplotlib.pyplot as plt
import numpy as np
import traitlets as tl

from aiida import orm
from aiidalab_qe.common.widgets import LoadingWidget


class Model(tl.HasTraits):
    node_id = tl.Unicode()
    node = tl.Instance(orm.Node, allow_none=True)

    bands = tl.Instance(
        np.ndarray,
        default_value=np.array([]),
    )
    kpoints = tl.Instance(
        np.ndarray,
        default_value=np.array([]),
    )

    title = ""

    def load_node(self):
        try:
            self.node = orm.load_node(self.node_id)
        except Exception as e:
            print(f"Failed to load node: {e}")


class BandsViewer(ipw.VBox):
    def __init__(self, vid, model: Model, **kwargs):
        super().__init__(
            children=[LoadingWidget("Loading data fetcher")],
            **kwargs,
        )

        self._id = vid
        self._model = model
        self._model.observe(
            self._on_node_uuid_change,
            "node_id",
        )

        self.rendered = False

    def render(self):
        if self.rendered:
            return

        self.bands = ipw.Label(value="Fetching attribute 1...")
        ipw.dlink(
            (self._model, "bands"),
            (self.bands, "value"),
            lambda bands: f"Bands count: {len(bands)}",
        )

        self.kpoints = ipw.Label(value="Fetching attribute 2...")
        ipw.dlink(
            (self._model, "kpoints"),
            (self.kpoints, "value"),
            lambda kpoints: f"Kpoints count: {len(kpoints)}",
        )

        self.plot_button = ipw.Button(description="Plot bands")
        ipw.dlink(
            (self._model, "node"), (self.plot_button, "disabled"), lambda node: not node
        )
        self.plot_button.on_click(lambda _: asyncio.create_task(self.fetch_and_plot()))

        self.plot_area = ipw.Output()

        self.children = [
            ipw.HBox(
                children=[
                    ipw.Label(str(self._id)),
                    ipw.VBox(
                        children=[
                            self.bands,
                            self.kpoints,
                            self.plot_button,
                        ]
                    ),
                ]
            ),
            self.plot_area,
        ]

        self.rendered = True

    def _on_node_uuid_change(self, change):
        if change["new"]:
            self.node = self._model.load_node()

    async def fetch_bands(self):
        await asyncio.sleep(1)
        node = self._model.node
        bands = node.outputs.bands.bands.band_structure.get_array("bands")
        self._model.bands = bands

    async def fetch_kpoints(self):
        await asyncio.sleep(2)
        node = self._model.node
        kpoints = node.outputs.bands.bands.band_structure.get_array("kpoints")
        self._model.kpoints = kpoints

    async def plot_bands(self):
        await asyncio.sleep(0.5)
        self.plot_area.clear_output()
        with self.plot_area:
            print(len(self._model.bands))
            plt.title(self._model.title)
            _ = plt.plot(self._model.bands)
        plt.show()

    async def fetch_and_plot(self):
        await asyncio.gather(
            self.fetch_bands(),
            self.fetch_kpoints(),
        )
        await self.plot_bands()

In [None]:
models = []
for i in range(1, 5):
    model = Model()
    models.append(model)
    model.title = f"Band structure {i}"
    loader = BandsViewer(vid=i, model=model)
    display(loader)
    sleep(0.5)
    loader.render()

In [None]:
for model in models:
    model.node_id = "5216"