Skip to content

Commit

Permalink
[feat]: Adding numpy broadcast to replace explicit for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
hentt30 committed Jun 9, 2021
1 parent 547164a commit 5424590
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 35 deletions.
14 changes: 5 additions & 9 deletions minushalf/utils/atomic_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,12 @@ def correct_potential(self,
self.potential_file.get_maximum_module_wave_vector() /
len(self.potential_file.get_potential_fourier_transform()))

correct_potential = np.vectorize(
correct_potential_fourier_transform,
excluded=['rays', 'occupation_potential'],
)

potential = correct_potential(
coefficient=self.potential_file.get_potential_fourier_transform(),
potential = correct_potential_fourier_transform(
coefficient=np.array(
self.potential_file.get_potential_fourier_transform()),
k=wave_vectors,
rays=np.array(self.vtotal.radius, dtype=object),
occupation_potential=np.array(occupation_potential, dtype=object),
rays=np.array(self.vtotal.radius, dtype=float),
occupation_potential=np.array(occupation_potential, dtype=float),
cut=cut,
)

Expand Down
24 changes: 17 additions & 7 deletions minushalf/utils/correct_potential_fourier_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
from minushalf.data import Constants


def correct_potential_fourier_transform(coefficient: float, k: float,
rays: np.array,
occupation_potential: np.array,
cut: float) -> float:
def correct_potential_fourier_transform(
coefficient: np.array,
k: np.array,
rays: np.array,
occupation_potential: np.array,
cut: float,
) -> np.array:
r"""
The pseudopotential is given in terms of the radial distance, and is only defined for r >= 0,
as expected. Since it is only evaluated inside an integral from 0 to infinity, it does not
Expand Down Expand Up @@ -55,9 +58,10 @@ def correct_potential_fourier_transform(coefficient: float, k: float,
Fourier transform of the potential for the state with fractional occupation of the crystal
"""
const = Constants()
k = np.reshape(k, (len(k), 1))

if not k:
k = 10**(-12)
if not k[0][0]:
k[0][0] = 10**(-12)

try:
filter_rays = rays[np.where(rays < cut)]
Expand All @@ -75,8 +79,14 @@ def correct_potential_fourier_transform(coefficient: float, k: float,
(potential * np.sin(const.bohr_radius * k * radius) +
lazy_potential * np.sin(const.bohr_radius * k * lazy_radius)) / 2)

## Sum all rows and transpose

initial_term = (occupation_potential[0] *
np.sin(const.bohr_radius * k * rays[0]) / 2 * rays[0])

return (coefficient + (initial_term + partial_fourier_sum.sum()) /
## Reshape vectors to give the correct output format
initial_term = np.reshape(initial_term, (len(initial_term)))
k = np.reshape(k, (len(k)))

return (coefficient + (initial_term + partial_fourier_sum.sum(axis=1)) /
(const.bohr_radius * k))
50 changes: 31 additions & 19 deletions tests/unit/utils/test_correct_potential_fourier_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def test_correct_potential_fourier_transform_first():
cut: 3.0
"""
res = correct_potential_fourier_transform(0, 0, np.arange(0, 5, 0.1),
res = correct_potential_fourier_transform(np.array([0]),
np.array([0], dtype=float),
np.arange(0, 5, 0.1),
np.arange(20, 25, 0.1), 3.0)
assert np.isclose(res, 92.23450000000014)
assert np.isclose(res[0], 92.23450000000014)


def test_correct_potential_fourier_transform_second():
Expand All @@ -35,9 +37,11 @@ def test_correct_potential_fourier_transform_second():
cut: 3.0
"""
res = correct_potential_fourier_transform(1, 0, np.arange(0, 5, 0.1),
res = correct_potential_fourier_transform(np.array([1]),
np.array([0], dtype=float),
np.arange(0, 5, 0.1),
np.arange(20, 25, 0.1), 3.0)
assert np.isclose(res, 93.23450000000014)
assert np.isclose(res[0], 93.23450000000014)


def test_correct_potential_fourier_transform_third():
Expand All @@ -52,9 +56,10 @@ def test_correct_potential_fourier_transform_third():
cut: 3.0
"""
res = correct_potential_fourier_transform(1, 2e-1, np.arange(0, 5, 0.1),
res = correct_potential_fourier_transform(np.array([1]), np.array([2e-1]),
np.arange(0, 5, 0.1),
np.arange(20, 25, 0.1), 3.0)
assert np.isclose(res, 92.49911868134099)
assert np.isclose(res[0], 92.49911868134099)


def test_correct_potential_fourier_transform_fourth():
Expand All @@ -69,9 +74,10 @@ def test_correct_potential_fourier_transform_fourth():
cut: 13.6
"""
res = correct_potential_fourier_transform(4, 3, np.arange(12, 19, 0.01),
res = correct_potential_fourier_transform(np.array([4]), np.array([3]),
np.arange(12, 19, 0.01),
np.arange(70, 79, 0.01), 13.6)
assert np.isclose(res, 110.19498579254733)
assert np.isclose(res[0], 110.19498579254733)


def test_correct_potential_fourier_transform_fifth():
Expand All @@ -86,9 +92,10 @@ def test_correct_potential_fourier_transform_fifth():
cut: 13.6
"""
res = correct_potential_fourier_transform(0, 3, np.arange(12, 19, 0.01),
res = correct_potential_fourier_transform(np.array([0]), np.array([3]),
np.arange(12, 19, 0.01),
np.arange(70, 79, 0.01), 13.6)
assert np.isclose(res, 106.19498579254733)
assert np.isclose(res[0], 106.19498579254733)


@pytest.mark.xfail
Expand All @@ -104,7 +111,8 @@ def test_correct_potential_fourier_transform_sixth():
cut: 11.6
"""
correct_potential_fourier_transform(0, 3, np.arange(12, 19, 0.01),
correct_potential_fourier_transform(np.array([0]), np.array([3]),
np.arange(12, 19, 0.01),
np.arange(70, 79, 0.01), 11.6)


Expand All @@ -120,9 +128,10 @@ def test_correct_potential_fourier_transform_seventh():
cut: 1.23
"""
res = correct_potential_fourier_transform(0, 67, np.arange(0, 12, 0.01),
res = correct_potential_fourier_transform(np.array([0]), np.array([67]),
np.arange(0, 12, 0.01),
np.arange(70, 79, 0.01), 1.23)
assert np.isclose(res, 0.013212154569150683)
assert np.isclose(res[0], 0.013212154569150683)


def test_correct_potential_fourier_transform_eighth():
Expand All @@ -137,9 +146,10 @@ def test_correct_potential_fourier_transform_eighth():
cut: 1.23
"""
res = correct_potential_fourier_transform(0, 67, np.arange(0, 12, 0.1),
res = correct_potential_fourier_transform(np.array([0]), np.array([67]),
np.arange(0, 12, 0.1),
np.arange(70, 79, 0.1), 1.23)
assert np.isclose(res, -0.017530596213630182)
assert np.isclose(res[0], -0.017530596213630182)


def test_correct_potential_fourier_transform_nineth():
Expand All @@ -154,9 +164,10 @@ def test_correct_potential_fourier_transform_nineth():
cut: 1.23
"""
res = correct_potential_fourier_transform(0, 67, np.arange(0, 12, 0.2),
res = correct_potential_fourier_transform(np.array([0]), np.array([67]),
np.arange(0, 12, 0.2),
np.arange(70, 79, 0.2), 1.23)
assert np.isclose(res, 0.3972120670340923)
assert np.isclose(res[0], 0.3972120670340923)


def test_correct_potential_fourier_transform_tenth():
Expand All @@ -171,6 +182,7 @@ def test_correct_potential_fourier_transform_tenth():
cut: 1.23
"""
res = correct_potential_fourier_transform(0, 67, np.arange(0, 12, 0.4),
res = correct_potential_fourier_transform(np.array([0]), np.array([67]),
np.arange(0, 12, 0.4),
np.arange(70, 79, 0.4), 1.23)
assert np.isclose(res, 0.32399512659710605)
assert np.isclose(res[0], 0.32399512659710605)

0 comments on commit 5424590

Please sign in to comment.