Skip to content

Commit

Permalink
Merge pull request #2501 from jmmshn/no_complex
Browse files Browse the repository at this point in the history
Remove complex numbers from the definition of WSWQ
  • Loading branch information
mkhorton committed Apr 20, 2022
2 parents dcc4576 + 3fef4d0 commit ec9eb7c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
11 changes: 9 additions & 2 deletions pymatgen/io/vasp/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5652,12 +5652,19 @@ class WSWQ(MSONable):
The indices of WSWQ.data are:
[spin][kpoint][band_i][band_j]
Attributes:
nspin: Number of spin channels
nkpoints: Number of k-points
nbands: Number of bands
me_real: Real part of the overlap matrix elements
me_imag: Imaginary part of the overlap matrix elements
"""

nspin: int
nkpoints: int
nbands: int
data: np.ndarray
me_real: np.ndarray
me_imag: np.ndarray

@classmethod
def from_file(cls, filename):
Expand Down Expand Up @@ -5694,7 +5701,7 @@ def from_file(cls, filename):
# NOTE: loop order (slow->fast) spin -> kpoint -> j -> i
data = data.reshape((nspin, nkpoints, nbands, nbands))
data = np.swapaxes(data, 2, 3) # swap i and j
return cls(nspin=nspin, nkpoints=nkpoints, nbands=nbands, data=data)
return cls(nspin=nspin, nkpoints=nkpoints, nbands=nbands, me_real=np.real(data), me_imag=np.imag(data))


class UnconvergedVASPWarning(Warning):
Expand Down
12 changes: 9 additions & 3 deletions pymatgen/io/vasp/tests/test_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2162,10 +2162,16 @@ def setUp(self):

def test_consistency(self):
self.assertEqual(True, True)
self.assertEqual(self.wswq.nbands, 288)
self.assertEqual(self.wswq.nkpoints, 2)
self.assertEqual(self.wswq.nbands, 18)
self.assertEqual(self.wswq.nkpoints, 20)
self.assertEqual(self.wswq.nspin, 2)
self.assertEqual(self.wswq.data.shape, (2, 2, 288, 288))
self.assertEqual(self.wswq.me_real.shape, (2, 20, 18, 18))
self.assertEqual(self.wswq.me_imag.shape, (2, 20, 18, 18))
for itr, (r, i) in enumerate(zip(self.wswq.me_real[0][0][4], self.wswq.me_imag[0][0][4])):
if itr == 4:
assert np.linalg.norm([r, i]) > 0.999
else:
assert np.linalg.norm([r, i]) < 0.001


if __name__ == "__main__":
Expand Down
Binary file modified test_files/WSWQ.gz
Binary file not shown.

0 comments on commit ec9eb7c

Please sign in to comment.