Skip to content

Commit

Permalink
Fix mypy issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
Shyue Ping Ong committed Dec 14, 2022
1 parent 5923fb5 commit 7b668b7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 39 deletions.
34 changes: 17 additions & 17 deletions pymatgen/analysis/diffusion/aimd/rdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"""
RDF implementation.
"""
from __future__ import annotations

from collections import Counter
from math import ceil
from multiprocessing import cpu_count
from typing import List, Tuple, Union

import numpy as np
from joblib import Parallel, delayed
Expand All @@ -28,17 +28,17 @@ class RadialDistributionFunction:

def __init__(
self,
structures: List,
indices: List,
reference_indices: List,
structures: list,
indices: list,
reference_indices: list,
ngrid: int = 101,
rmax: float = 10.0,
cell_range: int = 1,
sigma: float = 0.1,
):
"""
Args:
structures ([Structure]): List of structure
structures ([Structure]): list of structure
objects with the same composition. Allow for ensemble averaging.
ngrid (int): Number of radial grid points.
rmax (float): Maximum of radial grid (the minimum is always zero)
Expand Down Expand Up @@ -147,19 +147,19 @@ def __init__(
@classmethod
def from_species(
cls,
structures: List,
structures: list,
ngrid: int = 101,
rmax: float = 10.0,
cell_range: int = 1,
sigma: float = 0.1,
species: Union[Tuple, List] = ("Li", "Na"),
reference_species: Union[Tuple, List] = None,
species: tuple | list = ("Li", "Na"),
reference_species: tuple | list | None = None,
):
"""
Initialize using species.
Args:
structures (list of pmg_structure objects): List of structure
structures (list of pmg_structure objects): list of structure
objects with the same composition. Allow for ensemble averaging.
ngrid (int): Number of radial grid points.
rmax (float): Maximum of radial grid (the minimum is always zero).
Expand Down Expand Up @@ -206,7 +206,7 @@ def coordination_number(self):

def get_rdf_plot(
self,
label: str = None,
label: str | None = None,
xlim: tuple = (0.0, 8.0),
ylim: tuple = (-0.005, 3.0),
loc_peak: bool = False,
Expand Down Expand Up @@ -284,7 +284,7 @@ class RadialDistributionFunctionFast:

def __init__(
self,
structures: Union[Structure, List[Structure]],
structures: Structure | list[Structure],
rmin: float = 0.0,
rmax: float = 10.0,
ngrid: float = 101,
Expand Down Expand Up @@ -346,7 +346,7 @@ def __init__(
elements = np.array([str(i.specie) for i in structures[0]]) # type: ignore
self.center_elements = [elements[i] for i in self.center_indices]
self.neighbor_elements = [elements[i] for i in self.neighbor_indices]
self.density = [{}] * len(self.structures) # type: List[Dict]
self.density = [{}] * len(self.structures) # type: list[dict]

self.natoms = [i.composition.to_data_dict["unit_cell_composition"] for i in self.structures]

Expand Down Expand Up @@ -377,8 +377,8 @@ def _dist_to_counts(self, d):

def get_rdf(
self,
ref_species: Union[str, List[str]],
species: Union[str, List[str]],
ref_species: str | list[str],
species: str | list[str],
is_average=True,
):
"""
Expand Down Expand Up @@ -406,8 +406,8 @@ def get_rdf(

def get_one_rdf(
self,
ref_species: Union[str, List[str]],
species: Union[str, List[str]],
ref_species: str | list[str],
species: str | list[str],
index=0,
):
"""
Expand Down Expand Up @@ -472,7 +472,7 @@ def get_coordination_number(self, ref_species, species, is_average=True):
return self.r, cn


def _get_neighbor_list(structure, r) -> Tuple:
def _get_neighbor_list(structure, r) -> tuple:
"""
Thin wrapper to enable parallel calculations
Expand Down
26 changes: 14 additions & 12 deletions pymatgen/analysis/diffusion/aimd/van_hove.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""
Van Hove analysis for correlations.
"""
from __future__ import annotations

import itertools
from collections import Counter
from typing import Callable, List, Tuple, Union
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -17,6 +18,7 @@
from pymatgen.core import Structure
from pymatgen.util.plotting import pretty_plot
from pymatgen.analysis.diffusion.analyzer import DiffusionAnalyzer
from pymatgen.util.typing import ArrayLike

from .rdf import RadialDistributionFunction

Expand Down Expand Up @@ -46,9 +48,9 @@ def __init__(
step_skip: int = 50,
sigma: float = 0.1,
cell_range: int = 1,
species: Union[Tuple, List] = ("Li", "Na"),
reference_species: Union[Tuple, List] = None,
indices: List = None,
species: tuple | list = ("Li", "Na"),
reference_species: tuple | list | None = None,
indices: list | None = None,
):
"""
Initiation.
Expand Down Expand Up @@ -197,7 +199,7 @@ def __init__(
# time interval (in ps) in gsrt and gdrt.
self.timeskip = self.obj.time_step * self.obj.step_skip * step_skip / 1000.0

def get_3d_plot(self, figsize: Tuple = (12, 8), mode: str = "distinct"):
def get_3d_plot(self, figsize: tuple = (12, 8), mode: str = "distinct"):
"""
Plot 3D self-part or distinct-part of van Hove function, which is
specified by the input argument 'type'.
Expand Down Expand Up @@ -240,7 +242,7 @@ def get_3d_plot(self, figsize: Tuple = (12, 8), mode: str = "distinct"):

return plt

def get_1d_plot(self, mode: str = "distinct", times: List = [0.0], colors: List = None):
def get_1d_plot(self, mode: str = "distinct", times: list = [0.0], colors: list | None = None):
"""
Plot the van Hove function at given r or t.
Expand Down Expand Up @@ -292,7 +294,7 @@ class EvolutionAnalyzer:
Analyze the evolution of structures during AIMD simulations.
"""

def __init__(self, structures: List, rmax: float = 10, step: int = 1, time_step: int = 2):
def __init__(self, structures: list, rmax: float = 10, step: int = 1, time_step: int = 2):
"""
Initialization the EvolutionAnalyzer from MD simulations. From the
structures obtained from MD simulations, we can analyze the structure
Expand Down Expand Up @@ -337,7 +339,7 @@ def get_pairs(structure: Structure):
return list(pairs)

@staticmethod
def rdf(structure: Structure, pair: Tuple, ngrid: int = 101, rmax: float = 10):
def rdf(structure: Structure, pair: tuple, ngrid: int = 101, rmax: float = 10):
"""
Process rdf from a given structure and pair.
Expand Down Expand Up @@ -404,7 +406,7 @@ def atom_dist(

return np.array(density)

def get_df(self, func: Callable, save_csv: str = None, **kwargs):
def get_df(self, func: Callable, save_csv: str | None = None, **kwargs):
"""
Get the data frame for a given pair. This step would be very slow if
there are hundreds or more structures to parse.
Expand Down Expand Up @@ -466,8 +468,8 @@ def get_min_dist(df: pds.DataFrame, tol: float = 1e-10):
@staticmethod
def plot_evolution_from_data(
df: pds.DataFrame,
x_label: str = None,
cb_label: str = None,
x_label: str | None = None,
cb_label: str | None = None,
cmap=plt.cm.plasma, # pylint: disable=E1101
):
"""
Expand Down Expand Up @@ -522,7 +524,7 @@ def plot_evolution_from_data(

def plot_rdf_evolution(
self,
pair: Tuple,
pair: tuple,
cmap=plt.cm.plasma, # pylint: disable=E1101
df: pds.DataFrame = None,
):
Expand Down
21 changes: 11 additions & 10 deletions pymatgen/analysis/diffusion/neb/full_path_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
Migraiton Graph Analysis
"""
from __future__ import annotations

__author__ = "Jimmy Shen"
__copyright__ = "Copyright 2019, The Materials Project"
Expand All @@ -14,7 +15,7 @@
import operator
from copy import deepcopy
from itertools import starmap
from typing import Callable, Dict, List, Union
from typing import Callable

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -209,16 +210,16 @@ def with_distance(

@staticmethod
def get_structure_from_entries(
entries: List[ComputedStructureEntry],
entries: list[ComputedStructureEntry],
migrating_ion_entry: ComputedEntry,
**kwargs,
) -> List[Structure]:
) -> list[Structure]:
"""
Read in a list of base entries and inserted entries. Return a list of structures that contains metastable
sites for the migration species decorated with a "insertion_energy" property.
Args:
entries: List of entries, must contain a mixture of inserted and empty structures.
entries: list of entries, must contain a mixture of inserted and empty structures.
migrating_ion_entry: The metallic phase of the working ion, used to calculate insertion energies.
Additional Kwargs:
Expand Down Expand Up @@ -322,7 +323,7 @@ def _group_and_label_hops(self):

def add_data_to_similar_edges(
self,
target_label: Union[int, str],
target_label: int | str,
data: dict,
m_hop: MigrationHop = None,
):
Expand Down Expand Up @@ -376,7 +377,7 @@ def get_path(self, max_val=100000, flip_hops=True):
If false, hops will retain their original orientation
from the migration graph.
Returns:
Generator for List of Dicts:
Generator for list of Dicts:
Each dict contains the information of a hop
"""

Expand Down Expand Up @@ -432,7 +433,7 @@ def get_path(self, max_val=100000, flip_hops=True):
else:
yield u, path_hops

def get_summary_dict(self, added_keys: List[str] = None) -> dict:
def get_summary_dict(self, added_keys: list[str] = None) -> dict:
"""
Dictionary format, for saving to database
"""
Expand Down Expand Up @@ -698,7 +699,7 @@ def get_least_chg_path(self):
min_path = path
return min_path

def get_summary_dict(self, add_keys: List[str] = None):
def get_summary_dict(self, add_keys: list[str] = None):
"""
Dictionary format, for saving to database
"""
Expand Down Expand Up @@ -737,7 +738,7 @@ def _shift_grid(vv):
return vv + step / 2.0


def get_hop_site_sequence(hop_list: List[Dict], start_u: Union[int, str], key: str = None) -> List:
def get_hop_site_sequence(hop_list: list[dict], start_u: int | str, key: str | None = None) -> list:
"""
Read in a list of hop dictionaries and print the sequence of sites (and relevant property values if any).
Args:
Expand Down Expand Up @@ -772,7 +773,7 @@ def get_hop_site_sequence(hop_list: List[Dict], start_u: Union[int, str], key: s
return site_seq


def order_path(hop_list: List[Dict], start_u: Union[int, str]) -> List[Dict]:
def order_path(hop_list: list[dict], start_u: int | str) -> list[dict]:
"""
Takes a list of hop dictionaries and flips hops (switches isite and esite)
as needed to form a coherent path / sequence of sites according to
Expand Down

0 comments on commit 7b668b7

Please sign in to comment.