Skip to content

Commit

Permalink
replace set methods (union|intersection|difference) with arithmetic o…
Browse files Browse the repository at this point in the history
…perators
  • Loading branch information
janosh committed Jan 16, 2023
1 parent 5f8a9fc commit b3bbaf0
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 38 deletions.
4 changes: 2 additions & 2 deletions pymatgen/analysis/dimensionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def rank(vertices):

def rank_increase(seen, candidate):
rank0 = len(seen) - 1
rank1 = rank(seen.union({candidate}))
rank1 = rank(seen | {candidate})
return rank1 > rank0

connected_sites = {i: neighbours(i) for i in range(bonded_structure.structure.num_sites)}
Expand Down Expand Up @@ -482,7 +482,7 @@ def find_clusters(struct, connected_matrix):

def visit(atom, atom_cluster):
visited[atom] = True
new_cluster = set(np.where(connected_matrix[atom] != 0)[0]).union(atom_cluster)
new_cluster = set(np.where(connected_matrix[atom] != 0)[0]) | {atom_cluster}
atom_cluster = new_cluster
for new_atom in atom_cluster:
if not visited[new_atom]:
Expand Down
10 changes: 5 additions & 5 deletions pymatgen/analysis/functional_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def link_marked_atoms(self, atoms):
# Graph representation of only marked atoms
subgraph = self.molgraph.graph.subgraph(list(atoms)).to_undirected()

func_grps = []
func_groups = []
for func_grp in nx.connected_components(subgraph):
grp_hs = set()
for node in func_grp:
Expand All @@ -219,11 +219,11 @@ def link_marked_atoms(self, atoms):
# Add all associated hydrogens into the functional group
if neighbor in hydrogens:
grp_hs.add(neighbor)
func_grp = func_grp.union(grp_hs)
func_grp = func_grp | grp_hs

func_grps.append(func_grp)
func_groups.append(func_grp)

return func_grps
return func_groups

def get_basic_functional_groups(self, func_groups=None):
"""
Expand Down Expand Up @@ -305,7 +305,7 @@ def get_all_functional_groups(self, elements=None, func_groups=None, catch_basic
"""
heteroatoms = self.get_heteroatoms(elements=elements)
special_cs = self.get_special_carbon(elements=elements)
groups = self.link_marked_atoms(heteroatoms.union(special_cs))
groups = self.link_marked_atoms(heteroatoms | special_cs)

if catch_basic:
groups += self.get_basic_functional_groups(func_groups=func_groups)
Expand Down
13 changes: 7 additions & 6 deletions pymatgen/analysis/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,9 +1376,10 @@ def __len__(self):
def sort(self, key=None, reverse=False):
"""
Same as Structure.sort(), also remaps nodes in graph.
:param key:
:param reverse:
:return:
Args:
key: key to sort by
reverse: reverse sort order
"""
old_structure = self.structure.copy()

Expand Down Expand Up @@ -1490,7 +1491,7 @@ def diff(self, other, strict=True):
if len(edges) == 0 and len(edges_other) == 0:
jaccard_dist = 0 # by definition
else:
jaccard_dist = 1 - len(edges.intersection(edges_other)) / len(edges.union(edges_other))
jaccard_dist = 1 - len(edges ^ edges_other) / len(edges | edges_other)

return {
"self": edges - edges_other,
Expand Down Expand Up @@ -2930,11 +2931,11 @@ def diff(self, other, strict=True):
if len(edges) == 0 and len(edges_other) == 0:
jaccard_dist = 0 # by definition
else:
jaccard_dist = 1 - len(edges.intersection(edges_other)) / len(edges.union(edges_other))
jaccard_dist = 1 - len(edges ^ edges_other) / len(edges | edges_other)

return {
"self": edges - edges_other,
"other": edges_other - edges,
"both": edges.intersection(edges_other),
"both": edges ^ edges_other,
"dist": jaccard_dist,
}
2 changes: 1 addition & 1 deletion pymatgen/analysis/interface_reactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def products(self):
"""
products = set()
for _, _, _, react, _ in self.get_kinks():
products = products.union({k.reduced_formula for k in react.products})
products = products | {k.reduced_formula for k in react.products}
return list(products)


Expand Down
2 changes: 1 addition & 1 deletion pymatgen/analysis/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def _get_nn_shell_info(
raise ValueError("Shell must be positive")

# Append this site to the list of previously-visited sites
_previous_steps = _previous_steps.union({(site_idx, _cur_image)})
_previous_steps = _previous_steps | {(site_idx, _cur_image)}

# Get all the neighbors of this site
possible_steps = list(all_nn_info[site_idx])
Expand Down
15 changes: 7 additions & 8 deletions pymatgen/analysis/phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,16 +895,15 @@ def get_decomp_and_phase_separation_energy(
"additional unstable entries"
)

reduced_space = (
set(competing_entries)
.difference(self._get_stable_entries_in_space(entry_elems))
.union(self.el_refs.values())
)
reduced_space = competing_entries - {self._get_stable_entries_in_space(entry_elems)} | {
self.el_refs.values()
}

# NOTE calling PhaseDiagram is only reasonable if the composition has fewer than 5 elements
# TODO can we call PatchedPhaseDiagram in the here?
# TODO can we call PatchedPhaseDiagram here?
inner_hull = PhaseDiagram(reduced_space)

competing_entries = inner_hull.stable_entries.union(self._get_stable_entries_in_space(entry_elems))
competing_entries = inner_hull.stable_entries | {self._get_stable_entries_in_space(entry_elems)}
competing_entries = {c for c in compare_entries if id(c) not in same_comp_mem_ids}

if len(competing_entries) > space_limit:
Expand Down Expand Up @@ -1593,7 +1592,7 @@ def __init__(
# Add terminal elements as we may not have PD patches including them
# NOTE add el_refs in case no multielement entries are present for el
_stable_entries = {se for pd in self.pds.values() for se in pd._stable_entries}
self._stable_entries = tuple(_stable_entries.union(self.el_refs.values()))
self._stable_entries = tuple(_stable_entries | {*self.el_refs.values()})
self._stable_spaces = tuple(frozenset(e.composition.elements) for e in self._stable_entries)

def __repr__(self):
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/analysis/tests/test_functional_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def test_link_marked_atoms(self):
heteroatoms = self.extractor.get_heteroatoms()
special_cs = self.extractor.get_special_carbon()

link = self.extractor.link_marked_atoms(heteroatoms.union(special_cs))
link = self.extractor.link_marked_atoms(heteroatoms | special_cs)

self.assertEqual(len(link), 1)
self.assertEqual(len(link[0]), 9)

# Exclude Oxygen-related functional groups
heteroatoms_no_o = self.extractor.get_heteroatoms(elements=["N"])
special_cs_no_o = self.extractor.get_special_carbon(elements=["N"])
all_marked = heteroatoms_no_o.union(special_cs_no_o)
all_marked = heteroatoms_no_o | special_cs_no_o

link_no_o = self.extractor.link_marked_atoms(all_marked)

Expand All @@ -118,7 +118,7 @@ def test_get_all_functional_groups(self):
heteroatoms = self.extractor.get_heteroatoms()
special_cs = self.extractor.get_special_carbon()

link = self.extractor.link_marked_atoms(heteroatoms.union(special_cs))
link = self.extractor.link_marked_atoms(heteroatoms | special_cs)
basics = self.extractor.get_basic_functional_groups()

all_func = self.extractor.get_all_functional_groups()
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/apps/battery/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def get_removals_int_oxid(self):

numa = set()
for oxid_el in oxid_els:
numa = numa.union(self._get_int_removals_helper(self.comp.copy(), oxid_el, oxid_els, numa))
numa = numa | self._get_int_removals_helper(self.comp.copy(), oxid_el, oxid_els, numa)
# convert from num A in structure to num A removed
num_working_ion = self.comp[Species(self.working_ion.symbol, self.working_ion_charge)]
return {num_working_ion - a for a in numa}
Expand Down Expand Up @@ -235,13 +235,13 @@ def _get_int_removals_helper(self, spec_amts_oxi, redox_el, redox_els, numa):
spec.oxi_state * spec_amts_oxi[spec] for spec in spec_amts_oxi if spec.symbol not in self.working_ion.symbol
)
a = max(0, -oxi_noA / self.working_ion_charge)
numa = numa.union({a})
numa = numa | {a}

# recursively try the other oxidation states
if a == 0:
return numa
for red in redox_els:
numa = numa.union(self._get_int_removals_helper(spec_amts_oxi.copy(), red, redox_els, numa))
numa = numa | self._get_int_removals_helper(spec_amts_oxi.copy(), red, redox_els, numa)
return numa


Expand Down
7 changes: 5 additions & 2 deletions pymatgen/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ def from_weight_dict(cls, weight_dict) -> Composition:
Args:
weight_dict (dict): {symbol: weight_fraction} dict.
Returns:
Composition
"""

weight_sum = sum([val / Element(el).atomic_mass for el, val in weight_dict.items()])
Expand Down Expand Up @@ -1275,13 +1278,13 @@ def __truediv__(self, other: object) -> ChemicalPotential:

def __sub__(self, other: object) -> ChemicalPotential:
if isinstance(other, ChemicalPotential):
els = set(self).union(other)
els = {*self} | {other}
return ChemicalPotential({e: self.get(e, 0) - other.get(e, 0) for e in els})
return NotImplemented

def __add__(self, other: object) -> ChemicalPotential:
if isinstance(other, ChemicalPotential):
els = set(self).union(other)
els = {*self} | {other}
return ChemicalPotential({e: self.get(e, 0) + other.get(e, 0) for e in els})
return NotImplemented

Expand Down
4 changes: 2 additions & 2 deletions pymatgen/ext/optimade.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,10 @@ def _handle_response_fields(self, additional_response_fields: str | list[str] |
A string of comma-separated OPTIMADE response fields.
"""
if isinstance(additional_response_fields, str):
additional_response_fields = [additional_response_fields]
additional_response_fields = {additional_response_fields}
if not additional_response_fields:
additional_response_fields = set()
return ",".join(set(additional_response_fields).union(self.mandatory_response_fields))
return ",".join({additional_response_fields} | self.mandatory_response_fields)

def refresh_aliases(self, providers_url="https://providers.optimade.org/providers.json"):
"""
Expand Down
9 changes: 4 additions & 5 deletions pymatgen/io/abinit/abitimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class AbinitTimerParser(collections.abc.Iterable):
Assume the Abinit output files have been produced with `timopt -1`.
Example:
parser = AbinitTimerParser()
parser.parse(list_of_files)
Expand Down Expand Up @@ -223,10 +222,10 @@ def section_names(self, ordkey="wall_time"):
if idx == 0:
section_names = [s.name for s in timer.order_sections(ordkey)]
# check = section_names
# else:
# new_set = set( [s.name for s in timer.order_sections(ordkey)])
# section_names.intersection_update(new_set)
# check = check.union(new_set)
# else:
# new_set = {s.name for s in timer.order_sections(ordkey)}
# section_names.intersection_update(new_set)
# check = check | new_set

# if check != section_names:
# print("sections", section_names)
Expand Down

0 comments on commit b3bbaf0

Please sign in to comment.