Skip to content

Commit

Permalink
Add ability to blend any number of transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
AnsonTran committed May 10, 2024
1 parent f683fc7 commit 91d3329
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 94 deletions.
4 changes: 2 additions & 2 deletions lib/matplotlib/tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def test_str_transform():
CompositeGenericTransform(
CompositeGenericTransform(
TransformWrapper(
BlendedAffine2D(
BlendedAffine(
IdentityTransform(),
IdentityTransform())),
CompositeAffine2D(
Expand All @@ -864,7 +864,7 @@ def test_str_transform():
CompositeGenericTransform(
PolarAffine(
TransformWrapper(
BlendedAffine2D(
BlendedAffine(
IdentityTransform(),
IdentityTransform())),
LockableBbox(
Expand Down
193 changes: 110 additions & 83 deletions lib/matplotlib/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,182 +2203,208 @@ def inverted(self):


class _BlendedMixin:
"""Common methods for `BlendedGenericTransform` and `BlendedAffine2D`."""
"""Common methods for `BlendedGenericTransform` and `BlendedAffine`."""

def __eq__(self, other):
if isinstance(other, (BlendedAffine2D, BlendedGenericTransform)):
return (self._x == other._x) and (self._y == other._y)
elif self._x == self._y:
return self._x == other
num_transforms = len(self._transforms)

if (isinstance(other, (BlendedGenericTransform, BlendedAffine))
and num_transforms == len(other._transforms)):
return all(self._transforms[i] == other._transforms[i]
for i in range(num_transforms))
else:
return NotImplemented

def contains_branch_seperately(self, transform):
return (self._x.contains_branch(transform),
self._y.contains_branch(transform))
return tuple(branch.contains_branch(transform) for branch in self._transforms)

__str__ = _make_str_method("_x", "_y")
def __str__(self):
indent = functools.partial(textwrap.indent, prefix=" " * 4)
return (
type(self).__name__ + "("
+ ",".join([*(indent("\n" + transform.__str__())
for transform in self._transforms)])
+ ")")


class BlendedGenericTransform(_BlendedMixin, Transform):
"""
A "blended" transform uses one transform for the *x*-direction, and
another transform for the *y*-direction.
A "blended" transform uses one transform for each direction
This "generic" version can handle any given child transform in the
*x*- and *y*-directions.
This "generic" version can handle any number of given child transforms, each
handling a different axis.
"""
input_dims = 2
output_dims = 2
is_separable = True
pass_through = True

def __init__(self, x_transform, y_transform, **kwargs):
def __init__(self, *args, **kwargs):
"""
Create a new "blended" transform using *x_transform* to transform the
*x*-axis and *y_transform* to transform the *y*-axis.
Create a new "blended" transform, with the first argument providing
a transform for the *x*-axis, the second argument providing a transform
for the *y*-axis, etc.
You will generally not call this constructor directly but use the
`blended_transform_factory` function instead, which can determine
automatically which kind of blended transform to create.
"""
self.input_dims = self.output_dims = len(args)

for i in range(self.input_dims):
transform = args[i]
if transform.input_dims > 1 and transform.input_dims <= i:
raise TypeError("Invalid transform provided to"
"`BlendedGenericTransform`")

Transform.__init__(self, **kwargs)
self._x = x_transform
self._y = y_transform
self.set_children(x_transform, y_transform)
self.set_children(*args)
self._transforms = args
self._affine = None

@property
def depth(self):
return max(self._x.depth, self._y.depth)
return max(transform.depth for transform in self._transforms)

def contains_branch(self, other):
# A blended transform cannot possibly contain a branch from two
# different transforms.
return False

is_affine = property(lambda self: self._x.is_affine and self._y.is_affine)
has_inverse = property(
lambda self: self._x.has_inverse and self._y.has_inverse)
is_affine = property(lambda self: all(transform.is_affine
for transform in self._transforms))
has_inverse = property(lambda self: all(transform.has_inverse
for transform in self._transforms))

def frozen(self):
# docstring inherited
return blended_transform_factory(self._x.frozen(), self._y.frozen())
return blended_transform_factory(*(transform.frozen()
for transform in self._transforms))

@_api.rename_parameter("3.8", "points", "values")
def transform_non_affine(self, values):
# docstring inherited
if self._x.is_affine and self._y.is_affine:
if self.is_affine:
return values
x = self._x
y = self._y

if x == y and x.input_dims == 2:
return x.transform_non_affine(values)
if all(transform == self._transforms[0]
for transform in self._transforms) and self.input_dims >= 2:
return self._transforms[0].transform_non_affine(values)

if x.input_dims == 2:
x_points = x.transform_non_affine(values)[:, 0:1]
else:
x_points = x.transform_non_affine(values[:, 0])
x_points = x_points.reshape((len(x_points), 1))
all_points = []
masked = False

if y.input_dims == 2:
y_points = y.transform_non_affine(values)[:, 1:]
else:
y_points = y.transform_non_affine(values[:, 1])
y_points = y_points.reshape((len(y_points), 1))
for dim in range(self.input_dims):
transform = self._transforms[dim]
if transform.input_dims == 1:
points = transform.transform_non_affine(values[:, dim])
points = points.reshape((len(points), 1))
else:
points = transform.transform_non_affine(values)[:, dim:dim+1]

if (isinstance(x_points, np.ma.MaskedArray) or
isinstance(y_points, np.ma.MaskedArray)):
return np.ma.concatenate((x_points, y_points), 1)
masked = masked or isinstance(points, np.ma.MaskedArray)
all_points.append(points)

if masked:
return np.ma.concatenate(tuple(all_points), 1)
else:
return np.concatenate((x_points, y_points), 1)
return np.concatenate(tuple(all_points), 1)

def inverted(self):
# docstring inherited
return BlendedGenericTransform(self._x.inverted(), self._y.inverted())
return BlendedGenericTransform(*(transform.inverted()
for transform in self._transforms))

def get_affine(self):
# docstring inherited
if self._invalid or self._affine is None:
if self._x == self._y:
self._affine = self._x.get_affine()
if all(transform == self._transforms[0] for transform in self._transforms):
self._affine = self._transforms[0].get_affine()
else:
x_mtx = self._x.get_affine().get_matrix()
y_mtx = self._y.get_affine().get_matrix()
# We already know the transforms are separable, so we can skip
# setting b and c to zero.
mtx = np.array([x_mtx[0], y_mtx[1], [0.0, 0.0, 1.0]])
self._affine = Affine2D(mtx)
mtx = np.identity(self.input_dims + 1)
for i in range(self.input_dims):
transform = self._transforms[i]
if transform.output_dims > 1:
mtx[i] = transform.get_affine().get_matrix()[i]

self._affine = _affine_factory(mtx, dims=self.input_dims)
self._invalid = 0
return self._affine


class BlendedAffine2D(_BlendedMixin, Affine2DBase):
class BlendedAffine(_BlendedMixin, AffineImmutable):
"""
A "blended" transform uses one transform for the *x*-direction, and
another transform for the *y*-direction.
This version is an optimization for the case where both child
transforms are of type `Affine2DBase`.
transforms are of type `AffineImmutable`.
"""

is_separable = True

def __init__(self, x_transform, y_transform, **kwargs):
def __init__(self, *args, **kwargs):
"""
Create a new "blended" transform using *x_transform* to transform the
*x*-axis and *y_transform* to transform the *y*-axis.
Create a new "blended" transform, with the first argument providing
a transform for the *x*-axis, the second argument providing a transform
for the *y*-axis, etc.
Both *x_transform* and *y_transform* must be 2D affine transforms.
All provided transforms must be affine transforms.
You will generally not call this constructor directly but use the
`blended_transform_factory` function instead, which can determine
automatically which kind of blended transform to create.
"""
is_affine = x_transform.is_affine and y_transform.is_affine
is_separable = x_transform.is_separable and y_transform.is_separable
is_correct = is_affine and is_separable
if not is_correct:
raise ValueError("Both *x_transform* and *y_transform* must be 2D "
"affine transforms")

dims = len(args)
Transform.__init__(self, **kwargs)
self._x = x_transform
self._y = y_transform
self.set_children(x_transform, y_transform)
AffineImmutable.__init__(self, dims=dims, **kwargs)

if not all(transform.is_affine and transform.is_separable
for transform in args):
raise ValueError("Given transforms must be affine")

for i in range(self.input_dims):
transform = args[i]
if transform.input_dims > 1 and transform.input_dims <= i:
raise TypeError("Invalid transform provided to"
"`BlendedGenericTransform`")

self._transforms = args
self.set_children(*args)

Affine2DBase.__init__(self)
self._mtx = None

def get_matrix(self):
# docstring inherited
if self._invalid:
if self._x == self._y:
self._mtx = self._x.get_matrix()
if all(transform == self._transforms[0] for transform in self._transforms):
self._mtx = self._transforms[0].get_matrix()
else:
x_mtx = self._x.get_matrix()
y_mtx = self._y.get_matrix()
# We already know the transforms are separable, so we can skip
# setting b and c to zero.
self._mtx = np.array([x_mtx[0], y_mtx[1], [0.0, 0.0, 1.0]])
# setting non-diagonal values to zero.
self._mtx = np.array(
[self._transforms[i].get_affine().get_matrix()[i]
for i in range(self.input_dims)] +
[[0.0] * self.input_dims + [1.0]])
self._inverted = None
self._invalid = 0
return self._mtx


def blended_transform_factory(x_transform, y_transform):
@_api.deprecated("3.9", alternative="BlendedAffine")
class BlendedAffine2D(BlendedAffine):
pass


def blended_transform_factory(*args):
"""
Create a new "blended" transform using *x_transform* to transform
the *x*-axis and *y_transform* to transform the *y*-axis.
A faster version of the blended transform is returned for the case
where both child transforms are affine.
"""
if (isinstance(x_transform, Affine2DBase) and
isinstance(y_transform, Affine2DBase)):
return BlendedAffine2D(x_transform, y_transform)
return BlendedGenericTransform(x_transform, y_transform)
if all(isinstance(transform, AffineImmutable) for transform in args):
return BlendedAffine(*args)
return BlendedGenericTransform(*args)


class CompositeGenericTransform(Transform):
Expand Down Expand Up @@ -2479,8 +2505,9 @@ def get_affine(self):
if not self._b.is_affine:
return self._b.get_affine()
else:
return Affine2D(np.dot(self._b.get_affine().get_matrix(),
self._a.get_affine().get_matrix()))
return _affine_factory(np.dot(self._b.get_affine().get_matrix(),
self._a.get_affine().get_matrix()),
dims=self.input_dims)

def inverted(self):
# docstring inherited
Expand Down
17 changes: 8 additions & 9 deletions lib/matplotlib/transforms.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -258,26 +258,25 @@ class _BlendedMixin:
def contains_branch_seperately(self, transform: Transform) -> Sequence[bool]: ...

class BlendedGenericTransform(_BlendedMixin, Transform):
input_dims: Literal[2]
output_dims: Literal[2]
pass_through: bool
def __init__(
self, x_transform: Transform, y_transform: Transform, **kwargs
self, *args: Transform, **kwargs
) -> None: ...
@property
def depth(self) -> int: ...
def contains_branch(self, other: Transform) -> Literal[False]: ...
@property
def is_affine(self) -> bool: ...

class BlendedAffine2D(_BlendedMixin, Affine2DBase):
def __init__(
self, x_transform: Transform, y_transform: Transform, **kwargs
) -> None: ...
class BlendedAffine(_BlendedMixin, AffineImmutable):
def __init__(self, *args: Transform, **kwargs) -> None: ...

class BlendedAffine2D(BlendedAffine):
pass

def blended_transform_factory(
x_transform: Transform, y_transform: Transform
) -> BlendedGenericTransform | BlendedAffine2D: ...
*args: Transform
) -> BlendedGenericTransform | BlendedAffine: ...

class CompositeGenericTransform(Transform):
pass_through: bool
Expand Down

0 comments on commit 91d3329

Please sign in to comment.