In [None]:
from __future__ import annotations

import typing as t

import solara
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass

In [None]:
class ReactiveMeta(type):
    def __new__(cls, name, bases, class_dict):
        def make_reactive(class_dict: dict):
            new_dict = {}
            for key, value in class_dict.items():
                if (
                    not key.startswith("__")  # built-in method/property
                    and not callable(value)  # added method/property
                    and not hasattr(value, "__annotations__")  # nested reactive class
                ):
                    new_dict[key] = solara.reactive(value)
                else:
                    new_dict[key] = value
            return new_dict

        return super().__new__(cls, name, bases, make_reactive(class_dict))

In [None]:
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ReactiveDataClass(metaclass=ReactiveMeta):
    ...
    # def __getattr__(self, name):
    #     attr = getattr(self, name)
    #     if isinstance(attr, solara.Reactive):
    #         return attr.value
    #     return attr

    # def __setattr__(self, name: str, value) -> None:
    #     attr = getattr(self, name)
    #     if isinstance(attr, solara.Reactive):
    #         setattr(attr, "value", value)
    #     attr = value

In [None]:
class AttributesModel(ReactiveDataClass):
    hair_color: solara.Reactive[
        t.Literal[
            "brown",
            "blonde",
            "black",
            "red",
        ]
    ] = solara.reactive("brown")


class Model(ReactiveDataClass):
    name: solara.Reactive[str] = solara.reactive("Bob")
    age: solara.Reactive[int] = solara.reactive(0)
    attributes: AttributesModel = AttributesModel()

In [None]:
@solara.component
def Component(model: Model):
    def age():
        model.age.value += 1

    def dye():
        if model.attributes.hair_color.value == "brown":
            model.attributes.hair_color.value = "blonde"
        else:
            model.attributes.hair_color.value = "brown"

    solara.Text(
        f"{model.name.value} is {model.age.value} years old and has {model.attributes.hair_color.value} hair."
    )
    solara.Button(label="Time forward", on_click=age)
    solara.Button(label="Dye hair", on_click=dye)

In [None]:
Component(Model())

In [None]:
import solara
from aiida.orm import ProcessNode, StructureData
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass

from aiidalab_qe.common.services.aiida import AiiDAService

# TODO dynamically "concatenate" announced plugin schemas
# TODO provide descriptions throughout
# TODO consider implementing model-level validation
# TODO consider using the models for UI selector options (dropdowns, toggles, etc.)


class ReactiveMeta(type):
    def __new__(cls, name, bases, class_dict):
        def make_reactive(class_dict: dict):
            new_dict = {}
            for key, value in class_dict.items():
                if (
                    not key.startswith("__")  # built-in method/property
                    and not callable(value)  # added method/property
                    and not hasattr(value, "__annotations__")  # nested reactive class
                ):
                    new_dict[key] = solara.reactive(value)
                else:
                    new_dict[key] = value
            return new_dict

        return super().__new__(cls, name, bases, make_reactive(class_dict))


@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ReactiveDataclass:
    def __init__(self, **data):
        for key, value in data.items():
            setattr(getattr(self, key), "value", value)

    def __repr__(self):
        return f"{self.__class__.__name__}({', '.join([f'{key}={getattr(self, key).value}' for key in self.__annotations__.keys()])})"


class BasicModel(ReactiveDataclass):
    protocol: t.Literal["fast", "moderate", "precise"] = "moderate"
    spin_type: t.Literal["none", "collinear"] = "none"
    electronic_type: t.Literal["metal", "insulator"] = "metal"


class SystemModel(ReactiveDataclass):
    tot_charge: float = 0.0
    starting_ns_eigenvalue: list[tuple[int, int, str, int]] = []
    ecutwfc: float = 0.0
    ecutrho: float = 0.0
    vdw_corr: t.Literal[
        "none",
        "dft-d3",
        "dft-d3bj",
        "dft-d3m",
        "dft-d3mbj",
        "ts-vdw",
    ] = "none"
    smearing: t.Literal[
        "cold",
        "gaussian",
        "fermi-dirac",
        "methfessel-paxton",
    ] = "cold"
    degauss: float = 0.0
    lspinorb: bool = False
    noncolin: bool = False
    nspin: int = 1
    tot_magnetization: float = 0.0


class ControlModel(ReactiveDataclass):
    forc_conv_thr: float = 0.0
    etot_conv_thr: float = 0.0


class ElectronsModel(ReactiveDataclass):
    conv_thr: float = 0.0
    electron_maxstep: int = 80


class PwParametersModel(ReactiveDataclass):
    SYSTEM: SystemModel = SystemModel()
    CONTROL: ControlModel = ControlModel()
    ELECTRONS: ElectronsModel = ElectronsModel()


class PwModel(ReactiveDataclass):
    parameters: PwParametersModel = PwParametersModel()
    pseudos: dict[str, str] = {}


class HubbardParametersModel(ReactiveDataclass):
    hubbard_u: dict[str, float] = {}


class AdvancedModel(ReactiveDataclass):
    pw: PwModel = PwModel()
    clean_workdir: bool = False
    kpoints_distance: float = 0.0
    optimization_maxsteps: int = 50
    pseudo_family: str = ""
    hubbard_parameters: HubbardParametersModel = HubbardParametersModel()
    initial_magnetic_moments: dict[str, float] = {}


# class BandsModel(ReactiveDataclass):
#     projwfc_bands: bool


# class PdosModel(ReactiveDataclass):
#     nscf_kpoints_distance: float
#     use_pdos_degauss: bool
#     pdos_degauss: float
#     energy_grid_step: float


# class XasPseudosModel(ReactiveDataclass):
#     gipaw: str
#     core_hole: str


# class XasModel(ReactiveDataclass):
#     elements_list: list[str]
#     core_hole_treatments: dict[str, str]
#     pseudo_labels: dict[str, XasPseudosModel]
#     core_wfc_data_labels: dict[str, str]
#     supercell_min_parameter: float


# class CorrectionEnergyModel(ReactiveDataclass):
#     exp: float
#     core: float


# class XpsModel(ReactiveDataclass):
#     structure_type: str
#     pseudo_group: str
#     correction_energies: dict[str, CorrectionEnergyModel]
#     core_level_list: list[str]


class CodeParallelizationModel(ReactiveDataclass):
    npools: t.Optional[int] = None


class CodeModel(ReactiveDataclass):
    # options: list[list[tuple[str, str]]]
    code: str = ""
    nodes: int = 1
    cpus: int = 1
    ntasks_per_node: int = 1
    cpus_per_task: int = 1
    max_wallclock_seconds: int = 3600
    parallelization: CodeParallelizationModel = CodeParallelizationModel()


class CodesModel(ReactiveDataclass):
    override: t.Optional[bool] = None
    codes: dict[str, CodeModel] = {}


class ComputationalResourcesModel(ReactiveDataclass):
    global_: CodesModel = CodesModel()
    # bands: CodesModel
    # pdos: CodesModel
    # xas: CodesModel


class CalculationParametersModel(ReactiveDataclass):
    relax_type: t.Literal["none", "positions", "positions_cell"] = "positions_cell"
    properties: list[str] = []
    basic: BasicModel = BasicModel()
    advanced: AdvancedModel = AdvancedModel()
    # bands: BandsModel
    # pdos: PdosModel
    # xas: XasModel
    # xps: XpsModel


class QeAppModel(ReactiveDataclass):
    input_structure: StructureData | None = None
    calculation_parameters: CalculationParametersModel = CalculationParametersModel()
    computational_resources: ComputationalResourcesModel = ComputationalResourcesModel()
    process: ProcessNode | None = None


def from_process(pk: int | None) -> QeAppModel:
    from aiida.orm.utils.serialize import deserialize_unsafe

    try:
        process = AiiDAService.load_qe_app_workflow_node(pk)
        assert process
        ui_parameters = deserialize_unsafe(process.base.extras.get("ui_parameters", {}))
        assert ui_parameters
    except AssertionError:
        return QeAppModel()

    calculation_parameters = _extract_calculation_parameters(ui_parameters)
    computational_resources = _extract_computational_resources(ui_parameters)

    return QeAppModel(
        input_structure=process.inputs.structure,
        calculation_parameters=calculation_parameters,
        computational_resources=computational_resources,
        process=process,
    )


def _extract_calculation_parameters(parameters: dict) -> CalculationParametersModel:
    model = CalculationParametersModel()

    workchain_parameters: dict = parameters.get("workchain", {})
    model.relax_type = workchain_parameters.get("relax_type")

    models = {
        "basic": BasicModel,
        "advanced": AdvancedModel,
    }

    # TODO extand models by plugins

    for identifier, sub_model in models.items():
        if sub_model_parameters := parameters.get(identifier):
            setattr(model, identifier, sub_model(**sub_model_parameters))

    return model


def _extract_computational_resources(parameters: dict) -> ComputationalResourcesModel:
    return ComputationalResourcesModel(**parameters["codes"])

In [None]:
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ReactiveDataclass:
    def __init__(self, **data):
        for key, value in data.items():
            setattr(getattr(self, key), "value", value)

    def __repr__(self):
        return f"{self.__class__.__name__}({', '.join([f'{key}={getattr(self, key).value}' for key in self.__annotations__.keys()])})"

    def __getattr__(self, name):
        attr = getattr(self, name)
        if isinstance(attr, solara.Reactive):
            return attr.value
        return attr

    def __setattr__(self, name: str, value) -> None:
        attr = getattr(self, name)
        if isinstance(attr, solara.Reactive):
            setattr(attr, "value", value)
        attr = value

In [None]:
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class Test(ReactiveDataclass):
    a: solara.Reactive[int] = solara.reactive(0)
    b: solara.Reactive[t.Literal["hello", "goodbye"]] = solara.reactive("hello")

In [None]:
test = Test()
test

In [None]:
import dataclasses
import solara


@dataclasses.dataclass(frozen=True)
class TodoItem:
    text: str
    done: bool

In [None]:
@solara.component
def Component(todo: TodoItem):
    todo_item = solara.use_reactive(todo)
    done = solara.lab.Ref(todo_item.fields.done)

    def toggle():
        done.value = not done.value

    solara.Text(
        f"{todo_item.value.text} is {'done' if todo_item.value.done else 'not done'}."
    )
    solara.Button(label="Toggle", on_click=toggle)

In [None]:
Component(TodoItem(text="Buy milk", done=False))