diff --git a/pygem/rbf.py b/pygem/rbf.py index a887697d..cb1f6fff 100644 --- a/pygem/rbf.py +++ b/pygem/rbf.py @@ -70,6 +70,8 @@ import matplotlib.pyplot as plt +import warnings + class RBF(Deformation): """ @@ -93,8 +95,12 @@ class RBF(Deformation): basis functions. For details see the class :class:`RBF`. The default value is 0.5. :param dict extra_parameter: the additional parameters that may be passed to - the kernel function. Default is None. - + the kernel function. Default is None. + :param str dtype: Precision specification. Supported values: + 'fp16'/'float16', 'fp32'/'float32', 'fp64'/'float64' (default), + 'fp96'/'float96','fp128'/'float128' (if available on platform). + Default is 'fp64'. + :cvar numpy.ndarray weights: the matrix formed by the weights corresponding to the a-priori selected N control points, associated to the basis functions and c and Q terms that describe the polynomial of order one @@ -112,7 +118,7 @@ class RBF(Deformation): basis functions. :cvar dict extra: the additional parameters that may be passed to the kernel function. - + :Example: >>> from pygem import RBF @@ -125,12 +131,61 @@ class RBF(Deformation): >>> mesh = np.array([x.ravel(), y.ravel(), z.ravel()]) >>> deformed_mesh = rbf(mesh) """ + + # Precision mapping + DTYPE_MAP = { + 'fp16': np.float16, + 'float16': np.float16, + 'fp32': np.float32, + 'float32': np.float32, + 'fp64': np.float64, + 'float64': np.float64, + 'fp96': np.float96 if hasattr(np, 'float96') else np.float64, + 'float96': np.float96 if hasattr(np, 'float96') else np.float64, + 'fp128': np.float128 if hasattr(np, 'float128') else np.float64, + 'float128': np.float128 if hasattr(np, 'float128') else np.float64, + } + def __init__(self, original_control_points=None, deformed_control_points=None, func='gaussian_spline', radius=0.5, - extra_parameter=None): + extra_parameter=None, + dtype='fp64'): + + # Parse and set dtype with platform check + if isinstance(dtype, str): + dtype_lower = dtype.lower() + if dtype_lower not in self.DTYPE_MAP: + raise ValueError( + f"Unsupported dtype '{dtype}'. Supported values: " + f"{list(self.DTYPE_MAP.keys())}" + ) + + # Check for fp128 fallback + if dtype_lower in ['fp128', 'float128']: + if not hasattr(np, 'float128'): + warnings.warn( + "fp128/float128 is not supported on this platform. " + "Automatically falling back to fp64. " + "For true quad-precision, consider using Linux platform.", + RuntimeWarning + ) + + # Check for fp96 fallback + if dtype_lower in ['fp96', 'float96']: + if not hasattr(np, 'float96'): + warnings.warn( + "fp96/float96 is not supported on this platform. " + "Automatically falling back to fp64. " + "For higher precision consider using 'fp128' (if available) ", + RuntimeWarning + ) + + self._dtype = self.DTYPE_MAP[dtype_lower] + else: + self._dtype = dtype self.basis = func self.radius = radius @@ -139,26 +194,25 @@ def __init__(self, self.original_control_points = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.], [0., 1., 1.], [1., 0., 1.], - [1., 1., 0.], [1., 1., - 1.]]) + [1., 1., 0.], [1., 1., 1.]], + dtype=self._dtype) else: - self.original_control_points = original_control_points + self.original_control_points = np.asarray(original_control_points, dtype=self._dtype) if deformed_control_points is None: self.deformed_control_points = np.array([[0., 0., 0.], [0., 0., 1.], [0., 1., 0.], [1., 0., 0.], [0., 1., 1.], [1., 0., 1.], - [1., 1., 0.], [1., 1., - 1.]]) + [1., 1., 0.], [1., 1., 1.]], + dtype=self._dtype) else: - self.deformed_control_points = deformed_control_points + self.deformed_control_points = np.asarray(deformed_control_points, dtype=self._dtype) self.extra = extra_parameter if extra_parameter else dict() self.weights = self._get_weights(self.original_control_points, self.deformed_control_points) - @property def n_control_points(self): """ @@ -209,17 +263,26 @@ def _get_weights(self, X, Y): :rtype: numpy.ndarray """ npts, dim = X.shape - H = np.zeros((npts + 3 + 1, npts + 3 + 1)) - H[:npts, :npts] = self.basis(cdist(X, X), self.radius, **self.extra) - H[npts, :npts] = 1.0 - H[:npts, npts] = 1.0 + size = npts + 3 + 1 + H = np.zeros((size, size), dtype=self._dtype) + + # Compute distances and basis values using configured precision + dists = cdist(X, X).astype(self._dtype) + basis_block = self.basis(dists, self.radius, **self.extra) + basis_block = np.asarray(basis_block, dtype=self._dtype) + + H[:npts, :npts] = basis_block + H[npts, :npts] = self._dtype(1.0) + H[:npts, npts] = self._dtype(1.0) H[:npts, -3:] = X H[-3:, :npts] = X.T - rhs = np.zeros((npts + 3 + 1, dim)) + rhs = np.zeros((size, dim), dtype=self._dtype) rhs[:npts, :] = Y - weights = np.linalg.solve(H, rhs) - return weights + + solve_dtype = np.float64 if self._dtype not in (np.float32, np.float64) else self._dtype + weights = np.linalg.solve(H.astype(solve_dtype), rhs.astype(solve_dtype)).astype(self._dtype) + return weights.astype(self._dtype) def read_parameters(self, filename='parameters_rbf.prm'): """ @@ -242,7 +305,7 @@ def read_parameters(self, filename='parameters_rbf.prm'): config.read(filename) rbf_settings = dict(config.items('Radial Basis Functions')) - + self.basis = rbf_settings.pop('basis function') self.radius = float(rbf_settings.pop('radius')) self.extra = {k: eval(v) for k, v in rbf_settings.items()} @@ -250,12 +313,12 @@ def read_parameters(self, filename='parameters_rbf.prm'): ctrl_points = config.get('Control points', 'original control points') lines = ctrl_points.split('\n') self.original_control_points = np.array( - list(map(lambda x: x.split(), lines)), dtype=float) + list(map(lambda x: x.split(), lines)), dtype=self._dtype) mod_points = config.get('Control points', 'deformed control points') lines = mod_points.split('\n') self.deformed_control_points = np.array( - list(map(lambda x: x.split(), lines)), dtype=float) + list(map(lambda x: x.split(), lines)), dtype=self._dtype) if len(lines) != self.n_control_points: raise TypeError("The number of control points must be equal both in" @@ -308,8 +371,8 @@ def write_parameters(self, filename='parameters_rbf.prm'): for i in range(0, self.n_control_points): output_string += offset * ' ' + str( self.original_control_points[i][0]) + ' ' + str( - self.original_control_points[i][1]) + ' ' + str( - self.original_control_points[i][2]) + '\n' + self.original_control_points[i][1]) + ' ' + str( + self.original_control_points[i][2]) + '\n' offset = 25 output_string += '\n# deformed control points collects the coordinates' @@ -321,8 +384,8 @@ def write_parameters(self, filename='parameters_rbf.prm'): for i in range(0, self.n_control_points): output_string += offset * ' ' + str( self.deformed_control_points[i][0]) + ' ' + str( - self.deformed_control_points[i][1]) + ' ' + str( - self.deformed_control_points[i][2]) + '\n' + self.deformed_control_points[i][1]) + ' ' + str( + self.deformed_control_points[i][2]) + '\n' offset = 25 with open(filename, 'w') as f: @@ -393,13 +456,18 @@ def __call__(self, src_pts): This method performs the deformation of the mesh points. After the execution it sets `self.modified_mesh_points`. """ + src = np.asarray(src_pts, dtype=self._dtype) self.compute_weights() - H = np.zeros((src_pts.shape[0], self.n_control_points + 3 + 1)) - H[:, :self.n_control_points] = self.basis( - cdist(src_pts, self.original_control_points), - self.radius, - **self.extra) - H[:, self.n_control_points] = 1.0 - H[:, -3:] = src_pts - return np.asarray(np.dot(H, self.weights)) + H = np.zeros((src.shape[0], self.n_control_points + 3 + 1), + dtype=self._dtype) + + dists = cdist(src, self.original_control_points).astype(self._dtype) + basis_block = self.basis(dists, self.radius, **self.extra) + + H[:, :self.n_control_points] = np.asarray(basis_block, dtype=self._dtype) + H[:, self.n_control_points] = self._dtype(1.0) + H[:, -3:] = src + + result = np.dot(H, self.weights) + return np.asarray(result, dtype=self._dtype)