diff --git a/pygem/rbf.py b/pygem/rbf.py index ed29ff3..0d774ae 100644 --- a/pygem/rbf.py +++ b/pygem/rbf.py @@ -108,14 +108,14 @@ class RBF(Deformation): transformation. :cvar float radius: the scaling parameter that affects the shape of the basis functions. - :cvar dict extra_parameter: the additional parameters that may be passed to - the kernel function. + :cvar dict extra: the additional parameters that may be passed to the + kernel function. :Example: >>> from pygem import RBF >>> import numpy as np - >>> rbf = RBF('gaussian_spline') + >>> rbf = RBF(func='gaussian_spline') >>> xv = np.linspace(0, 1, 20) >>> yv = np.linspace(0, 1, 20) >>> zv = np.linspace(0, 1, 20) @@ -208,7 +208,7 @@ def _get_weights(self, X, Y): """ npts, dim = X.shape H = np.zeros((npts + 3 + 1, npts + 3 + 1)) - H[:npts, :npts] = self.basis(cdist(X, X), self.radius)#, **self.extra) + H[:npts, :npts] = self.basis(cdist(X, X), self.radius, **self.extra) H[npts, :npts] = 1.0 H[:npts, npts] = 1.0 H[:npts, -3:] = X @@ -239,8 +239,11 @@ def read_parameters(self, filename='parameters_rbf.prm'): config = configparser.RawConfigParser() config.read(filename) - self.basis = config.get('Radial Basis Functions', 'basis function') - self.radius = config.getfloat('Radial Basis Functions', 'radius') + rbf_settings = dict(config.items('Radial Basis Functions')) + + self.basis = rbf_settings.pop('basis function') + self.radius = float(rbf_settings.pop('radius')) + self.extra = {k: eval(v) for k, v in rbf_settings.items()} ctrl_points = config.get('Control points', 'original control points') lines = ctrl_points.split('\n') @@ -331,6 +334,7 @@ def __str__(self): string = '' string += 'basis function = {}\n'.format(self.basis) string += 'radius = {}\n'.format(self.radius) + string += 'extra_parameter = {}\n'.format(self.extra) string += '\noriginal control points =\n' string += '{}\n'.format(self.original_control_points) string += '\ndeformed control points =\n' diff --git a/tests/test_datasets/parameters_rbf_extra.prm b/tests/test_datasets/parameters_rbf_extra.prm new file mode 100644 index 0000000..6be1f3c --- /dev/null +++ b/tests/test_datasets/parameters_rbf_extra.prm @@ -0,0 +1,39 @@ + +[Radial Basis Functions] +# This section describes the radial basis functions shape. + +# basis funtion is the name of the basis functions to use in the transformation. The functions +# implemented so far are: gaussian_spline, multi_quadratic_biharmonic_spline, +# inv_multi_quadratic_biharmonic_spline, thin_plate_spline, beckert_wendland_c2_basis, polyharmonic_spline. +# For a comprehensive list with details see the class RBF. +basis function: polyharmonic_spline + +# radius is the scaling parameter r that affects the shape of the basis functions. See the documentation +# of the class RBF for details. +radius: 0.5 + +# Any additional parameter to pass to the basis function (eg the `k` power for poliharmonic_spline) +k: 4 + +[Control points] +# This section describes the RBF control points. + +# original control points collects the coordinates of the interpolation control points before the deformation. +original control points: 0.0 0.0 0.0 + 0.0 0.0 1.0 + 0.0 1.0 0.0 + 1.0 0.0 0.0 + 0.0 1.0 1.0 + 1.0 0.0 1.0 + 1.0 1.0 0.0 + 1.0 1.0 1.0 + +# deformed control points collects the coordinates of the interpolation control points after the deformation. +deformed control points: 0.0 0.0 0.0 + 0.0 0.0 1.0 + 0.0 1.0 0.0 + 1.0 0.0 0.0 + 0.0 1.0 1.0 + 1.0 0.0 1.0 + 1.0 1.0 0.0 + 1.0 1.0 2.0 diff --git a/tests/test_datasets/parameters_rbf_radius.prm b/tests/test_datasets/parameters_rbf_radius.prm new file mode 100644 index 0000000..dbe44df --- /dev/null +++ b/tests/test_datasets/parameters_rbf_radius.prm @@ -0,0 +1,38 @@ + +[Radial Basis Functions] +# This section describes the radial basis functions shape. + +# basis funtion is the name of the basis functions to use in the transformation. The functions +# implemented so far are: gaussian_spline, multi_quadratic_biharmonic_spline, +# inv_multi_quadratic_biharmonic_spline, thin_plate_spline, beckert_wendland_c2_basis, polyharmonic_spline. +# For a comprehensive list with details see the class RBF. +basis function: gaussian_spline + +# radius is the scaling parameter r that affects the shape of the basis functions. See the documentation +# of the class RBF for details. +radius: 2.0 + + + +[Control points] +# This section describes the RBF control points. + +# original control points collects the coordinates of the interpolation control points before the deformation. +original control points: 0.0 0.0 0.0 + 0.0 0.0 1.0 + 0.0 1.0 0.0 + 1.0 0.0 0.0 + 0.0 1.0 1.0 + 1.0 0.0 1.0 + 1.0 1.0 0.0 + 1.0 1.0 1.0 + +# deformed control points collects the coordinates of the interpolation control points after the deformation. +deformed control points: 0.0 0.0 0.0 + 0.0 0.0 1.0 + 0.0 1.0 0.0 + 1.0 0.0 0.0 + 0.0 1.0 1.0 + 1.0 0.0 1.0 + 1.0 1.0 0.0 + 1.0 1.0 1.0 diff --git a/tests/test_rbf.py b/tests/test_rbf.py index 762b8cc..c29ad7d 100644 --- a/tests/test_rbf.py +++ b/tests/test_rbf.py @@ -51,6 +51,10 @@ def test_class_members_default_radius(self): rbf = RBF() assert rbf.radius == 0.5 + def test_class_members_default_extra(self): + rbf = RBF() + assert rbf.extra == {} + def test_class_members_default_n_control_points(self): rbf = RBF() assert rbf.n_control_points == 8 @@ -68,10 +72,20 @@ def test_read_parameters_basis(self): rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm') assert rbf.basis == RBFFactory('gaussian_spline') + def test_read_parameters_basis2(self): + rbf = RBF() + rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm') + assert rbf.basis == RBFFactory('polyharmonic_spline') + def test_read_parameters_radius(self): rbf = RBF() - rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm') - assert rbf.radius == 0.5 + rbf.read_parameters('tests/test_datasets/parameters_rbf_radius.prm') + assert rbf.radius == 2.0 + + def test_read_extra_parameters(self): + rbf = RBF() + rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm') + assert rbf.extra == {'k': 4} def test_read_parameters_n_control_points(self): rbf = RBF() @@ -145,17 +159,23 @@ def test_write_parameters(self): self.assertTrue(filecmp.cmp(outfilename, outfilename_expected)) #os.remove(outfilename) - - def test_read_parameters_filename_default(self): - params = RBF() - params.read_parameters() - outfilename = 'parameters_rbf.prm' - outfilename_expected = 'tests/test_datasets/parameters_rbf_default.prm' - - self.assertTrue(filecmp.cmp(outfilename, outfilename_expected)) - os.remove(outfilename) """ def test_print_info(self): - params = RBF() - print(params) + rbf = RBF() + print(rbf) + + def test_call_dummy_transformation(self): + rbf = RBF() + rbf.read_parameters('tests/test_datasets/parameters_rbf_default.prm') + mesh = self.get_cube_mesh_points() + new = rbf(mesh) + np.testing.assert_array_almost_equal(new[17], mesh[17]) + + def test_call(self): + rbf = RBF() + rbf.read_parameters('tests/test_datasets/parameters_rbf_extra.prm') + mesh = self.get_cube_mesh_points() + new = rbf(mesh) + np.testing.assert_array_almost_equal(new[17], [8.947368e-01, 5.353524e-17, 8.845331e-03]) +