Skip to content

Commit

Permalink
[fix]: Using numpy broadcast to optimize trimming function
Browse files Browse the repository at this point in the history
  • Loading branch information
hentt30 committed Jun 9, 2021
1 parent c54cb90 commit c56a0e9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 26 deletions.
9 changes: 4 additions & 5 deletions minushalf/utils/atomic_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,10 @@ def occupy_potential(self, cut: float, amplitude) -> list:
A list that contains the potentials of fractional electron
occupation at the exact level to be corrected.
"""
trimming = np.vectorize(trimming_function)
occupation_potential = trimming(
self.vtotal.radius,
self.vtotal_occupied.down_potential,
self.vtotal.down_potential,
occupation_potential = trimming_function(
np.array(self.vtotal.radius, dtype=float),
np.array(self.vtotal_occupied.down_potential, dtype=float),
np.array(self.vtotal.down_potential, dtype=float),
cut,
amplitude,
)
Expand Down
29 changes: 18 additions & 11 deletions minushalf/utils/trimming_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
from minushalf.data import Constants


def trimming_function(radius: float, ion_potential: float, atom_potential,
cut: float, amplitude: float) -> float:
def trimming_function(
radius: np.array,
ion_potential: np.array,
atom_potential: np.array,
cut: float,
amplitude: float,
) -> np.array:
r"""
Function that generate the potential for fractional occupation. The potential
is cuted by a a function theta(r) to avoid divergence in calculations.
Expand All @@ -32,11 +37,11 @@ def trimming_function(radius: float, ion_potential: float, atom_potential,
amplitude (float): multiplicative factor of the potential function
radius (float): radius in which the potential was calculated
radius (np.array): rays in which the potential was calculated
ion_potential (float): Atom pseudopotential with fractional occupation
ion_potential (np.array): Atom pseudopotential with fractional occupation
atom_potential (float): Atom pseudopotential with all electrons
atom_potential (np.array): Atom pseudopotential with all electrons
Returns:
potential of fractional electron occupation at the exact level to be corrected
Expand All @@ -45,10 +50,12 @@ def trimming_function(radius: float, ion_potential: float, atom_potential,
"""
const = Constants()

if radius >= cut:
return 0
potential = np.array(
4 * const.pi_constant * const.rydberg *
np.power(const.bohr_radius, 3) *
np.power(1 - np.power(radius / cut, const.trimming_exponent), 3) *
(ion_potential - atom_potential) * amplitude)

return (4 * const.pi_constant * const.rydberg *
np.power(const.bohr_radius, 3) *
np.power(1 - np.power(radius / cut, const.trimming_exponent), 3) *
(ion_potential - atom_potential) * amplitude)
potential[radius >= cut] = 0

return potential
30 changes: 20 additions & 10 deletions tests/unit/utils/test_trimming_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def test_trimming_first():
ion_potential: 0.77e-5
atom_potential: 0.64e-4
"""
res = trimming_function(3.0, 0.77e-5, 0.64e-4, 3.0, 1.0)
res = trimming_function(np.array(3.0), np.array(0.77e-5),
np.array(0.64e-4), 3.0, 1.0)
assert np.isclose(res, 0)


Expand All @@ -27,7 +28,8 @@ def test_trimming_second():
ion_potential: 0.77e-5
atom_potential: 0.64e-4
"""
res = trimming_function(2.9, 0.77e-5, 0.64e-4, 3.0, 1.0)
res = trimming_function(np.array(2.9), np.array(0.77e-5),
np.array(0.64e-4), 3.0, 1.0)
assert np.isclose(res, -1.911990822901309e-05)


Expand All @@ -40,7 +42,8 @@ def test_trimming_third():
ion_potential: 0.77e-5
atom_potential: 0.64e-4
"""
res = trimming_function(2.9, 0.77e-5, 0.64e-4, 3.0, 2.0)
res = trimming_function(np.array(2.9), np.array(0.77e-5),
np.array(0.64e-4), 3.0, 2.0)
assert np.isclose(res, -3.823981645802618e-05)


Expand All @@ -53,7 +56,8 @@ def test_trimming_fourth():
ion_potential: 1.2
atom_potential: 0.34
"""
res = trimming_function(0.23e-2, 1.2, 0.34, 1.55, 1.54)
res = trimming_function(np.array(0.23e-2), np.array(1.2), np.array(0.34),
1.55, 1.54)
assert np.isclose(res, 33.55492614903707)


Expand All @@ -66,7 +70,8 @@ def test_trimming_fifth():
ion_potential: 1.2
atom_potential: 0.34
"""
res = trimming_function(100, 1.2, 0.34, 1.55, 1.54)
res = trimming_function(np.array(100.0), np.array(1.2), np.array(0.34),
1.55, 1.54)
assert np.isclose(res, 0)


Expand All @@ -79,7 +84,8 @@ def test_trimming_sixth():
ion_potential: 1.2e-23
atom_potential: 0.34e-34
"""
res = trimming_function(1, 1.2e-23, 0.34e-34, 1.55, 1.54)
res = trimming_function(np.array(1), np.array(1.2e-23), np.array(0.34e-34),
1.55, 1.54)
assert np.isclose(res, 4.273004874113689e-22)


Expand All @@ -92,7 +98,8 @@ def test_trimming_seventh():
ion_potential: 1.643
atom_potential: 1.643
"""
res = trimming_function(2.95, 1.643, 1.643, 3.76, 0.34)
res = trimming_function(np.array(2.95), np.array(1.643), np.array(1.643),
3.76, 0.34)
assert np.isclose(res, 0)


Expand All @@ -105,7 +112,8 @@ def test_trimming_eighth():
ion_potential: 1.643
atom_potential: 1.835
"""
res = trimming_function(2.95, 1.643, 1.835, 3.76, 1.0)
res = trimming_function(np.array(2.95), np.array(1.643), np.array(1.835),
3.76, 1.0)
assert np.isclose(res, -3.0556886231671694)


Expand All @@ -118,7 +126,8 @@ def test_trimming_nineth():
ion_potential: 1.643
atom_potential: 1.835
"""
res = trimming_function(0, 1.643, 1.835, 3.76, 1.0)
res = trimming_function(np.array(0), np.array(1.643), np.array(1.835),
3.76, 1.0)
assert np.isclose(res, -4.8645015256834165)


Expand All @@ -131,5 +140,6 @@ def test_trimming_tenth():
ion_potential: 12e-23
atom_potential: 15e-34
"""
res = trimming_function(4.32, 12e-23, 15e-34, 4.55, 1.0)
res = trimming_function(np.array(4.32), np.array(12e-23), np.array(15e-34),
4.55, 1.0)
assert np.isclose(res, 1.1912043652717141e-22)

0 comments on commit c56a0e9

Please sign in to comment.