Skip to content

Commit

Permalink
Merge pull request #2247 from janosh/literal-types
Browse files Browse the repository at this point in the history
Type hints for literal string kwargs
  • Loading branch information
mkhorton committed Oct 6, 2021
2 parents a66e98b + 247d4fa commit 05c0b2b
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 149 deletions.
14 changes: 10 additions & 4 deletions pymatgen/analysis/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

import numpy as np
from monty.json import MSONable
from scipy.special import comb, erfc
from scipy import constants
from scipy.special import comb, erfc

from pymatgen.core.structure import Structure

Expand Down Expand Up @@ -471,9 +471,15 @@ def as_dict(self, verbosity: int = 0) -> Dict:
return d

@classmethod
def from_dict(cls, d: Dict, fmt: str = None, **kwargs):
"""
Create an EwaldSummation instance from json serialized dictionary.
def from_dict(cls, d: Dict, fmt: str = None, **kwargs) -> "EwaldSummation":
"""Create an EwaldSummation instance from JSON serialized dictionary.
Args:
d (Dict): Dictionary representation
fmt (str, optional): Unused. Defaults to None.
Returns:
EwaldSummation: class instance
"""
summation = cls(
structure=Structure.from_dict(d["structure"]),
Expand Down
14 changes: 7 additions & 7 deletions pymatgen/analysis/interface_reactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
import json
import os
import warnings
from typing import List, Tuple, Union
from typing import List, Literal, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from monty.dev import deprecated
from monty.json import MSONable
from pandas import DataFrame
from plotly.graph_objects import Scatter, Figure
from plotly.graph_objects import Figure, Scatter

from pymatgen.analysis.phase_diagram import PhaseDiagram, GrandPotentialPhaseDiagram
from pymatgen.analysis.phase_diagram import GrandPotentialPhaseDiagram, PhaseDiagram
from pymatgen.analysis.reaction_calculator import Reaction
from pymatgen.core.composition import Composition
from pymatgen.util.plotting import pretty_plot
from pymatgen.util.string import latexify, htmlify
from pymatgen.util.string import htmlify, latexify

__author__ = "Yihan Xiao, Matthew McDermott"
__maintainer__ = "Matthew McDermott"
Expand Down Expand Up @@ -195,14 +195,14 @@ def get_kinks(self) -> List[Tuple[int, float, float, Reaction, float]]:

return list(zip(index_kink, x_kink, energy_kink, react_kink, energy_per_rxt_formula))

def plot(self, backend: str = "plotly") -> Union[Figure, plt.Figure]:
def plot(self, backend: Literal["plotly", "matplotlib"] = "plotly") -> Union[Figure, plt.Figure]:
"""
Plots reaction energy as a function of mixing ratio x in self.c1 - self.c2
tie line.
Args:
backend: Plotting library used to create plot. Defaults to "plotly".
Can alternatively be set to "matplotlib"
backend ("plotly" | "matplotlib"): Plotting library used to create the plot. Defaults to
"plotly" but can also be "matplotlib".
Returns:
Plot of reaction energies as a function of mixing ratio
Expand Down
9 changes: 5 additions & 4 deletions pymatgen/analysis/magnetism/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
MagOrderingTransformation,
MagOrderParameterConstraint,
)
from pymatgen.transformations.standard_transformations import AutoOxiStateDecorationTransformation
from pymatgen.transformations.standard_transformations import (
AutoOxiStateDecorationTransformation,
)
from pymatgen.util.typing import VectorLike


__author__ = "Matthew Horton"
__copyright__ = "Copyright 2017, The Materials Project"
__version__ = "0.1"
Expand Down Expand Up @@ -291,7 +292,7 @@ def __init__(
if set_net_positive:
sign = np.sum(magmoms)
if sign < 0:
magmoms = -np.array(magmoms)
magmoms = [-x for x in magmoms]

structure.add_site_property("magmom", magmoms)

Expand Down Expand Up @@ -342,7 +343,7 @@ def _round_magmoms(magmoms: VectorLike, round_magmoms_mode: Union[int, float]) -
num_decimals = len(str(round_magmoms_mode).split(".")[1]) + 1
magmoms = np.around(magmoms, decimals=num_decimals)

return magmoms
return np.array(magmoms)

def get_structure_with_spin(self) -> Structure:
"""Returns a Structure with species decorated with spin values instead
Expand Down
32 changes: 19 additions & 13 deletions pymatgen/analysis/magnetism/jahnteller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import os
import warnings
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Literal, Optional, Tuple, Union, cast

import numpy as np

Expand Down Expand Up @@ -153,6 +153,8 @@ def get_analysis_and_structure(

if motif in ["oct", "tet"]:

motif = cast(Literal["oct", "tet"], motif) # mypy needs help

# guess spin of metal ion
if guesstimate_spin and "magmom" in site.properties:

Expand Down Expand Up @@ -435,20 +437,22 @@ def get_magnitude_of_effect_from_spin_config(motif: str, spin_config: Dict[str,
return magnitude

@staticmethod
def _estimate_spin_state(species: Union[str, Species], motif: str, known_magmom: float) -> str:
def _estimate_spin_state(
species: Union[str, Species], motif: Literal["oct", "tet"], known_magmom: float
) -> Literal["undefined", "low", "high", "unknown"]:
"""Simple heuristic to estimate spin state. If magnetic moment
is sufficiently close to that predicted for a given spin state,
we assign it that state. If we only have data for one spin
state then that's the one we use (e.g. we assume all tetrahedral
complexes are high-spin, since this is typically the case).
Args:
species: str or Species
motif: "oct" or "tet"
known_magmom: magnetic moment in Bohr magnetons
species: str or Species
motif ("oct" | "tet"): Tetrahedron or octahedron crystal site coordination
known_magmom: magnetic moment in Bohr magnetons
Returns: "undefined" (if only one spin state possible), "low",
"high" or "unknown"
Returns:
"undefined" (if only one spin state possible), "low", "high" or "unknown"
"""
mu_so_high = JahnTellerAnalyzer.mu_so(species, motif=motif, spin_state="high")
mu_so_low = JahnTellerAnalyzer.mu_so(species, motif=motif, spin_state="low")
Expand All @@ -469,18 +473,20 @@ def _estimate_spin_state(species: Union[str, Species], motif: str, known_magmom:
return "unknown"

@staticmethod
def mu_so(species: Union[str, Species], motif: str, spin_state: str) -> Optional[float]:
def mu_so(
species: Union[str, Species], motif: Literal["oct", "tet"], spin_state: Literal["high", "low"]
) -> Optional[float]:
"""Calculates the spin-only magnetic moment for a
given species. Only supports transition metals.
Args:
species: Species
motif: "oct" or "tet"
spin_state: "high" or "low"
species: Species
motif ("oct" | "tet"): Tetrahedron or octahedron crystal site coordination
spin_state ("low" | "high"): Whether the species is in a high or low spin state
Returns:
Spin-only magnetic moment in Bohr magnetons or None if
species crystal field not defined
float: Spin-only magnetic moment in Bohr magnetons or None if
species crystal field not defined
"""
try:
sp = get_el_sp(species)
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/analysis/phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import re
from functools import lru_cache
from typing import Literal

import numpy as np
import plotly.graph_objs as go
Expand All @@ -27,7 +28,7 @@
from pymatgen.entries import Entry
from pymatgen.util.coord import Simplex, in_coord_list
from pymatgen.util.plotting import pretty_plot
from pymatgen.util.string import latexify, htmlify
from pymatgen.util.string import htmlify, latexify

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1616,7 +1617,7 @@ def __init__(
self,
phasediagram: PhaseDiagram,
show_unstable: float = 0.2,
backend: str = "plotly",
backend: Literal["plotly", "matplotlib"] = "plotly",
**plotkwargs,
):
"""
Expand All @@ -1625,8 +1626,7 @@ def __init__(
show_unstable (float): Whether unstable (above the hull) phases will be
plotted. If a number > 0 is entered, all phases with
e_hull < show_unstable (eV/atom) will be shown.
backend (str): Python package used for plotting ("matplotlib" or
"plotly"). Defaults to "plotly".
backend ("plotly" | "matplotlib"): Python package used for plotting. Defaults to "plotly".
**plotkwargs (dict): Keyword args passed to matplotlib.pyplot.plot. Can
be used to customize markers etc. If not set, the default is
{
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/analysis/xas/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""
import math
import warnings
from typing import List
from typing import List, Literal

import numpy as np
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -102,7 +102,7 @@ def __str__(self):
super().__str__(),
)

def stitch(self, other: "XAS", num_samples: int = 500, mode: str = "XAFS") -> "XAS":
def stitch(self, other: "XAS", num_samples: int = 500, mode: Literal["XAFS", "L23"] = "XAFS") -> "XAS":
"""
Stitch XAS objects to get the full XAFS spectrum or L23 edge XANES
spectrum depending on the mode.
Expand All @@ -120,8 +120,8 @@ def stitch(self, other: "XAS", num_samples: int = 500, mode: str = "XAFS") -> "X
Args:
other: Another XAS object.
num_samples(int): Number of samples for interpolation.
mode(str): Either XAFS mode for stitching XANES and EXAFS
or L23 mode for stitching L2 and L3.
mode("XAFS" | "L23"): Either XAFS mode for stitching XANES and EXAFS
or L23 mode for stitching L2 and L3.
Returns:
XAS object: The stitched spectrum.
Expand Down
26 changes: 16 additions & 10 deletions pymatgen/core/periodic_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from io import open
from itertools import combinations, product
from pathlib import Path
from typing import Callable, Optional, Union, Dict, Tuple, List
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
from monty.json import MSONable

from pymatgen.core.units import SUPPORTED_UNIT_NAMES, FloatWithUnit, Length, Mass, Unit
from pymatgen.util.string import formula_double_format, Stringify
from pymatgen.util.string import Stringify, formula_double_format

# Loads element data from json file
with open(str(Path(__file__).absolute().parent / "periodic_table.json"), "rt") as f:
Expand Down Expand Up @@ -1216,7 +1216,12 @@ def get_nmr_quadrupole_moment(self, isotope: Optional[str] = None) -> float:
raise ValueError("No quadrupole moment for isotope {}".format(isotope))
return quad_mom.get(isotope, 0.0)

def get_shannon_radius(self, cn: str, spin: str = "", radius_type: str = "ionic") -> float:
def get_shannon_radius(
self,
cn: str,
spin: Literal["", "Low Spin", "High Spin"] = "",
radius_type: Literal["ionic", "crystal"] = "ionic",
) -> float:
"""
Get the local environment specific ionic radius for species.
Expand All @@ -1239,29 +1244,30 @@ def get_shannon_radius(self, cn: str, spin: str = "", radius_type: str = "ionic"
k, data = list(radii.items())[0] # type: ignore
if k != spin:
warnings.warn(
"Specified spin state of %s not consistent with database "
"spin of %s. Only one spin data available, and "
"that value is returned." % (spin, k)
f"Specified spin state of {spin} not consistent with database "
f"spin of {k}. Only one spin data available, and that value is returned."
)
else:
data = radii[spin]
return data["%s_radius" % radius_type]

def get_crystal_field_spin(self, coordination: str = "oct", spin_config: str = "high") -> float:
def get_crystal_field_spin(
self, coordination: Literal["oct", "tet"] = "oct", spin_config: Literal["low", "high"] = "high"
) -> float:
"""
Calculate the crystal field spin based on coordination and spin
configuration. Only works for transition metal species.
Args:
coordination (str): Only oct and tet are supported at the moment.
spin_config (str): Supported keywords are "high" or "low".
coordination ("oct" | "tet"): Tetrahedron or octahedron crystal site coordination
spin_config ("low" | "high"): Whether the species is in a high or low spin state
Returns:
Crystal field spin in Bohr magneton.
Raises:
AttributeError if species is not a valid transition metal or has
an invalid oxidation state.
an invalid oxidation state.
ValueError if invalid coordination or spin_config.
"""
if coordination not in ("oct", "tet") or spin_config not in ("high", "low"):
Expand Down
8 changes: 4 additions & 4 deletions pymatgen/core/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
x y value pairs.
"""

from typing import List, Union, Callable
from typing import Callable, List, Literal, Union

import numpy as np
from monty.json import MSONable
Expand Down Expand Up @@ -74,13 +74,13 @@ def __getattr__(self, item):
def __len__(self):
return self.ydim[0]

def normalize(self, mode: str = "max", value: float = 1.0):
def normalize(self, mode: Literal["max", "sum"] = "max", value: float = 1.0):
"""
Normalize the spectrum with respect to the sum of intensity
Args:
mode (str): Normalization mode. Supported modes are "max" (set the
max y value to value, e.g., in XRD patterns), "sum" (set the
mode ("max" | "sum"): Normalization mode. Supported modes are "max" (set
the max y value to value, e.g., in XRD patterns), "sum" (set the
sum of y to a value, i.e., like a probability density).
value (float): Value to normalize to. Defaults to 1.
"""
Expand Down
17 changes: 9 additions & 8 deletions pymatgen/core/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,13 +1642,15 @@ def get_sorted_structure(
sites = sorted(self, key=key, reverse=reverse)
return self.__class__.from_sites(sites, charge=self._charge)

def get_reduced_structure(self, reduction_algo: str = "niggli") -> Union["IStructure", "Structure"]:
def get_reduced_structure(
self, reduction_algo: Literal["niggli", "LLL"] = "niggli"
) -> Union["IStructure", "Structure"]:
"""
Get a reduced structure.
Args:
reduction_algo (str): The lattice reduction algorithm to use.
Currently supported options are "niggli" or "LLL".
reduction_algo ("niggli" | "LLL"): The lattice reduction algorithm to use.
Defaults to "niggli".
"""
if reduction_algo == "niggli":
reduced_latt = self._lattice.get_niggli_reduced_lattice()
Expand Down Expand Up @@ -3731,20 +3733,19 @@ def scale_lattice(self, volume: float):
"""
self.lattice = self._lattice.scale(volume)

def merge_sites(self, tol: float = 0.01, mode: str = "sum"):
def merge_sites(self, tol: float = 0.01, mode: Literal["sum", "delete", "average"] = "sum"):
"""
Merges sites (adding occupancies) within tol of each other.
Removes site properties.
Args:
tol (float): Tolerance for distance to merge sites.
mode (str): Three modes supported. "delete" means duplicate sites are
mode ('sum' | 'delete' | 'average'): "delete" means duplicate sites are
deleted. "sum" means the occupancies are summed for the sites.
"average" means that the site is deleted but the properties are averaged
Only first letter is considered.
"""
mode = mode.lower()[0]
from scipy.cluster.hierarchy import fcluster, linkage
from scipy.spatial.distance import squareform

Expand All @@ -3759,13 +3760,13 @@ def merge_sites(self, tol: float = 0.01, mode: str = "sum"):
props = self[inds[0]].properties
for n, i in enumerate(inds[1:]):
sp = self[i].species
if mode == "s":
if mode.lower()[0] == "s":
species += sp
offset = self[i].frac_coords - coords
coords = coords + ((offset - np.round(offset)) / (n + 2)).astype(coords.dtype)
for key in props.keys():
if props[key] is not None and self[i].properties[key] != props[key]:
if mode == "a" and isinstance(props[key], float):
if mode.lower()[0] == "a" and isinstance(props[key], float):
# update a running total
props[key] = props[key] * (n + 1) / (n + 2) + self[i].properties[key] / (n + 2)
else:
Expand Down

0 comments on commit 05c0b2b

Please sign in to comment.