Skip to content

Commit

Permalink
add generate_2d_dirac_delta function and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kymer0615 committed Dec 5, 2023
1 parent e7650f5 commit 77fc976
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
34 changes: 34 additions & 0 deletions odak/learn/tools/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0],
kernel_2d = kernel_2d / kernel_2d.max()
return kernel_2d

def generate_2d_dirac_delta(kernel_length = [21, 21], a = [3, 3], mu = [0, 0], theta=0, normalize = False):
"""
Generate 2D Dirac delta function by using Gaussian distribution. Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function
Parameters
----------
kernel_length : list
Length of the Dirac delta function along X and Y axes.
a : list
The scale factor in Gaussian distribution to approximate the Dirac delta function.
As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.
mu : list
Mu of the Gaussian kernel along X and Y axes.
theta : float
The rotation angle of the 2D Dirac delta function.
normalize : bool
If set True, normalize the output.
Returns
----------
kernel_2d : torch.tensor
Generated 2D Dirac delta function.
"""
x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])
y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])
X, Y = torch.meshgrid(x, y, indexing='ij')
X = X - mu[0]
Y = Y - mu[1]
X_rot = X * np.cos(theta) - Y * np.sin(theta)
Y_rot = X * np.sin(theta) + Y * np.cos(theta)
kernel_2d = (1 / (abs(a[0] * a[1]) * np.pi)) * np.exp(-((X_rot/a[0])**2 + (Y_rot/a[1])**2))
if normalize:
kernel_2d = kernel_2d / kernel_2d.max()
return kernel_2d

def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 'same'):
"""
Expand Down
10 changes: 10 additions & 0 deletions test/test_generate_2d_dirac_delta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import sys
from odak.learn.tools.matrix import generate_2d_dirac_delta


def test():
dirac_delta = generate_2d_dirac_delta(normalize=True, a=[0.1, 0.1])
assert dirac_delta[10][10] == 1., "The Dirac delta fucntion does not approximate peak correctly"

if __name__ == '__main__':
sys.exit(test())

0 comments on commit 77fc976

Please sign in to comment.