Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 30, 2024
1 parent 61f8c1f commit 9fac4eb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 36 deletions.
46 changes: 18 additions & 28 deletions pymatgen/io/aims/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any

import numpy as np
from monty.io import zopen
from monty.json import MontyDecoder, MSONable

from pymatgen.core import Lattice, Molecule, Structure
Expand All @@ -35,7 +36,7 @@ class AimsGeometryIn(MSONable):
Attributes:
_content (str): The content of the input file
_structure (Structure or Molecule): The structure or molecule
_structure (Structure | Molecule): The structure or molecule
representation of the file
"""

Expand All @@ -56,12 +57,8 @@ def from_str(cls, contents: str) -> Self:
line.strip() for line in contents.split("\n") if len(line.strip()) > 0 and line.strip()[0] != "#"
]

species = []
coords = []
is_frac = []
lattice_vectors = []
charges_dct = {}
moments_dct = {}
species, coords, is_frac, lattice_vectors = [], [], [], []
charges_dct, moments_dct = {}, {}

for line in content_lines:
inp = line.split()
Expand Down Expand Up @@ -116,25 +113,21 @@ def from_file(cls, filepath: str | Path) -> Self:
Returns:
AimsGeometryIn: The input object represented in the file
"""
if str(filepath).endswith(".gz"):
with gzip.open(filepath, mode="rt") as infile:
content = infile.read()
else:
with open(filepath) as infile:
content = infile.read()
with zopen(filepath, mode="rt") as in_file:
content = in_file.read()
return cls.from_str(content)

@classmethod
def from_structure(cls, structure: Structure | Molecule) -> Self:
"""Construct an input file from an input structure.
Args:
structure (Structure or Molecule): The structure for the file
structure (Structure | Molecule): The structure for the file
Returns:
AimsGeometryIn: The input object for the structure
"""
content_lines = []
content_lines: list[str] = []

if isinstance(structure, Structure):
for lv in structure.lattice.matrix:
Expand Down Expand Up @@ -408,8 +401,7 @@ class AimsControlIn(MSONable):

def __post_init__(self) -> None:
"""Initialize the output list of _parameters"""
if "output" not in self._parameters:
self._parameters["output"] = []
self._parameters.setdefault("output", [])

def __getitem__(self, key: str) -> Any:
"""Get an input parameter
Expand Down Expand Up @@ -466,8 +458,7 @@ def parameters(self, parameters: dict[str, Any]) -> None:
parameters (dict[str, Any]): The new set of parameters to use
"""
self._parameters = parameters
if "output" not in self._parameters:
self._parameters["output"] = []
self._parameters.setdefault("output", [])

def get_aims_control_parameter_str(self, key: str, value: Any, fmt: str) -> str:
"""Get the string needed to add a parameter to the control.in file
Expand All @@ -490,7 +481,7 @@ def get_content(
"""Get the content of the file
Args:
structure (Structure or Molecule): The structure to write the input
structure (Structure | Molecule): The structure to write the input
file for
verbose_header (bool): If True print the input option dictionary
directory: str | Path | None = The directory for the calculation,
Expand All @@ -514,8 +505,7 @@ def get_content(
if verbose_header:
content += "# \n# List of parameters used to initialize the calculator:"
for param, val in parameters.items():
s = f"# {param}:{val}\n"
content += s
content += f"# {param}:{val}\n"
content += lim + "\n"

assert ("smearing" in parameters and "occupation_type" in parameters) is False
Expand Down Expand Up @@ -543,7 +533,7 @@ def get_content(
elif isinstance(value, bool):
content += self.get_aims_control_parameter_str(key, str(value).lower(), ".%s.")
elif isinstance(value, (tuple, list)):
content += self.get_aims_control_parameter_str(key, " ".join([str(x) for x in value]), "%s")
content += self.get_aims_control_parameter_str(key, " ".join(map(str, value)), "%s")
elif isinstance(value, str):
content += self.get_aims_control_parameter_str(key, value, "%s")
else:
Expand All @@ -569,7 +559,7 @@ def write_file(
"""Writes the control.in file
Args:
structure (Structure or Molecule): The structure to write the input
structure (Structure | Molecule): The structure to write the input
file for
directory (str or Path): The directory to write the control.in file.
If None use cwd
Expand Down Expand Up @@ -614,20 +604,20 @@ def get_species_block(self, structure: Structure | Molecule, species_dir: str |
Raises:
ValueError: If a file for the species is not found
"""
sb = ""
block = ""
species = np.unique(structure.species)
for sp in species:
filename = f"{species_dir}/{sp.Z:02d}_{sp.symbol}_default"
if Path(filename).exists():
with open(filename) as sf:
sb += "".join(sf.readlines())
block += "".join(sf.readlines())
elif Path(f"{filename}.gz").exists():
with gzip.open(f"{filename}.gz", mode="rt") as sf:
sb += "".join(sf.readlines())
block += "".join(sf.readlines())
else:
raise ValueError(f"Species file for {sp.symbol} not found.")

return sb
return block

def as_dict(self) -> dict[str, Any]:
"""Get a dictionary representation of the geometry.in file."""
Expand Down
16 changes: 8 additions & 8 deletions tests/io/aims/test_aims_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
def test_read_write_si_in(tmp_path: Path):
si = AimsGeometryIn.from_file(TEST_DIR / "geometry.in.si.gz")

in_lattice = np.array([[0.0, 2.715, 2.716], [2.717, 0.0, 2.718], [2.719, 2.720, 0.0]])
in_coords = np.array([[0.0, 0.0, 0.0], [0.25, 0.24, 0.26]])
in_lattice = np.array([[0, 2.715, 2.716], [2.717, 0, 2.718], [2.719, 2.720, 0]])
in_coords = np.array([[0, 0, 0], [0.25, 0.24, 0.26]])

assert all(sp.symbol == "Si" for sp in si.structure.species)
assert_allclose(si.structure.lattice.matrix, in_lattice)
Expand All @@ -50,9 +50,9 @@ def test_read_h2o_in(tmp_path: Path):
h2o = AimsGeometryIn.from_file(TEST_DIR / "geometry.in.h2o.gz")

in_coords = [
[0.0, 0.0, 0.119262],
[0.0, 0.763239, -0.477047],
[0.0, -0.763239, -0.477047],
[0, 0, 0.119262],
[0, 0.763239, -0.477047],
[0, -0.763239, -0.477047],
]

assert all(sp.symbol == symb for sp, symb in zip(h2o.structure.species, ["O", "H", "H"]))
Expand Down Expand Up @@ -107,12 +107,12 @@ def test_aims_cube():
AimsCube(type=ALLOWED_AIMS_CUBE_TYPES[0], origin=[0])

with pytest.raises(ValueError, match="Only three cube edges can be passed"):
AimsCube(type=ALLOWED_AIMS_CUBE_TYPES[0], edges=[[0.0, 0.0, 0.1]])
AimsCube(type=ALLOWED_AIMS_CUBE_TYPES[0], edges=[[0, 0, 0.1]])

with pytest.raises(ValueError, match="Each cube edge must have 3 components"):
AimsCube(
type=ALLOWED_AIMS_CUBE_TYPES[0],
edges=[[0.0, 0.0, 0.1], [0.1, 0.0, 0.0], [0.1, 0.0]],
edges=[[0, 0, 0.1], [0.1, 0, 0], [0.1, 0]],
)

with pytest.raises(ValueError, match="elf_type is only used when the cube type is elf. Otherwise it must be None"):
Expand All @@ -124,7 +124,7 @@ def test_aims_cube():
test_cube = AimsCube(
type="elf",
origin=[0, 0, 0],
edges=[[0.01, 0, 0], [0.0, 0.01, 0], [0.0, 0, 0.01]],
edges=[[0.01, 0, 0], [0, 0.01, 0], [0, 0, 0.01]],
points=[100, 100, 100],
spin_state=1,
kpoint=1,
Expand Down

0 comments on commit 9fac4eb

Please sign in to comment.