Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 101 additions & 33 deletions pygem/rbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@

import matplotlib.pyplot as plt

import warnings


class RBF(Deformation):
"""
Expand All @@ -93,8 +95,12 @@ class RBF(Deformation):
basis functions. For details see the class
:class:`RBF`. The default value is 0.5.
:param dict extra_parameter: the additional parameters that may be passed to
the kernel function. Default is None.

the kernel function. Default is None.
:param str dtype: Precision specification. Supported values:
'fp16'/'float16', 'fp32'/'float32', 'fp64'/'float64' (default),
'fp96'/'float96','fp128'/'float128' (if available on platform).
Default is 'fp64'.

:cvar numpy.ndarray weights: the matrix formed by the weights corresponding
to the a-priori selected N control points, associated to the basis
functions and c and Q terms that describe the polynomial of order one
Expand All @@ -112,7 +118,7 @@ class RBF(Deformation):
basis functions.
:cvar dict extra: the additional parameters that may be passed to the
kernel function.

:Example:

>>> from pygem import RBF
Expand All @@ -125,12 +131,61 @@ class RBF(Deformation):
>>> mesh = np.array([x.ravel(), y.ravel(), z.ravel()])
>>> deformed_mesh = rbf(mesh)
"""

# Precision mapping
DTYPE_MAP = {
'fp16': np.float16,
'float16': np.float16,
'fp32': np.float32,
'float32': np.float32,
'fp64': np.float64,
'float64': np.float64,
'fp96': np.float96 if hasattr(np, 'float96') else np.float64,
'float96': np.float96 if hasattr(np, 'float96') else np.float64,
'fp128': np.float128 if hasattr(np, 'float128') else np.float64,
'float128': np.float128 if hasattr(np, 'float128') else np.float64,
}

def __init__(self,
original_control_points=None,
deformed_control_points=None,
func='gaussian_spline',
radius=0.5,
extra_parameter=None):
extra_parameter=None,
dtype='fp64'):

# Parse and set dtype with platform check
if isinstance(dtype, str):
dtype_lower = dtype.lower()
if dtype_lower not in self.DTYPE_MAP:
raise ValueError(
f"Unsupported dtype '{dtype}'. Supported values: "
f"{list(self.DTYPE_MAP.keys())}"
)

# Check for fp128 fallback
if dtype_lower in ['fp128', 'float128']:
if not hasattr(np, 'float128'):
warnings.warn(
"fp128/float128 is not supported on this platform. "
"Automatically falling back to fp64. "
"For true quad-precision, consider using Linux platform.",
RuntimeWarning
)

# Check for fp96 fallback
if dtype_lower in ['fp96', 'float96']:
if not hasattr(np, 'float96'):
warnings.warn(
"fp96/float96 is not supported on this platform. "
"Automatically falling back to fp64. "
"For higher precision consider using 'fp128' (if available) ",
RuntimeWarning
)

self._dtype = self.DTYPE_MAP[dtype_lower]
else:
self._dtype = dtype

self.basis = func
self.radius = radius
Expand All @@ -139,26 +194,25 @@ def __init__(self,
self.original_control_points = np.array([[0., 0., 0.], [0., 0., 1.],
[0., 1., 0.], [1., 0., 0.],
[0., 1., 1.], [1., 0., 1.],
[1., 1., 0.], [1., 1.,
1.]])
[1., 1., 0.], [1., 1., 1.]],
dtype=self._dtype)
else:
self.original_control_points = original_control_points
self.original_control_points = np.asarray(original_control_points, dtype=self._dtype)

if deformed_control_points is None:
self.deformed_control_points = np.array([[0., 0., 0.], [0., 0., 1.],
[0., 1., 0.], [1., 0., 0.],
[0., 1., 1.], [1., 0., 1.],
[1., 1., 0.], [1., 1.,
1.]])
[1., 1., 0.], [1., 1., 1.]],
dtype=self._dtype)
else:
self.deformed_control_points = deformed_control_points
self.deformed_control_points = np.asarray(deformed_control_points, dtype=self._dtype)

self.extra = extra_parameter if extra_parameter else dict()

self.weights = self._get_weights(self.original_control_points,
self.deformed_control_points)


@property
def n_control_points(self):
"""
Expand Down Expand Up @@ -209,17 +263,26 @@ def _get_weights(self, X, Y):
:rtype: numpy.ndarray
"""
npts, dim = X.shape
H = np.zeros((npts + 3 + 1, npts + 3 + 1))
H[:npts, :npts] = self.basis(cdist(X, X), self.radius, **self.extra)
H[npts, :npts] = 1.0
H[:npts, npts] = 1.0
size = npts + 3 + 1
H = np.zeros((size, size), dtype=self._dtype)

# Compute distances and basis values using configured precision
dists = cdist(X, X).astype(self._dtype)
basis_block = self.basis(dists, self.radius, **self.extra)
basis_block = np.asarray(basis_block, dtype=self._dtype)

H[:npts, :npts] = basis_block
H[npts, :npts] = self._dtype(1.0)
H[:npts, npts] = self._dtype(1.0)
H[:npts, -3:] = X
H[-3:, :npts] = X.T

rhs = np.zeros((npts + 3 + 1, dim))
rhs = np.zeros((size, dim), dtype=self._dtype)
rhs[:npts, :] = Y
weights = np.linalg.solve(H, rhs)
return weights

solve_dtype = np.float64 if self._dtype not in (np.float32, np.float64) else self._dtype
weights = np.linalg.solve(H.astype(solve_dtype), rhs.astype(solve_dtype)).astype(self._dtype)
return weights.astype(self._dtype)

def read_parameters(self, filename='parameters_rbf.prm'):
"""
Expand All @@ -242,20 +305,20 @@ def read_parameters(self, filename='parameters_rbf.prm'):
config.read(filename)

rbf_settings = dict(config.items('Radial Basis Functions'))

self.basis = rbf_settings.pop('basis function')
self.radius = float(rbf_settings.pop('radius'))
self.extra = {k: eval(v) for k, v in rbf_settings.items()}

ctrl_points = config.get('Control points', 'original control points')
lines = ctrl_points.split('\n')
self.original_control_points = np.array(
list(map(lambda x: x.split(), lines)), dtype=float)
list(map(lambda x: x.split(), lines)), dtype=self._dtype)

mod_points = config.get('Control points', 'deformed control points')
lines = mod_points.split('\n')
self.deformed_control_points = np.array(
list(map(lambda x: x.split(), lines)), dtype=float)
list(map(lambda x: x.split(), lines)), dtype=self._dtype)

if len(lines) != self.n_control_points:
raise TypeError("The number of control points must be equal both in"
Expand Down Expand Up @@ -308,8 +371,8 @@ def write_parameters(self, filename='parameters_rbf.prm'):
for i in range(0, self.n_control_points):
output_string += offset * ' ' + str(
self.original_control_points[i][0]) + ' ' + str(
self.original_control_points[i][1]) + ' ' + str(
self.original_control_points[i][2]) + '\n'
self.original_control_points[i][1]) + ' ' + str(
self.original_control_points[i][2]) + '\n'
offset = 25

output_string += '\n# deformed control points collects the coordinates'
Expand All @@ -321,8 +384,8 @@ def write_parameters(self, filename='parameters_rbf.prm'):
for i in range(0, self.n_control_points):
output_string += offset * ' ' + str(
self.deformed_control_points[i][0]) + ' ' + str(
self.deformed_control_points[i][1]) + ' ' + str(
self.deformed_control_points[i][2]) + '\n'
self.deformed_control_points[i][1]) + ' ' + str(
self.deformed_control_points[i][2]) + '\n'
offset = 25

with open(filename, 'w') as f:
Expand Down Expand Up @@ -393,13 +456,18 @@ def __call__(self, src_pts):
This method performs the deformation of the mesh points. After the
execution it sets `self.modified_mesh_points`.
"""
src = np.asarray(src_pts, dtype=self._dtype)
self.compute_weights()

H = np.zeros((src_pts.shape[0], self.n_control_points + 3 + 1))
H[:, :self.n_control_points] = self.basis(
cdist(src_pts, self.original_control_points),
self.radius,
**self.extra)
H[:, self.n_control_points] = 1.0
H[:, -3:] = src_pts
return np.asarray(np.dot(H, self.weights))
H = np.zeros((src.shape[0], self.n_control_points + 3 + 1),
dtype=self._dtype)

dists = cdist(src, self.original_control_points).astype(self._dtype)
basis_block = self.basis(dists, self.radius, **self.extra)

H[:, :self.n_control_points] = np.asarray(basis_block, dtype=self._dtype)
H[:, self.n_control_points] = self._dtype(1.0)
H[:, -3:] = src

result = np.dot(H, self.weights)
return np.asarray(result, dtype=self._dtype)
Loading