Skip to content

Commit

Permalink
Merge pull request #89 from PerretB/fix_linearize_vertexweights
Browse files Browse the repository at this point in the history
Fix linearize vertexweights
  • Loading branch information
PerretB committed May 26, 2019
2 parents 8a32005 + 2ad62e3 commit 7e2c3f9
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 67 deletions.
8 changes: 7 additions & 1 deletion doc/source/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,10 @@ Python modules
:caption: IO:

Pink Graph </python/pink_io.rst>
Tree IO </python/tree_io.rst>
Tree IO </python/tree_io.rst>

.. toctree::
:maxdepth: 2
:caption: Misc:

Utility functions </python/hg_utils.rst>
18 changes: 0 additions & 18 deletions doc/source/python/concept.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,15 @@ Concepts
.. autoclass:: higra.Concept
:members:

.. autoclass:: higra.CptEdgeWeightedGraph
:members:

.. autoclass:: higra.CptGraphCut
:members:

.. autoclass:: higra.CptVertexWeightedGraph
:members:

.. autoclass:: higra.CptVertexLabeledGraph
:members:

.. autoclass:: higra.CptGridGraph
:members:

.. autoclass:: higra.CptRegionAdjacencyGraph
:members:

.. autoclass:: higra.CptSaliencyMap
:members:

.. autoclass:: higra.CptHierarchy
:members:

.. autoclass:: higra.CptValuedHierarchy
:members:

.. autoclass:: higra.CptBinaryHierarchy
:members:

Expand Down
26 changes: 26 additions & 0 deletions doc/source/python/hg_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. _hg_utils:

Utility functions
=================

.. autofunction:: higra.is_iterable

.. autofunction:: higra.extend_class

.. autofunction:: higra.normalize_shape

.. autofunction:: higra.linearize_vertex_weights

.. autofunction:: higra.delinearize_vertex_weights

.. autofunction:: higra.is_in_bijection

.. autofunction:: higra.mean_angle_mod_pi

.. autofunction:: higra.dtype_info

.. autofunction:: higra.common_type

.. autofunction:: higra.cast_to_common_type

.. autofunction:: higra.cast_to_dtype
2 changes: 1 addition & 1 deletion doc/source/python/tree_io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ Tree IO allows de/serialization of a tree and associated attributes in a custom

.. autofunction:: higra.read_tree

.. autofunction:: higra.save_tree_attributes
.. autofunction:: higra.save_tree
202 changes: 163 additions & 39 deletions higra/hg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,23 @@ def is_iterable(obj):

def extend_class(cls, method_name=None):
"""
Add the decorated function to the specified class.
Decorator: add the decorated function to the specified class.
If no name is specified the name of the function is used.
:example:
>>> class A:
>>> pass
>>>
>>> @extend_class(A, "shiny_method")
>>> def hello(self, name):
>>> print("Hello", name)
>>>
>>> a = A()
>>> a.shiny_method("foo")
:param cls The class to extend
:param method_name The name that the new method will take (If None, it will be determined from the decorated function)
"""

def decorate(funct):
Expand All @@ -50,37 +65,135 @@ def normalize_shape(shape):
This function ensure that the given shape will be easily convertible
in a c++ callback (ie. that it won't interfere badly in pybind11
overload resolution algorithm)
:param shape:
:return:
:example:
>>> normalize_shape([4, 5])
(4, 5)
:param shape: an iterable of integers
:return: the equivalent normalized shape
"""
return tuple(int(i) for i in shape)


@hg.argument_helper(("graph", hg.CptGridGraph))
def linearize_vertex_weights(vertex_weights, graph=None, shape=None):
if shape is None or graph.num_vertices() == vertex_weights.shape[0]:
"""
Linearize the given ndarray according to the given shape
If ``shape`` is ``None``, the input array is returned.
Else if ``shape`` is a prefix of ``vertex_weights.shape`` then the ``len(shape)`` first dimensions of ``vertex_weights`` are
collapsed (linearized).
Else if ``vertex_weights.shape[0]`` is equal to the product of the elements in ``shape`` then the input array is
returned (array was already linearized).
Else: an exception is raised.
Note that this function tries its best to guess what should be done but some ambiguity might always exist.
:Examples:
>>> r = hg.linearize_vertex_weights(np.ones((4, 5)), shape=(4, 5))
>>> r.shape
(20,)
>>> r = hg.linearize_vertex_weights(np.ones((4, 5, 10, 12)), shape=(4, 5))
>>> r.shape
(20, 10, 12)
>>> r = hg.linearize_vertex_weights(np.ones((20,)), shape=(4, 5))
>>> r.shape
(20,)
>>> r = hg.linearize_vertex_weights(np.ones((20, 4, 5, 2, 3)), shape=(4, 5))
>>> r.shape
(20, 4, 5, 2, 3)
Raises:
ValueError: if ``vertex_weights.shape`` and ``shape`` are incompatible.
:See:
:func:`~higra.delinearize_vertex_weights`
:param vertex_weights: an ndarray representing vertex weights on a square nd-grid
:param graph: a graph (optional, Concept :class:`~higra.CptGridGraph`)
:param shape: a list of integers (optional, deduced from Concept :class:`~higra.CptGridGraph`)
:return: maybe reshaped vertex_weights
"""
if shape is None:
return vertex_weights

v_shape = vertex_weights.shape
if len(v_shape) < len(shape):
raise Exception("Vertex weights shape " + str(v_shape) +
" is not compatible with graph shape " + str(shape) + ".")
shape_prefix_of_v_shape = True
if len(v_shape) >= len(shape):
for i in range(len(shape)):
if v_shape[i] != shape[i]:
shape_prefix_of_v_shape = False
break
else:
shape_prefix_of_v_shape = False

flag = True
for i in range(len(shape)):
if v_shape[i] != shape[i]:
flag = False
break
if shape_prefix_of_v_shape:
return vertex_weights.reshape([graph.num_vertices()] + list(v_shape[len(shape):]))

if not flag:
raise Exception("Vertex weights shape " + str(v_shape) +
" is not compatible with graph shape " + str(shape) + ".")
num_elements = 1
for i in shape:
num_elements *= i

return vertex_weights.reshape([graph.num_vertices()] + list(v_shape[len(shape):]))
if num_elements == vertex_weights.shape[0]:
return vertex_weights

raise ValueError("Vertex weights shape " + str(v_shape) +
" is not compatible with graph shape " + str(shape) + ".")


@hg.argument_helper(("graph", hg.CptGridGraph))
def delinearize_vertex_weights(vertex_weights, graph=None, shape=None):
"""
De-Linearize the given ndarray according to the given shape
If ``shape`` is ``None``, the input array is returned.
Else if ``shape`` is a prefix of ``vertex_weights.shape`` then the input array is returned.
Else if ``vertex_weights.shape[0]`` is equal to the product of the elements in ``shape`` then the first dimension of
``vertex_weights`` is expanded (delinearized) into the sequence of dimensions given by ``shape``
Else: an exception is raised.
Note that this function tries its best to guess what should be done but some ambiguity might always exist.
Examples:
>>> r = hg.delinearize_vertex_weights(np.ones((20,)), shape=(4, 5))
>>> r.shape
(4, 5)
>>> r = hg.delinearize_vertex_weights(np.ones((20, 4, 5, 2, 3)), shape=(4, 5))
>>> r.shape
(4, 5, 4, 5, 2, 3)
>>> r = hg.delinearize_vertex_weights(np.ones((4, 5)), shape=(4, 5))
>>> r.shape
(4, 5)
>>> r = hg.delinearize_vertex_weights(np.ones((4, 5, 10, 12)), shape=(4, 5))
>>> r.shape
(4, 5, 10, 12)
Raises:
ValueError: if ``vertex_weights.shape`` and ``shape`` are incompatible.
:See:
:func:`~higra.linearize_vertex_weights`
:param vertex_weights: an ndarray representing vertex weights on a square nd-grid
:param graph: a graph (optional, Concept :class:`~higra.CptGridGraph`)
:param shape: a list of integers (optional, deduced from Concept :class:`~higra.CptGridGraph`)
:return: maybe reshaped vertex_weights
"""
if shape is None:
return vertex_weights

Expand All @@ -90,22 +203,29 @@ def delinearize_vertex_weights(vertex_weights, graph=None, shape=None):
if shape == v_shape[:len(shape)]:
return vertex_weights

if v_shape[0] != graph.num_vertices():
raise Exception("Vertex weights shape " + str(v_shape) +
" is not compatible with graph size " + str(graph.num_vertices()) + ".")
num_elements = 1
for i in shape:
num_elements *= i

if v_shape[0] != num_elements:
raise ValueError("Vertex weights shape " + str(v_shape) +
" is not compatible with graph size " + str(graph.num_vertices()) + ".")

return vertex_weights.reshape(list(shape) + list(v_shape[1:]))


def is_in_bijection(a, b):
"""
Given two numpy arrays a and b returns true iif
- a and b have same size
- there exists a bijective function f such that, for all i a(i) = f(b(i))
Given two numpy arrays a and b returns ``True`` if
- ``a`` and ``b`` have same size
- there exists a bijective function :math:`f` such that, for all :math:`i`, :math:`a(i) = f(b(i))`
:param a:
:param b:
:return:
Note that input arrays are flattened.
:param a: an nd array
:param b: an nd array
:return: ``True`` if a bijection exists and ``False`` otherwise
"""
aa = a.flatten()
bb = b.flatten()
Expand Down Expand Up @@ -137,13 +257,13 @@ def is_in_bijection(a, b):

def mean_angle_mod_pi(angles1, angles2):
"""
Compute the element wise mean of two arrays of angles (radian) handling a modulo pi wrapping
Compute the element wise mean of two arrays of angles (radian) handling a modulo :math:`\pi` wrapping
eg: the modulo pi mean angle between 0 and 3.0 is roughly 3.07
eg: the modulo :math:`\pi` mean angle between 0 and 3.0 is roughly 3.07
:param angles1: must be in [0; pi]
:param angles2: must be in [0; pi]
:return: average of angles1 and angles2 with correct wrapping in [0; pi]
:param angles1: must be in :math:`[0; \pi]`
:param angles2: must be in :math:`[0; \pi]`
:return: average of angles1 and angles2 with correct wrapping in :math:`[0; \pi]`
"""
min_angles = np.minimum(angles1, angles2)
max_angles = np.maximum(angles1, angles2)
Expand All @@ -164,13 +284,14 @@ def mean_angle_mod_pi(angles1, angles2):

def dtype_info(dtype):
"""
Returns a `numpy.iinfo` object if given dtype is an integral type,
and a `numpy.finfo` if given dtype os a float type.
Returns a ``numpy.iinfo`` object if given ``dtype`` is an integral type,
and a ``numpy.finfo`` if given ``dtype`` is a float type.
Raises an exception is dtype is a more complex type.
Raises:
ValueError: if ``dtype`` is a more complex type.
:param dtype:
:return:
:param dtype: a numpy dtype
:return: an info object on given ``dtype``
"""
int_types = (np.uint8, np.int8, np.uint16, np.int16, np.uint32, np.int32, np.uint64, np.int64)
float_types = (np.float32, np.float64)
Expand All @@ -180,7 +301,7 @@ def dtype_info(dtype):
elif dtype in float_types:
return np.finfo(dtype)
else:
raise TypeError("Given dtype is not suported or invalid.")
raise ValueError("Given dtype is not suported or invalid.")


__int8 = 'int8'
Expand Down Expand Up @@ -327,11 +448,11 @@ def common_type(*arrays, safety_level='minimum'):
In this case the function also guaranties that the result type is integral if all the arrays have integral types.
If safety level is equal to 'overflow', the result type ensures a to have a type suitable to contain the result of common
operations involving the input arrays (additions, divisions...). This relies on `numpy.common_type`: The return type will always
operations involving the input arrays (additions, divisions...). This relies on ``numpy.common_type``: The return type will always
be an inexact (i.e. floating point) scalar type, even if all the arrays are integer arrays. If one of the inputs is
an integer array, the minimum precision type that is returned is a 64-bit floating point dtype.
All input arrays except int64 and uint64 can be safely cast to the returned dtype without loss of information.
All input arrays except ``int64`` and ``uint64`` can be safely cast to the returned ``dtype`` without loss of information.
:param arrays: a sequence of numpy arrays
:param safety_level: either 'minimum' or 'overflow'
Expand All @@ -355,7 +476,7 @@ def cast_to_common_type(*arrays, safety_level='minimum'):
"""
Find a common type to a list of numpy arrays, cast all arrays that need to be cast to this type and returns the list of arrays (with some of them casted).
see :func:`~higra.common_type`
If safety level is equal to 'minimum', then the result type is the smallest that ensures that all values in all
arrays can be represented exactly in the given type (except for `np.uint64` which is allowed to fit in a `np.int64`!)
Expand All @@ -368,6 +489,9 @@ def cast_to_common_type(*arrays, safety_level='minimum'):
All input arrays except int64 and uint64 can be safely cast to the returned dtype without loss of information.
:See:
:func:`~higra.common_type`
:param arrays: a sequence of numpy arrays
:param safety_level: either 'minimum' or 'overflow'
:return: a list of arrays
Expand Down
4 changes: 0 additions & 4 deletions test/cpp/algo/test_alignment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ namespace test_alignment {

altitudes = xt::index_view(altitudes, node_map_res);

std::cout << tree_res.parents() << std::endl;
std::cout << supervertex_labelisation_res << std::endl;
std::cout << altitudes << std::endl;

auto sm = aligner.align_hierarchy(supervertex_labelisation_res, tree_res, altitudes);
auto sm_k = graph_4_adjacency_2_khalimsky(g, {3, 3}, sm);

Expand Down
Loading

0 comments on commit 7e2c3f9

Please sign in to comment.