diff --git a/odak/learn/tools/matrix.py b/odak/learn/tools/matrix.py index 8f641622..828bd3f4 100644 --- a/odak/learn/tools/matrix.py +++ b/odak/learn/tools/matrix.py @@ -184,6 +184,40 @@ def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], kernel_2d = kernel_2d / kernel_2d.max() return kernel_2d +def generate_2d_dirac_delta(kernel_length = [21, 21], a = [3, 3], mu = [0, 0], theta=0, normalize = False): + """ + Generate 2D Dirac delta function by using Gaussian distribution. Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function + + Parameters + ---------- + kernel_length : list + Length of the Dirac delta function along X and Y axes. + a : list + The scale factor in Gaussian distribution to approximate the Dirac delta function. + As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function. + mu : list + Mu of the Gaussian kernel along X and Y axes. + theta : float + The rotation angle of the 2D Dirac delta function. + normalize : bool + If set True, normalize the output. + + Returns + ---------- + kernel_2d : torch.tensor + Generated 2D Dirac delta function. + """ + x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0]) + y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1]) + X, Y = torch.meshgrid(x, y, indexing='ij') + X = X - mu[0] + Y = Y - mu[1] + X_rot = X * np.cos(theta) - Y * np.sin(theta) + Y_rot = X * np.sin(theta) + Y * np.cos(theta) + kernel_2d = (1 / (abs(a[0] * a[1]) * np.pi)) * np.exp(-((X_rot/a[0])**2 + (Y_rot/a[1])**2)) + if normalize: + kernel_2d = kernel_2d / kernel_2d.max() + return kernel_2d def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 'same'): """ diff --git a/test/test_generate_2d_dirac_delta.py b/test/test_generate_2d_dirac_delta.py new file mode 100644 index 00000000..9d606fcc --- /dev/null +++ b/test/test_generate_2d_dirac_delta.py @@ -0,0 +1,10 @@ +import sys +from odak.learn.tools.matrix import generate_2d_dirac_delta + + +def test(): + dirac_delta = generate_2d_dirac_delta(normalize=True, a=[0.1, 0.1]) + assert dirac_delta[10][10] == 1., "The Dirac delta fucntion does not approximate peak correctly" + +if __name__ == '__main__': + sys.exit(test())