Skip to content

Commit

Permalink
nndata -> nn_data
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Mar 28, 2023
1 parent 4d0947e commit 7566e92
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 33 deletions.
36 changes: 18 additions & 18 deletions pymatgen/analysis/local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3856,25 +3856,25 @@ def get_nn_info(self, structure: Structure, n: int) -> list[dict]:
to the coordination number (1 or smaller), 'site_index' gives index of
the corresponding site in the original structure.
"""
nndata = self.get_nn_data(structure, n)
nn_data = self.get_nn_data(structure, n)

if not self.weighted_cn:
max_key = max(nndata.cn_weights, key=lambda k: nndata.cn_weights[k])
nn = nndata.cn_nninfo[max_key]
max_key = max(nn_data.cn_weights, key=lambda k: nn_data.cn_weights[k])
nn = nn_data.cn_nninfo[max_key]
for entry in nn:
entry["weight"] = 1
return nn

for entry in nndata.all_nninfo:
for entry in nn_data.all_nninfo:
weight = 0
for cn in nndata.cn_nninfo:
for cn_entry in nndata.cn_nninfo[cn]:
for cn in nn_data.cn_nninfo:
for cn_entry in nn_data.cn_nninfo[cn]:
if entry["site"] == cn_entry["site"]:
weight += nndata.cn_weights[cn]
weight += nn_data.cn_weights[cn]

entry["weight"] = weight

return nndata.all_nninfo
return nn_data.all_nninfo

def get_nn_data(self, structure: Structure, n: int, length=None):
"""
Expand All @@ -3887,9 +3887,9 @@ def get_nn_data(self, structure: Structure, n: int, length=None):
Returns:
a namedtuple (NNData) object that contains:
- all near neighbor sites with weights
- a dict of CN -> weight
- a dict of CN -> associated near neighbor sites
- all near neighbor sites with weights
- a dict of CN -> weight
- a dict of CN -> associated near neighbor sites
"""
length = length or self.fingerprint_length

Expand Down Expand Up @@ -4082,23 +4082,23 @@ def _semicircle_integral(dist_bins, idx):
return (area1 - area2) / (0.25 * math.pi * r**2)

@staticmethod
def transform_to_length(nndata, length):
def transform_to_length(nn_data, length):
"""
Given NNData, transforms data to the specified fingerprint length
Args:
nndata: (NNData)
nn_data: (NNData)
length: (int) desired length of NNData
"""
if length is None:
return nndata
return nn_data

if length:
for cn in range(length):
if cn not in nndata.cn_weights:
nndata.cn_weights[cn] = 0
nndata.cn_nninfo[cn] = []
if cn not in nn_data.cn_weights:
nn_data.cn_weights[cn] = 0
nn_data.cn_nninfo[cn] = []

return nndata
return nn_data


def _get_default_radius(site):
Expand Down
6 changes: 3 additions & 3 deletions pymatgen/analysis/tests/test_local_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,9 +1317,9 @@ def test_weighted_cn_no_oxid(self):

def test_fixed_length(self):
cnn = CrystalNN(fingerprint_length=30)
nndata = cnn.get_nn_data(self.lifepo4, 0)
assert len(nndata.cn_weights) == 30
assert len(nndata.cn_nninfo) == 30
nn_data = cnn.get_nn_data(self.lifepo4, 0)
assert len(nn_data.cn_weights) == 30
assert len(nn_data.cn_nninfo) == 30

def test_cation_anion(self):
cnn = CrystalNN(weighted_cn=True, cation_anion=True)
Expand Down
13 changes: 1 addition & 12 deletions pymatgen/util/coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,7 @@ def lattice_points_in_supercell(supercell_matrix):
Returns:
numpy array of the fractional coordinates
"""
diagonals = np.array(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
]
)
diagonals = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]])
d_points = np.dot(diagonals, supercell_matrix)

mins = np.min(d_points, axis=0)
Expand Down

0 comments on commit 7566e92

Please sign in to comment.