In [None]:
from __future__ import annotations

import typing as t
from enum import Enum
from pathlib import Path

from aiida import load_profile, orm
from aiida_quantumespresso.calculations.pp import PpCalculation
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_workgraph import dynamic, namespace, shelljob, task
from ase import Atoms

_ = load_profile()

In [None]:
class AfmCase(Enum):
    EMPIRICAL = 1
    HARTREE = 2
    HARTREE_RHO = 3

In [None]:
@task
def write_afm_params(params: dict) -> orm.SinglefileData:
    afm_filepath = Path.cwd() / "params.ini"
    with open(afm_filepath, "w") as config_file:
        for key, value in params.items():
            if isinstance(value, (list, tuple)):
                value = " ".join(map(str, value))
            config_file.write(f"{key} {value}\n")
    return orm.SinglefileData(file=afm_filepath.as_posix())


@task
def write_structure_file(structure: Atoms) -> orm.SinglefileData:
    geom_filepath = Path.cwd() / "geo.xyz"
    structure.write(geom_filepath, format="xyz")
    return orm.SinglefileData(file=geom_filepath.as_posix())

RelaxJob = task(PwRelaxWorkChain)
ScfJob = task(PwBaseWorkChain)
PpJob = task(PpCalculation)

In [None]:
from __future__ import annotations

import typing as t
from enum import Enum
from pathlib import Path

from aiida import orm
from aiida_quantumespresso.calculations.pp import PpCalculation
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_workgraph import dynamic, namespace, shelljob, task
from ase import Atoms


class AfmCase(Enum):
    EMPIRICAL = "empirical"
    HARTREE = "hartree"
    HARTREE_RHO = "hartree_rho"


@task
def write_afm_params(params: dict) -> orm.SinglefileData:
    afm_filepath = Path.cwd() / "params.ini"
    with open(afm_filepath, "w") as config_file:
        for key, value in params.items():
            if isinstance(value, (list, tuple)):
                value = " ".join(map(str, value))
            config_file.write(f"{key} {value}\n")
    return orm.SinglefileData(file=afm_filepath.as_posix())


@task
def write_structure_file(structure: Atoms) -> orm.SinglefileData:
    geom_filepath = Path.cwd() / "geo.xyz"
    structure.write(geom_filepath, format="xyz")
    return orm.SinglefileData(file=geom_filepath.as_posix())


RelaxationJob = task(PwRelaxWorkChain)
ScfJob = task(PwBaseWorkChain)
PpJob = task(PpCalculation)


@task.graph
def AfmWorkflow(
    case: AfmCase,
    structure: orm.StructureData,
    afm_params: dict,
    relax: bool = False,
    dft_params: t.Annotated[
        dict,
        RelaxationJob.inputs,
    ] = None,
    pp_params: t.Annotated[
        dict,
        namespace(
            hartree=PpJob.inputs,
            charge=namespace(
                structure=PpJob.inputs,
                tip=PpJob.inputs,
            ),
        ),
    ] = None,
    tip: orm.StructureData = None,
) -> t.Annotated[dict, dynamic(t.Any)]:
    """AFM simulation workflow."""

    dft_task = None
    hartree_task = None
    rho_task = None
    tip_rho_task = None

    if relax:
        assert dft_params, "Missing DFT parameters"
        dft_task = RelaxationJob(
            structure=structure,
            **dft_params,
        )
        structure = dft_task.output_structure

    assert structure, "Missing structure"
    geometry_file = write_structure_file(structure=structure).result

    assert afm_params, "Missing AFM parameters"
    afm_params_file = write_afm_params(params=afm_params).result

    ljff = shelljob(
        command="ppafm-generate-ljff",
        nodes={
            "geometry": geometry_file,
            "parameters": afm_params_file,
        },
        arguments=[
            "-i",
            "geo.xyz",
            "-f",
            "npy",
        ],
        outputs=["FFLJ.npz"],
    )

    scan_nodes = {
        "parameters": afm_params_file,
        "ljff_data": ljff.FFLJ_npz,
    }

    metadata = {
        "options": {
            "use_symlinks": True,
        }
    }

    if case != AfmCase.EMPIRICAL:
        if not relax:
            assert dft_params, "Missing DFT parameters"
            scf_params = dft_params.get("base", {})
            assert scf_params, "Missing base SCF parameters"
            scf_params["pw.structure"] = structure
            dft_task = ScfJob(**scf_params)

        assert pp_params, "Missing post-processing parameters"
        hartree_params = pp_params.get("hartree", {})
        assert hartree_params, "Missing Hartree parameters"
        hartree_task = PpJob(
            parent_folder=dft_task.remote_folder,
            **hartree_params,
        )

        if case == AfmCase.HARTREE.name:
            elff = shelljob(
                command="ppafm-generate-elff",
                metadata=metadata,
                nodes={
                    "parameters": afm_params_file,
                    "ljff_data": ljff.FFLJ_npz,
                    "hartree_data": hartree_task.remote_folder,
                },
                filenames={
                    "hartree_data": "hartree",
                },
                arguments=[
                    "-i",
                    "hartree/aiida.fileout",
                    "-F",
                    "cube",
                    "-f",
                    "npy",
                ],
                outputs=["FFel.npz"],
            )

            scan_nodes["elff_data"] = elff.FFel_npz

        # Experimental feature, not fully tested
        elif case == AfmCase.HARTREE_RHO:
            charge_namespace: dict = pp_params.get("charge", {})
            geom_charge_params = charge_namespace.get("structure", {})
            assert geom_charge_params, "Missing structure charge density parameters"
            rho_task = PpJob(
                structure=structure,
                parent_folder=dft_task.remote_folder,
                **geom_charge_params,
            )

            # write tip file

            tip_charge_params = charge_namespace.get("tip", {})
            assert tip, "Missing tip structure"
            assert tip_charge_params, "Missing tip charge density parameters"
            tip_rho_task = PpJob(
                structure=tip,
                parent_folder=dft_task.remote_folder,
                **tip_charge_params,
            )

            conv_rho = shelljob(
                command="ppafm-conv-rho",
                nodes={
                    "geom_density": rho_task.remote_folder,
                    "tip_density": tip_rho_task.remote_folder,
                },
                filenames={
                    "geom_density": "structure",
                    "tip_density": "tip",
                },
                arguments=[
                    "-s",
                    "structure/charge.cube",
                    "-t",
                    "tip/charge.cube",
                    "-B",
                    "1.0",
                    "-E",
                ],
                outputs=["charge.cube"],
            )

            charge_elff = shelljob(
                command="ppafm-generate-elff",
                nodes={
                    "hartree_data": hartree_task.remote_folder,
                    "conv_density": conv_rho.charge_cube,
                    "tip_density": tip_rho_task.remote_folder,
                },
                filenames={
                    "hartree_data": "hartree",
                    "tip_density": "tip",
                },
                arguments=[
                    "-i",
                    "hartree/hartree.cube",
                    "-tip-dens",
                    "tip/charge.cube",
                    "--Rcode",
                    "0.7",
                    "-E",
                    "--doDensity",
                ],
                outputs=["FFel.npz"],
            )

            dftd3 = shelljob(
                command="ppafm-generate-dftd3",
                nodes={
                    "hartree_data": hartree_task.remote_folder,
                },
                filenames={
                    "hartree_data": "hartree",
                },
                arguments=[
                    "-i",
                    "hartree/hartree.cube",
                    "--df_name",
                    "PBE",
                ],
                outputs=["dftd3.dat"],
            )

            elff = shelljob(
                command="ppafm-generate-elff",
                nodes={
                    "hartree_data": hartree_task.remote_folder,
                    "charge_elff_data": charge_elff.FFel_npz,
                    "dftd3_data": dftd3.dftd3_dat,
                },
                arguments=[
                    "-i",
                    "hartree/hartree.cube",
                    "-f",
                    "npy",
                ],
                outputs=["FFel.npz"],
            )

        else:
            raise ValueError(f"Unsupported case: {case}")

    scan = shelljob(
        command="ppafm-relaxed-scan",
        metadata=metadata,
        nodes=scan_nodes,
        arguments=[
            "-f",
            "npy",
        ],
        outputs=["Q0.00K0.35"],
    )

    results = shelljob(
        command="ppafm-plot-results",
        metadata=metadata,
        nodes={
            "parameters": afm_params_file,
            "scan_dir": scan.Q0_00K0_35,
        },
        filenames={
            "scan_dir": "Q0.00K0.35",
        },
        arguments=[
            "--df",
            "--cbar",
            "--save_df",
            "-f",
            "npy",
        ],
        outputs=["Q0.00K0.35"],
    )

    return results


In [None]:
from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily

structure = orm.StructureData()
structure.set_cell([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]])
structure.set_pbc((False, False, False))
structure.append_atom(symbols="C", position=(8.2766055186, 6.1177492411, 5.0595246063))
structure.append_atom(symbols="C", position=(8.8582363932, 7.3839854920, 5.0363959612))
structure.append_atom(symbols="C", position=(8.0523006568, 8.5208073256, 5.0178902940))
structure.append_atom(symbols="C", position=(6.6646772714, 8.3913773689, 5.0224782237))
structure.append_atom(symbols="C", position=(6.0830463468, 7.1251409580, 5.0456056988))
structure.append_atom(symbols="C", position=(6.8889822232, 5.9883191444, 5.0641114560))
structure.append_atom(symbols="H", position=(8.9093056269, 5.2253257080, 5.0739448106))
structure.append_atom(symbols="H", position=(9.9475591624, 7.4856249613, 5.0328185108))
structure.append_atom(symbols="H", position=(8.5089059407, 9.5148483491, 4.9998780126))
structure.append_atom(symbols="H", position=(6.0319773431, 9.2838011020, 5.0080584394))
structure.append_atom(symbols="H", position=(4.9937235476, 7.0235011287, 5.0491811392))
structure.append_atom(symbols="H", position=(6.4323774593, 4.9942778709, 5.0821221574))

kpoints = orm.KpointsData()
kpoints.set_kpoints_mesh([1, 1, 1])

pseudo_family = t.cast(PseudoPotentialFamily, orm.load_group(4))
C_pp = pseudo_family.get_pseudo("C")
H_pp = pseudo_family.get_pseudo("H")

In [None]:
afm_params = {
    "PBC": False,
    "tip": "s",
    "klat": 0.3490127886809,
    "krad": 21.913190531846,
    "gridA": [14.9412827110, 0.0000000000, 0.0000000000],
    "gridB": [0.0000000000, 14.5091262213, 0.0000000000],
    "gridC": [0.0000000000, 0.0000000000, 10.0820001747],
    "sigma": 0.7,
    "charge": 0.0,
    "r0Probe": [0.0, 0.0, 2.97],
    "scanMax": [14.9412827110, 14.5091262213, 11],
    "scanMin": [0.0, 0.0, 8],
    "scanStep": [0.1, 0.1, 0.1],
    "Amplitude": 1.4,
    "probeType": "O",
    "f0Cantilever": 22352.5,
    "gridN": [-1, -1, -1],
}

In [None]:
dft_params = {
    "base": {
        "metadata": {},
        "pw": {
            "metadata": {
                "options": {
                    "resources": {
                        "num_machines": 1,
                    },
                    "max_wallclock_seconds": 43200,
                }
            },
            "pseudos": {
                "C": C_pp,
                "H": H_pp,
            },
            "code": orm.load_code("pw-7.4@localhost"),
            "parameters": {
                "CONTROL": {
                    "calculation": "relax",
                    "forc_conv_thr": 0.001,
                    "tprnfor": True,
                    "tstress": True,
                    "etot_conv_thr": 0.0002,
                    "nstep": 50,
                },
                "SYSTEM": {
                    "nosym": False,
                    "occupations": "fixed",
                    "ecutrho": 240.0,
                    "ecutwfc": 30.0,
                    "tot_charge": 0.0,
                    "vdw_corr": "none",
                },
                "ELECTRONS": {
                    "electron_maxstep": 80,
                    "mixing_beta": 0.4,
                    "conv_thr": 8e-10,
                },
            },
        },
        "kpoints": kpoints,
        "kpoints_force_parity": False,
        "max_iterations": 5,
    },
    "max_meta_convergence_iterations": 5,
    "meta_convergence": True,
    "volume_convergence": 0.05,
}

In [None]:
pp_params = {
    "hartree": {
        "metadata": {
            "options": {
                "resources": {
                    "num_machines": 1,
                },
                "max_wallclock_seconds": 43200,
            }
        },
        "code": orm.load_code("pp-7.4@localhost"),
        "parameters": {
            "INPUTPP": {
                "plot_num": 11,
            },
            "PLOT": {
                "iflag": 3,
            },
        },
    },
    "charge": {
        "structure": {},
        "tip": {},
    },
}


In [None]:
wg = AfmWorkflow.build(
    case=AfmCase.HARTREE.name,
    structure=structure,
    afm_params=afm_params,
    relax=False,
    dft_params=dft_params,
    pp_params=pp_params,
)

In [None]:
wg

In [None]:
wg.submit()

In [None]:
node = orm.load_node(wg.pk)

In [None]:
import base64

from IPython.display import HTML

fd: orm.FolderData = node.outputs.Q0_00K0_35

imgs = []

png_folder = "Amp1.40"

for obj in fd.list_objects(png_folder):
    if obj.name.endswith(".png"):
        with fd.open(f"{png_folder}/{obj.name}", "rb") as handle:
            data = handle.read()
            data64 = base64.b64encode(data).decode("utf-8")
            imgs.append(f"""
                <img
                    src="data:image/png;base64,{data64}"
                    style="max-width:150px; margin:5px;"
                />
            """)

HTML("".join(imgs))