From 937eb0367b326d0f69d1d28f46b5b75687a68883 Mon Sep 17 00:00:00 2001 From: Patrick Austin Date: Wed, 12 Apr 2023 11:45:29 +0000 Subject: [PATCH] Replace instances of mutable default arguments #73 --- pymuonsuite/dipolar/field.py | 7 +++--- pymuonsuite/io/castep.py | 9 +------- pymuonsuite/io/dftb.py | 8 ++----- pymuonsuite/io/readwrite.py | 26 ++++++++++++++++------ pymuonsuite/io/uep.py | 2 +- pymuonsuite/quantum/vibrational/phonons.py | 4 ++-- pymuonsuite/quantum/vibrational/schemes.py | 4 +++- 7 files changed, 32 insertions(+), 28 deletions(-) diff --git a/pymuonsuite/dipolar/field.py b/pymuonsuite/dipolar/field.py index bfcdea4..a26e19f 100644 --- a/pymuonsuite/dipolar/field.py +++ b/pymuonsuite/dipolar/field.py @@ -72,11 +72,12 @@ def __init__( self, atoms, mu_pos, - isotopes={}, + isotopes=None, isotope_list=None, cutoff=10, overlap_eps=1e-3, ): + isotopes = {} if isotopes is None else isotopes # Get positions, cell, and species, only things we care about self.cell = np.array(atoms.get_cell()) @@ -147,7 +148,7 @@ def set_moments(self, moments, moment_type="e"): def dipten(self): return np.sum(self.spins[:, None, None] * self._dT, axis=0) - def frequency(self, axis=[0, 0, 1]): + def frequency(self, axis=(0, 0, 1)): D = self.dipten() return np.sum(np.dot(D, axis) * axis) @@ -173,7 +174,7 @@ def pwd_spec(self, width=None, h_steps=100, nsteps=100): return om, spec - def random_spec_uniaxial(self, axis=[0, 0, 1], width=None, h_steps=100, occ=1.0): + def random_spec_uniaxial(self, axis=(0, 0, 1), width=None, h_steps=100, occ=1.0): # Consider individual dipolar constants DD = self.spins[:, None, None] * self._dT diff --git a/pymuonsuite/io/castep.py b/pymuonsuite/io/castep.py index 6d89d27..b96193f 100644 --- a/pymuonsuite/io/castep.py +++ b/pymuonsuite/io/castep.py @@ -21,7 +21,7 @@ class ReadWriteCastep(ReadWrite): - def __init__(self, params={}, script=None, calc=None): + def __init__(self, params=None, script=None, calc=None): """ | params (dict) Contains muon symbol, parameter file, | k_points_grid. @@ -40,13 +40,6 @@ def __init__(self, params={}, script=None, calc=None): if calc is not None and params != {}: self._create_calculator() - def _validate_params(self, params): - if not (isinstance(params, dict)): - raise ValueError("params should be a dict, not ", type(params)) - return - else: - return params - def set_params(self, params): """ | params (dict) Contains muon symbol, parameter file, diff --git a/pymuonsuite/io/dftb.py b/pymuonsuite/io/dftb.py index d1bbd7a..acfebfb 100644 --- a/pymuonsuite/io/dftb.py +++ b/pymuonsuite/io/dftb.py @@ -38,7 +38,7 @@ class ReadWriteDFTB(ReadWrite): - def __init__(self, params={}, script=None, calc=None): + def __init__(self, params=None, script=None, calc=None): """ | Args: | params (dict): Contains dftb_set, k_points_grid, @@ -70,11 +70,7 @@ def set_params(self, params): | charged are also required in the case | of writing geom_opt input files """ - if not (isinstance(params, dict)): - raise ValueError("params should be a dict, not ", type(params)) - return - - self.params = deepcopy(params) + self.params = deepcopy(self._validate_params(params)) # resetting this to None makes sure that the calc is recreated after # the params are updated: self._calc_type = None diff --git a/pymuonsuite/io/readwrite.py b/pymuonsuite/io/readwrite.py index 6129170..e292d66 100644 --- a/pymuonsuite/io/readwrite.py +++ b/pymuonsuite/io/readwrite.py @@ -5,7 +5,7 @@ class ReadWrite(object): - def __init__(self, params={}, script=None, calc=None): + def __init__(self, params=None, script=None, calc=None): """ | params (dict) parameters for writing input files | script (str): Path to a file containing a submission @@ -19,7 +19,23 @@ def __init__(self, params={}, script=None, calc=None): """ self._calc = calc self.script = script - self.params = params + self.params = self._validate_params(params) + + def _validate_params(self, params: dict) -> dict: + """ + | Args: + | params (dict): dict of parameters to validate + | Returns: + | (dict): params, or an empty dict if they were None + | Raises: + | TypeError: params is neither None nor a dict + """ + if params is None: + return {} + elif isinstance(params, dict): + return params + else: + raise TypeError(f"params should be a dict, not {type(params)}") def read(self, folder, sname=None): raise ( @@ -40,11 +56,7 @@ def set_params(self, params): | params (dict) Contains muon symbol, parameter file, | k_points_grid. """ - if not (isinstance(params, dict)): - raise ValueError("params should be a dict, not ", type(params)) - return - else: - self.params = params + self.params = self._validate_params(params) def set_script(self, script): """ diff --git a/pymuonsuite/io/uep.py b/pymuonsuite/io/uep.py index 5e3941d..f4d1a90 100644 --- a/pymuonsuite/io/uep.py +++ b/pymuonsuite/io/uep.py @@ -10,7 +10,7 @@ class ReadWriteUEP(ReadWrite): - def __init__(self, params={}, script=None): + def __init__(self, params=None, script=None): self.set_script(script) self.set_params(params) diff --git a/pymuonsuite/quantum/vibrational/phonons.py b/pymuonsuite/quantum/vibrational/phonons.py index a5ee11a..743e2c3 100644 --- a/pymuonsuite/quantum/vibrational/phonons.py +++ b/pymuonsuite/quantum/vibrational/phonons.py @@ -21,7 +21,7 @@ def ase_phonon_calc( struct, calc=None, - kpoints=[1, 1, 1], + kpoints=(1, 1, 1), ftol=0.01, force_clean=False, name="asephonon", @@ -36,7 +36,7 @@ def ase_phonon_calc( | calc (ase.Calculator): Calculator for energies and forces (if not | present, use the one from struct) | kpoints (np.ndarray): Kpoint grid for phonon calculation. If None, just - | do a Vibration modes calculation (default is [1,1,1]) + | do a Vibration modes calculation (default is (1,1,1)) | ftol (float): Tolerance for geometry optimisation (default | is 0.01 eV/Ang) | force_clean (bool): If True, force a deletion of all phonon files diff --git a/pymuonsuite/quantum/vibrational/schemes.py b/pymuonsuite/quantum/vibrational/schemes.py index 3c0a53d..f5e2bd8 100644 --- a/pymuonsuite/quantum/vibrational/schemes.py +++ b/pymuonsuite/quantum/vibrational/schemes.py @@ -126,7 +126,9 @@ def save(self, file): def load(file): return pickle.load(open(file)) - def recalc_all(self, displ_args={}, weights_args={}): + def recalc_all(self, displ_args: dict = None, weights_args: dict = None): + displ_args = {} if displ_args is None else displ_args + weights_args = {} if weights_args is None else weights_args self.recalc_displacements(**displ_args) self.recalc_weights(**weights_args)