Skip to content

Commit

Permalink
Small pydantic related bug fixes and linting (#836)
Browse files Browse the repository at this point in the history
* No underscore prefix on field

* Mypy fixes

* Fix model validator for task doc

* Add int as possible icsd_id

* Linting

* More mypy fixes

* Fix type union
  • Loading branch information
munrojm committed Sep 27, 2023
1 parent dac2e71 commit c2fad96
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 25 deletions.
4 changes: 3 additions & 1 deletion emmet-builders/emmet/builders/abinit/sound_velocity.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def prechunk(self, number_splits: int): # pragma: no cover
# All relevant materials that have been updated since phonon props were last calculated
q = dict(self.query)

mats = self.phonon.newer_in(self.phonon_materials, exhaustive=True, criteria=q)
mats = self.sound_vel.newer_in(
self.phonon_materials, exhaustive=True, criteria=q
)

N = ceil(len(mats) / number_splits)

Expand Down
2 changes: 1 addition & 1 deletion emmet-core/emmet/core/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ def sanity_check(
warnings.append(CM.STRAIN_RANK.format(rank))

# elastic tensor eigenvalues
eig_vals, _ = np.linalg.eig(elastic_doc.raw)
eig_vals, _ = np.linalg.eig(elastic_doc.raw) # type: ignore
if np.any(eig_vals < 0.0):
warnings.append(WM.NEGATIVE_EIGVAL)

Expand Down
14 changes: 7 additions & 7 deletions emmet-core/emmet/core/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,22 +444,22 @@ def from_bsdos( # type: ignore[override]
new_origin_task_id = None

if bs_gap is not None and bs_gap <= dos_gap + 0.2:
summary_task = bs_entry.setyawan_curtarolo.task_id
summary_task = bs_entry.setyawan_curtarolo.task_id # type: ignore
summary_band_gap = bs_gap
summary_cbm = (
bs_entry.setyawan_curtarolo.cbm.get("energy", None) # type: ignore
if bs_entry.setyawan_curtarolo.cbm is not None
if bs_entry.setyawan_curtarolo.cbm is not None # type: ignore
else None
)
summary_vbm = (
bs_entry.setyawan_curtarolo.vbm.get("energy", None) # type: ignore
if bs_entry.setyawan_curtarolo.cbm is not None
if bs_entry.setyawan_curtarolo.cbm is not None # type: ignore
else None
) # type: ignore
summary_efermi = bs_entry.setyawan_curtarolo.efermi
is_gap_direct = bs_entry.setyawan_curtarolo.is_gap_direct
is_metal = bs_entry.setyawan_curtarolo.is_metal
summary_magnetic_ordering = bs_entry.setyawan_curtarolo.magnetic_ordering
summary_efermi = bs_entry.setyawan_curtarolo.efermi # type: ignore
is_gap_direct = bs_entry.setyawan_curtarolo.is_gap_direct # type: ignore
is_metal = bs_entry.setyawan_curtarolo.is_metal # type: ignore
summary_magnetic_ordering = bs_entry.setyawan_curtarolo.magnetic_ordering # type: ignore

for origin in origins:
if origin["name"] == "setyawan_curtarolo":
Expand Down
2 changes: 1 addition & 1 deletion emmet-core/emmet/core/qchem/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def best_lot(
"""

sorted_lots = sorted(
mol_doc.best_entries.keys(),
mol_doc.best_entries.keys(), # type: ignore
key=lambda x: evaluate_lot(x, funct_scores, basis_scores, solvent_scores),
)

Expand Down
4 changes: 2 additions & 2 deletions emmet-core/emmet/core/structure_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def from_ungrouped_structure_entries(
struct_group = cls.from_grouped_entries(
f_group_l, ignored_specie=ignored_specie
)
cnt_ += len(struct_group.material_ids)
cnt_ += len(struct_group.material_ids) # type: ignore
continue

logger.debug(
Expand All @@ -203,7 +203,7 @@ def from_ungrouped_structure_entries(
struct_group = cls.from_grouped_entries(
g, ignored_specie=ignored_specie
)
cnt_ += len(struct_group.material_ids)
cnt_ += len(struct_group.material_ids) # type: ignore
results.append(struct_group)
if cnt_ != len(entries):
raise RuntimeError(
Expand Down
3 changes: 2 additions & 1 deletion emmet-core/emmet/core/substrates.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class SubstratesDoc(BaseModel):
description="The Materials Project ID of the film material. This comes in the form: mp-******.",
)

_norients: Optional[int] = Field(
norients: Optional[int] = Field(
None,
description="Number of possible surface orientations for the substrate.",
alias="_norients",
)

orient: Optional[str] = Field(
Expand Down
15 changes: 6 additions & 9 deletions emmet-core/emmet/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class TaskDoc(StructureMetadata, extra="allow"):
author: Optional[str] = Field(
None, description="Author extracted from transformations"
)
icsd_id: Optional[str] = Field(
icsd_id: Optional[Union[str, int]] = Field(
None, description="Inorganic Crystal Structure Database id of the structure"
)
transformations: Optional[Dict[str, Any]] = Field(
Expand Down Expand Up @@ -404,14 +404,11 @@ class TaskDoc(StructureMetadata, extra="allow"):
def last_updated_dict_ok(cls, v) -> datetime:
return v if isinstance(v, datetime) else monty_decoder.process_decoded(v)

@model_validator(mode="before")
@classmethod
def set_entry(cls, values) -> datetime:
if not values.get("entry", None) and (
values.get("calcs_reversed", None) and values.get("task_id")
):
values["entry"] = cls.get_entry(values["calcs_reversed"], values["task_id"])
return values
@model_validator(mode="after")
def set_entry(self) -> datetime:
if not self.entry and (self.calcs_reversed and self.task_id):
self.entry = self.get_entry(self.calcs_reversed, self.task_id)
return self

@classmethod
def from_directory(
Expand Down
6 changes: 3 additions & 3 deletions emmet-core/emmet/core/vasp/validation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
from typing import Dict, List, Union, Optional
from typing import Dict, List, Union, Optional, Any

import numpy as np
from pydantic import ConfigDict, Field, PyObject
from pydantic import ConfigDict, Field
from pymatgen.core.structure import Structure
from pymatgen.io.vasp.sets import VaspInputSet

Expand Down Expand Up @@ -63,7 +63,7 @@ def from_task_doc(
task_doc: TaskDocument,
kpts_tolerance: float = SETTINGS.VASP_KPTS_TOLERANCE,
kspacing_tolerance: float = SETTINGS.VASP_KSPACING_TOLERANCE,
input_sets: Dict[str, PyObject] = SETTINGS.VASP_DEFAULT_INPUT_SETS,
input_sets: Dict[str, Any] = SETTINGS.VASP_DEFAULT_INPUT_SETS,
LDAU_fields: List[str] = SETTINGS.VASP_CHECKED_LDAU_FIELDS,
max_allowed_scf_gradient: float = SETTINGS.VASP_MAX_SCF_GRADIENT,
potcar_hashes: Optional[Dict[CalcType, Dict[str, str]]] = None,
Expand Down

0 comments on commit c2fad96

Please sign in to comment.