Skip to content

Commit

Permalink
Implement defaults via base calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Feb 18, 2024
1 parent 4c4a357 commit a1dac8e
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 164 deletions.
139 changes: 31 additions & 108 deletions examples/coulomb_blockade/tester.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@
"leads_kpoints = orm.KpointsData()\n",
"leads_kpoints.set_kpoints_mesh([3, 1, 1])\n",
"\n",
"h, m, s = 1, 30, 60\n",
"\n",
"leads_output_prefix = orm.Str(\"leads\")"
]
},
Expand Down Expand Up @@ -148,8 +146,6 @@
"device_kpoints = orm.KpointsData()\n",
"device_kpoints.set_kpoints_mesh([1, 1, 1])\n",
"\n",
"h, m, s = 1, 30, 60\n",
"\n",
"device_output_prefix = orm.Str(\"device\")"
]
},
Expand All @@ -174,6 +170,34 @@
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Metadata"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"metadata = {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": NPROCS,\n",
" },\n",
" \"environment_variables\": {\n",
" \"OMP_NUM_THREADS\": OMP_NUM_THREADS,\n",
" \"NUMBA_NUM_THREADS\": NUMBA_NUM_THREADS,\n",
" },\n",
" }\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -194,32 +218,11 @@
" \"structure\": leads_structure,\n",
" \"kpoints\": leads_kpoints,\n",
" \"parameters\": leads_parameters,\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": 1,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" }\n",
" },\n",
" },\n",
" \"device\": {\n",
" \"structure\": device_structure,\n",
" \"kpoints\": device_kpoints,\n",
" \"parameters\": device_parameters,\n",
" \"write_nao\": orm.Bool(True),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": 1,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" }\n",
" },\n",
" },\n",
" },\n",
" \"scattering\": {\n",
Expand All @@ -236,30 +239,10 @@
" \"localization\": {\n",
" \"code\": orm.load_code(\"los-script\"),\n",
" \"lowdin\": orm.Bool(True),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": 1,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" }\n",
" },\n",
" },\n",
" \"greens_function\": {\n",
" \"code\": orm.load_code(\"greens-script\"),\n",
" \"basis\": orm.Dict(basis),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": 1,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" }\n",
" },\n",
" },\n",
" # \"greens_function_parameters\": orm.Dict(\n",
" # {\n",
Expand All @@ -282,45 +265,22 @@
" # \"beta\": 70.0,\n",
" # }\n",
" # ),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": NPROCS,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" \"environment_variables\": {\n",
" \"OMP_NUM_THREADS\": OMP_NUM_THREADS,\n",
" \"NUMBA_NUM_THREADS\": NUMBA_NUM_THREADS,\n",
" },\n",
" }\n",
" },\n",
" \"metadata\": metadata,\n",
" },\n",
" \"dmft\": {\n",
" \"code\": orm.load_code(\"dmft-script\"),\n",
" \"parameters\": orm.Dict(\n",
" {\n",
" # \"U\": 4.0,\n",
" # \"number_of_baths\": 4,\n",
" \"tolerance\": 200.0,\n",
" \"tolerance\": 200,\n",
" # \"alpha\": 0.0,\n",
" # \"inner_max_iter\": 1000,\n",
" # \"outer_max_iter\": 1000,\n",
" }\n",
" ),\n",
" \"converge_mu\": {\n",
" \"adjust_mu\": orm.Bool(True),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": 1,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" },\n",
" },\n",
" },\n",
" \"sweep_mu\": {\n",
" \"parameters\": orm.Dict(\n",
Expand All @@ -330,34 +290,11 @@
" \"dmu_step\": 0.5,\n",
" }\n",
" ),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": 1,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" }\n",
" },\n",
" },\n",
" },\n",
" \"transmission\": {\n",
" \"code\": orm.load_code(\"trans-script\"),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": NPROCS,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" \"environment_variables\": {\n",
" \"OMP_NUM_THREADS\": OMP_NUM_THREADS,\n",
" \"NUMBA_NUM_THREADS\": NUMBA_NUM_THREADS,\n",
" },\n",
" }\n",
" },\n",
" \"metadata\": metadata,\n",
" },\n",
" \"current\": {\n",
" \"code\": orm.load_code(\"curr-script\"),\n",
Expand All @@ -367,20 +304,6 @@
" # \"dV\": 0.1,\n",
" # \"temperature\": 9,\n",
" # }),\n",
" \"metadata\": {\n",
" \"options\": {\n",
" \"withmpi\": False,\n",
" \"resources\": {\n",
" \"num_machines\": 1,\n",
" \"num_mpiprocs_per_machine\": NPROCS,\n",
" },\n",
" \"max_wallclock_seconds\": h * m * s,\n",
" \"environment_variables\": {\n",
" \"OMP_NUM_THREADS\": NPROCS,\n",
" \"NUMBA_NUM_THREADS\": NPROCS,\n",
" },\n",
" }\n",
" },\n",
" },\n",
"}"
]
Expand Down
33 changes: 33 additions & 0 deletions src/aiida_quantum_transport/calculations/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from aiida.engine import CalcJob

if TYPE_CHECKING:
from aiida.engine.processes.calcjobs.calcjob import CalcJobProcessSpec


class BaseCalculation(CalcJob):
"""docstring"""

_default_parser_name = ""

@classmethod
def define(cls, spec: CalcJobProcessSpec) -> None:
"""docstring"""

super().define(spec)

_DEFAULTS = {
"metadata.options.parser_name": cls._default_parser_name,
"metadata.options.withmpi": False,
"metadata.options.max_wallclock_seconds": 3600,
"metadata.options.resources": lambda: {
"num_machines": 1,
"num_mpiprocs_per_machine": 1,
},
}

for port, default in _DEFAULTS.items():
spec.inputs.get_port(port).default = default
11 changes: 3 additions & 8 deletions src/aiida_quantum_transport/calculations/current.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from aiida import orm
from aiida.common.datastructures import CalcInfo, CodeInfo
from aiida.common.folders import Folder
from aiida.engine import CalcJob

from .base import BaseCalculation

if TYPE_CHECKING:
from aiida.engine.processes.calcjobs.calcjob import CalcJobProcessSpec


class CurrentCalculation(CalcJob):
class CurrentCalculation(BaseCalculation):
"""docstring"""

_default_parser_name = "quantum_transport.current"
Expand Down Expand Up @@ -49,12 +50,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
help="The results folder of the transmission calculation",
)

spec.input(
"metadata.options.parser_name",
valid_type=str,
default=cls._default_parser_name,
)

spec.output(
"remote_results_folder",
valid_type=orm.RemoteData,
Expand Down
11 changes: 3 additions & 8 deletions src/aiida_quantum_transport/calculations/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from aiida import orm
from aiida.common.datastructures import CalcInfo, CodeInfo
from aiida.common.folders import Folder
from aiida.engine import CalcJob

from .base import BaseCalculation

if TYPE_CHECKING:
from aiida.engine.processes.calcjobs.calcjob import CalcJobProcessSpec


class DFTCalculation(CalcJob):
class DFTCalculation(BaseCalculation):
"""docstring"""

_default_parser_name = "quantum_transport.dft"
Expand Down Expand Up @@ -48,12 +49,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
help="The input parameters",
)

spec.input(
"metadata.options.parser_name",
valid_type=str,
default=cls._default_parser_name,
)

spec.output(
"remote_results_folder",
valid_type=orm.RemoteData,
Expand Down
11 changes: 3 additions & 8 deletions src/aiida_quantum_transport/calculations/dmft.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from aiida import orm
from aiida.common.datastructures import CalcInfo, CodeInfo
from aiida.common.folders import Folder
from aiida.engine import CalcJob

from .base import BaseCalculation

if TYPE_CHECKING:
from aiida.engine.processes.calcjobs.calcjob import CalcJobProcessSpec


class DMFTCalculation(CalcJob):
class DMFTCalculation(BaseCalculation):
"""docstring"""

_default_parser_name = "quantum_transport.dmft"
Expand Down Expand Up @@ -89,12 +90,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
help="The converged chemical potential file",
)

spec.input(
"metadata.options.parser_name",
valid_type=str,
default=cls._default_parser_name,
)

spec.output(
"remote_results_folder",
valid_type=orm.RemoteData,
Expand Down
11 changes: 3 additions & 8 deletions src/aiida_quantum_transport/calculations/greens.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from aiida import orm
from aiida.common.datastructures import CalcInfo, CodeInfo
from aiida.common.folders import Folder
from aiida.engine import CalcJob

from .base import BaseCalculation

if TYPE_CHECKING:
from aiida.engine.processes.calcjobs.calcjob import CalcJobProcessSpec


class GreensFuncionParametersCalculation(CalcJob):
class GreensFuncionParametersCalculation(BaseCalculation):
"""docstring"""

_default_parser_name = "quantum_transport.greens"
Expand Down Expand Up @@ -66,12 +67,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None:
help="", # TODO fill in
)

spec.input(
"metadata.options.parser_name",
valid_type=str,
default=cls._default_parser_name,
)

spec.output(
"remote_results_folder",
valid_type=orm.RemoteData,
Expand Down
Loading

0 comments on commit a1dac8e

Please sign in to comment.