Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions pygem/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ class RBF(Deformation):
transformation.
:cvar float radius: the scaling parameter that affects the shape of the
basis functions.
:cvar dict extra_parameter: the additional parameters that may be passed to
the kernel function.
:cvar dict extra: the additional parameters that may be passed to the
kernel function.

:Example:

>>> from pygem import RBF
>>> import numpy as np
>>> rbf = RBF('gaussian_spline')
>>> rbf = RBF(func='gaussian_spline')
>>> xv = np.linspace(0, 1, 20)
>>> yv = np.linspace(0, 1, 20)
>>> zv = np.linspace(0, 1, 20)
Expand Down Expand Up @@ -208,7 +208,7 @@ def _get_weights(self, X, Y):
"""
npts, dim = X.shape
H = np.zeros((npts + 3 + 1, npts + 3 + 1))
H[:npts, :npts] = self.basis(cdist(X, X), self.radius)#, **self.extra)
H[:npts, :npts] = self.basis(cdist(X, X), self.radius, **self.extra)
H[npts, :npts] = 1.0
H[:npts, npts] = 1.0
H[:npts, -3:] = X
Expand Down Expand Up @@ -239,8 +239,11 @@ def read_parameters(self, filename='parameters_rbf.prm'):
config = configparser.RawConfigParser()
config.read(filename)

self.basis = config.get('Radial Basis Functions', 'basis function')
self.radius = config.getfloat('Radial Basis Functions', 'radius')
rbf_settings = dict(config.items('Radial Basis Functions'))

self.basis = rbf_settings.pop('basis function')
self.radius = float(rbf_settings.pop('radius'))
self.extra = {k: eval(v) for k, v in rbf_settings.items()}

ctrl_points = config.get('Control points', 'original control points')
lines = ctrl_points.split('\n')
Expand Down Expand Up @@ -331,6 +334,7 @@ def __str__(self):
string = ''
string += 'basis function = {}\n'.format(self.basis)
string += 'radius = {}\n'.format(self.radius)
string += 'extra_parameter = {}\n'.format(self.extra)
string += '\noriginal control points =\n'
string += '{}\n'.format(self.original_control_points)
string += '\ndeformed control points =\n'
Expand Down
39 changes: 39 additions & 0 deletions tests/test_datasets/parameters_rbf_extra.prm
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

[Radial Basis Functions]
# This section describes the radial basis functions shape.

# basis funtion is the name of the basis functions to use in the transformation. The functions
# implemented so far are: gaussian_spline, multi_quadratic_biharmonic_spline,
# inv_multi_quadratic_biharmonic_spline, thin_plate_spline, beckert_wendland_c2_basis, polyharmonic_spline.
# For a comprehensive list with details see the class RBF.
basis function: polyharmonic_spline

# radius is the scaling parameter r that affects the shape of the basis functions. See the documentation
# of the class RBF for details.
radius: 0.5

# Any additional parameter to pass to the basis function (eg the `k` power for poliharmonic_spline)
k: 4

[Control points]
# This section describes the RBF control points.

# original control points collects the coordinates of the interpolation control points before the deformation.
original control points: 0.0 0.0 0.0
0.0 0.0 1.0
0.0 1.0 0.0
1.0 0.0 0.0
0.0 1.0 1.0
1.0 0.0 1.0
1.0 1.0 0.0
1.0 1.0 1.0

# deformed control points collects the coordinates of the interpolation control points after the deformation.
deformed control points: 0.0 0.0 0.0
0.0 0.0 1.0
0.0 1.0 0.0
1.0 0.0 0.0
0.0 1.0 1.0
1.0 0.0 1.0
1.0 1.0 0.0
1.0 1.0 2.0
38 changes: 38 additions & 0 deletions tests/test_datasets/parameters_rbf_radius.prm
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

[Radial Basis Functions]
# This section describes the radial basis functions shape.

# basis funtion is the name of the basis functions to use in the transformation. The functions
# implemented so far are: gaussian_spline, multi_quadratic_biharmonic_spline,
# inv_multi_quadratic_biharmonic_spline, thin_plate_spline, beckert_wendland_c2_basis, polyharmonic_spline.
# For a comprehensive list with details see the class RBF.
basis function: gaussian_spline

# radius is the scaling parameter r that affects the shape of the basis functions. See the documentation
# of the class RBF for details.
radius: 2.0



[Control points]
# This section describes the RBF control points.

# original control points collects the coordinates of the interpolation control points before the deformation.
original control points: 0.0 0.0 0.0
0.0 0.0 1.0
0.0 1.0 0.0
1.0 0.0 0.0
0.0 1.0 1.0
1.0 0.0 1.0
1.0 1.0 0.0
1.0 1.0 1.0

# deformed control points collects the coordinates of the interpolation control points after the deformation.
deformed control points: 0.0 0.0 0.0
0.0 0.0 1.0
0.0 1.0 0.0
1.0 0.0 0.0
0.0 1.0 1.0
1.0 0.0 1.0
1.0 1.0 0.0
1.0 1.0 1.0
46 changes: 33 additions & 13 deletions tests/test_rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def test_class_members_default_radius(self):
rbf = RBF()
assert rbf.radius == 0.5

def test_class_members_default_extra(self):
rbf = RBF()
assert rbf.extra == {}

def test_class_members_default_n_control_points(self):
rbf = RBF()
assert rbf.n_control_points == 8
Expand All @@ -68,10 +72,20 @@ def test_read_parameters_basis(self):
rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm')
assert rbf.basis == RBFFactory('gaussian_spline')

def test_read_parameters_basis2(self):
rbf = RBF()
rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm')
assert rbf.basis == RBFFactory('polyharmonic_spline')

def test_read_parameters_radius(self):
rbf = RBF()
rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm')
assert rbf.radius == 0.5
rbf.read_parameters('tests/test_datasets/parameters_rbf_radius.prm')
assert rbf.radius == 2.0

def test_read_extra_parameters(self):
rbf = RBF()
rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm')
assert rbf.extra == {'k': 4}

def test_read_parameters_n_control_points(self):
rbf = RBF()
Expand Down Expand Up @@ -145,17 +159,23 @@ def test_write_parameters(self):

self.assertTrue(filecmp.cmp(outfilename, outfilename_expected))
#os.remove(outfilename)

def test_read_parameters_filename_default(self):
params = RBF()
params.read_parameters()
outfilename = 'parameters_rbf.prm'
outfilename_expected = 'tests/test_datasets/parameters_rbf_default.prm'

self.assertTrue(filecmp.cmp(outfilename, outfilename_expected))
os.remove(outfilename)
"""

def test_print_info(self):
params = RBF()
print(params)
rbf = RBF()
print(rbf)

def test_call_dummy_transformation(self):
rbf = RBF()
rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm')
mesh = self.get_cube_mesh_points()
new = rbf(mesh)
np.testing.assert_array_almost_equal(new[17], mesh[17])

def test_call(self):
rbf = RBF()
rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm')
mesh = self.get_cube_mesh_points()
new = rbf(mesh)
np.testing.assert_array_almost_equal(new[17], [8.947368e-01, 5.353524e-17, 8.845331e-03])