Skip to content

Commit

Permalink
Add PhononBandStructureSymmLine.__eq__ method
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Feb 23, 2024
1 parent 50efc02 commit 8b63b10
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
23 changes: 17 additions & 6 deletions pymatgen/phonon/bandstructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class PhononBandStructure(MSONable):

def __init__(
self,
qpoints: list[Kpoint],
qpoints: Sequence[Kpoint],
frequencies: ArrayLike,
lattice: Lattice,
nac_frequencies: Sequence[Sequence] | None = None,
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
q_pt, lattice, label=label, coords_are_cartesian=coords_are_cartesian
)
self.qpoints += [Kpoint(q_pt, lattice, label=label, coords_are_cartesian=coords_are_cartesian)]
self.bands = frequencies
self.bands = np.asarray(frequencies)
self.nb_bands = len(self.bands)
self.nb_qpoints = len(self.qpoints)

Expand Down Expand Up @@ -342,7 +342,7 @@ class PhononBandStructureSymmLine(PhononBandStructure):

def __init__(
self,
qpoints: list[Kpoint],
qpoints: Sequence[Kpoint],
frequencies: ArrayLike,
lattice: Lattice,
has_nac: bool = False,
Expand Down Expand Up @@ -394,7 +394,7 @@ def __repr__(self) -> str:
return f"{type(self).__name__}({bands=}, {labels=})"

def _reuse_init(
self, eigendisplacements: ArrayLike, frequencies: ArrayLike, has_nac: bool, qpoints: list[Kpoint]
self, eigendisplacements: ArrayLike, frequencies: ArrayLike, has_nac: bool, qpoints: Sequence[Kpoint]
) -> None:
self.distance = []
self.branches = []
Expand Down Expand Up @@ -642,13 +642,24 @@ def from_dict(cls, dct: dict) -> PhononBandStructureSymmLine:
eigendisplacements = (
np.array(dct["eigendisplacements"]["real"]) + np.array(dct["eigendisplacements"]["imag"]) * 1j
)
struct = Structure.from_dict(dct["structure"]) if "structure" in dct else None
return cls(
dct["qpoints"],
np.array(dct["bands"]),
lattice_rec,
dct["has_nac"],
eigendisplacements,
dct["labels_dict"],
structure=struct,
structure=Structure.from_dict(dct["structure"]) if "structure" in dct else None,
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, PhononBandStructureSymmLine):
return NotImplemented
return (
self.bands.shape == other.bands.shape
and np.allclose(self.bands, other.bands)
and self.lattice_rec == other.lattice_rec
# and self.qpoints == other.qpoints
and self.labels_dict == other.labels_dict
and self.structure == other.structure
)
7 changes: 7 additions & 0 deletions tests/phonon/test_bandstructure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import json

from numpy.testing import assert_allclose, assert_array_equal
Expand All @@ -26,6 +27,12 @@ def test_repr(self):
r"PhononBandStructureSymmLine(bands=(6, 130), labels=['$\\Gamma$', 'X', 'W', 'K', 'L', 'U'])"
)

def test_eq(self):
assert self.bs == self.bs
assert self.bs == copy.deepcopy(self.bs)
assert self.bs2 == self.bs2
assert self.bs != self.bs2

def test_basic(self):
assert self.bs.bands[1][10] == approx(0.7753555184)
assert self.bs.bands[5][100] == approx(5.2548379776)
Expand Down

0 comments on commit 8b63b10

Please sign in to comment.