In [153]:
from __future__ import annotations

import typing as t
from enum import Enum
from pathlib import Path

from aiida import load_profile, orm
from aiida_quantumespresso.workflows.pw.relax import PwRelaxWorkChain
from aiida_workgraph import dynamic, shelljob, task

_ = load_profile()

In [None]:
class AfmCase(Enum):
    EMPIRICAL = 1
    HARTREE = 2
    HARTREE_RHO = 3

In [None]:
@task
def write_afm_params(params: dict) -> orm.SinglefileData:
    afm_filepath = Path.cwd() / "params.ini"
    with open(afm_filepath, "w") as config_file:
        for key, value in params.items():
            config_file.write(f"{key} {value}\n")
    return orm.SinglefileData(file=afm_filepath.as_posix())

In [None]:
@task
def write_structure_file(structure: orm.StructureData) -> orm.SinglefileData:
    geom_filepath = Path.cwd() / "geo.xyz"
    structure.get_ase().write(geom_filepath, format="xyz")
    return orm.SinglefileData(file=geom_filepath.as_posix())

In [None]:
RelaxationJob = task(PwRelaxWorkChain)


@task.graph
def AfmWorkflow(
    case: AfmCase,
    structure: orm.StructureData,
    afm_params: dict,
    relax: bool = False,
    dft_params: t.Annotated[dict, RelaxationJob.inputs] = None,
) -> t.Annotated[dict, dynamic(t.Any)]:
    dft_task = None

    if relax:
        assert dft_params, "Missing DFT calculation parameters"
        dft_task = RelaxationJob(
            structure=structure,
            **dft_params,
        )
        structure = dft_task.output_structure

    geometry_file = write_structure_file(structure=structure).result
    afm_params_file = write_afm_params(params=afm_params).result

    if case == AfmCase.EMPIRICAL.name:
        ljff = shelljob(
            command="ppafm-generate-ljff",
            arguments=["-i", "geo.xyz", "-f", "npy"],
            nodes={
                "geometry": geometry_file,
                "parameters": afm_params_file,
            },
            outputs=["FFLJ.npz"],
        )
        scan = shelljob(
            command="ppafm-relaxed-scan",
            arguments=["-f", "npy"],
            nodes={
                "parameters": afm_params_file,
                "ljff_data": ljff.FFLJ_npz,
            },
            metadata={
                "options": {
                    "use_symlinks": True,
                }
            },
            outputs=["Q0.00K0.35"],
        )
        results = shelljob(
            command="ppafm-plot-results",
            arguments=["--df", "--cbar", "--save_df", "-f", "npy"],
            nodes={
                "parameters": afm_params_file,
                "scan_dir": scan.Q0_00K0_35,
            },
            filenames={"scan_dir": "Q0.00K0.35"},
            metadata={
                "options": {
                    "use_symlinks": True,
                }
            },
            outputs=["Q0.00K0.35"],
        )
    else:
        raise ValueError(f"Unsupported case: {case}")

    return results

In [None]:
from ase.io import read

atoms = read("AFM_TEST/Empirical/geo.xyz")
structure = orm.StructureData(ase=atoms)
structure.set_pbc((False, False, False))

kpoints = orm.KpointsData()
kpoints.set_kpoints_mesh([1, 1, 1])

In [None]:
afm_params = {
    "PBC": "False",
    "tip": "s",
    "klat": "0.3490127886809",
    "krad": "21.913190531846",
    "gridA": "14.9412827110 0.0000000000 0.0000000000",
    "gridB": "0.0000000000 14.5091262213 0.0000000000",
    "gridC": "0.0000000000 0.0000000000 10.0820001747",
    "sigma": "0.7",
    "charge": "0.0",
    "r0Probe": "0.0 0.0 2.97",
    "scanMax": "14.9412827110 14.5091262213 11",
    "scanMin": "0.0 0.0  8",
    "scanStep": "0.1 0.1 0.1",
    "Amplitude": "1.4",
    "probeType": "O",
    "f0Cantilever": "22352.5",
    "gridN": "-1 -1 -1",
}

In [None]:
dft_params = {
    "base": {
        "metadata": {},
        "pw": {
            "metadata": {
                "options": {
                    "stash": {},
                    "resources": {"num_machines": 1},
                    "max_wallclock_seconds": 43200,
                    "withmpi": True,
                }
            },
            "pseudos": {
                "C": orm.load_node("175b8a59-2850-4c60-b627-2100429cfdaa"),
                "H": orm.load_node("0f39fda8-1939-4b63-a2f7-d8bfba7449dc"),
            },
            "code": orm.load_code("pw-7.4@localhost"),
            "parameters": {
                "CONTROL": {
                    "calculation": "relax",
                    "forc_conv_thr": 0.001,
                    "tprnfor": True,
                    "tstress": True,
                    "etot_conv_thr": 0.0002,
                    "nstep": 50,
                },
                "SYSTEM": {
                    "nosym": False,
                    "occupations": "fixed",
                    "ecutrho": 240.0,
                    "ecutwfc": 30.0,
                    "tot_charge": 0.0,
                    "vdw_corr": "none",
                },
                "ELECTRONS": {
                    "electron_maxstep": 80,
                    "mixing_beta": 0.4,
                    "conv_thr": 8e-10,
                },
            },
        },
        "kpoints": kpoints,
        "kpoints_force_parity": False,
        "max_iterations": 5,
    },
    "max_meta_convergence_iterations": 5,
    "meta_convergence": True,
    "volume_convergence": 0.05,
}

In [None]:
wg = AfmWorkflow.build(
    case=AfmCase.EMPIRICAL.name,
    structure=structure,
    afm_params=afm_params,
    relax=True,
    dft_params=dft_params,
)
wg

In [None]:
# wg.submit()