Skip to content

Commit

Permalink
Fix all type checks for current set of critical classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Sep 11, 2019
1 parent 7c909de commit 58159d1
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
pycodestyle pymatgen
echo "--- Done ---"
echo "mypy checks..."
mypy pymatgen/core pymatgen/symmetry pymatgen/transformations pymatgen/command_line pymatgen/analysis
mypy pymatgen/core pymatgen/symmetry pymatgen/transformations pymatgen/command_line pymatgen/analysis pymatgen/entries
echo "--- Done ---"
# Command line tests
pmg structure --convert --filenames test_files/Li2O.cif POSCAR.pmg
Expand Down
5 changes: 2 additions & 3 deletions pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def from_spacegroup(cls,
@classmethod
def from_magnetic_spacegroup(
cls,
msg: Union[str, 'MagneticSpaceGroup'],
msg: Union[str, 'MagneticSpaceGroup'], # type: ignore
lattice: Union[List, np.ndarray, Lattice],
species: Sequence[Union[str, Element, Specie, DummySpecie, Composition]],
coords: Sequence[Sequence[float]],
Expand Down Expand Up @@ -826,8 +826,7 @@ def from_magnetic_spacegroup(
"different!" % (len(species), len(magmoms))
)

frac_coords = coords if not coords_are_cartesian else \
lattice.get_fractional_coords(coords)
frac_coords = coords if not coords_are_cartesian else latt.get_fractional_coords(coords)

all_sp = [] # type: List[Union[str, Element, Specie, DummySpecie, Composition]]
all_coords = [] # type: List[List[float]]
Expand Down
24 changes: 12 additions & 12 deletions pymatgen/entries/entry_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import csv
import re

from typing import List, Union
from typing import List, Union, Iterable, Set
from pymatgen.core.periodic_table import Element
from pymatgen.core.composition import Composition
from pymatgen.analysis.phase_diagram import PDEntry
Expand Down Expand Up @@ -152,7 +152,7 @@ class EntrySet(collections.abc.MutableSet, MSONable):
subsets, dumping into files, etc.
"""

def __init__(self, entries: List[Union[PDEntry, ComputedEntry, ComputedStructureEntry]]):
def __init__(self, entries: Iterable[Union[PDEntry, ComputedEntry, ComputedStructureEntry]]):
"""
Args:
entries: All the entries.
Expand Down Expand Up @@ -211,14 +211,14 @@ def get_subset_in_chemsys(self, chemsys: List[str]):
Returns:
EntrySet
"""
chemsys = set(chemsys)
if not chemsys.issubset(self.chemsys):
raise ValueError("%s is not a subset of %s" % (chemsys,
chem_sys = set(chemsys)
if not chem_sys.issubset(self.chemsys):
raise ValueError("%s is not a subset of %s" % (chem_sys,
self.chemsys))
subset = set()
for e in self.entries:
elements = [sp.symbol for sp in e.composition.keys()]
if chemsys.issuperset(elements):
if chem_sys.issuperset(elements):
subset.add(e)
return EntrySet(subset)

Expand All @@ -238,19 +238,19 @@ def to_csv(self, filename: str, latexify_names: bool = False):
e.g., Li_{2}O
"""

elements = set()
els = set() # type: Set[Element]
for entry in self.entries:
elements.update(entry.composition.elements)
elements = sorted(list(elements), key=lambda a: a.X)
els.update(entry.composition.elements)
elements = sorted(list(els), key=lambda a: a.X)
writer = csv.writer(open(filename, "w"), delimiter=unicode2str(","),
quotechar=unicode2str("\""),
quoting=csv.QUOTE_MINIMAL)
writer.writerow(["Name"] + elements + ["Energy"])
writer.writerow(["Name"] + [el.symbol for el in elements] + ["Energy"])
for entry in self.entries:
row = [entry.name if not latexify_names
else re.sub(r"([0-9]+)", r"_{\1}", entry.name)]
row.extend([entry.composition[el] for el in elements])
row.append(entry.energy)
row.append(str(entry.energy))
writer.writerow(row)

@classmethod
Expand All @@ -270,7 +270,7 @@ def from_csv(cls, filename: str):
quoting=csv.QUOTE_MINIMAL)
entries = list()
header_read = False
elements = None
elements = [] # type: List[str]
for row in reader:
if not header_read:
elements = row[1:(len(row) - 1)]
Expand Down
2 changes: 1 addition & 1 deletion pymatgen/symmetry/maggroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import textwrap

from pymatgen.core import Lattice
from pymatgen.core.lattice import Lattice
from pymatgen.electronic_structure.core import Magmom
from pymatgen.symmetry.groups import SymmetryGroup, in_array_list
from pymatgen.core.operations import MagSymmOp
Expand Down

0 comments on commit 58159d1

Please sign in to comment.