This Notebook is here to present the **manifold factories** and how to use them with **metrics**. For a real use case, you can also look at *special_euclidean.py* in the geometry module.

In [1]:
import logging

import geomstats.backend as gs
from geomstats.geometry.manifold import AbstractManifoldFactory, Manifold
from geomstats.geometry.connection import Connection

gs.random.seed(2020)
#logging.getLogger().setLevel('DEBUG')

INFO: Using numpy backend


Let's say we have a family of manifolds called **tata** which can have some type of metrics .

depending of some parameter, we know that we want to have a specific manifold of this family.

In [2]:
# first let's create the factory
class TataManifoldFactory(AbstractManifoldFactory):
    """Factory for Tata Manifolds."""

    metrics_creators = {} # These are class variables
    manifolds_creators = {} 

In [3]:
# Now that we have a factory for our type of manifolds, we can create different types associated with this factory

@TataManifoldFactory.register(color='blue')
class BlueTataManifold(Manifold):
    def belongs(self, point, atol=gs.atol):
        return True

@TataManifoldFactory.register(color='yellow')
class YellowTataManifold(Manifold):
    def belongs(self, point, atol=gs.atol):
        return True

@TataManifoldFactory.register(dim=3, color='blue')
class MatrixBlueTataManifold(Manifold):
    def belongs(self, point, atol=gs.atol):
        return True

In [4]:
# let's now test what happens when we try to create some manifold
manifold_b = TataManifoldFactory.create(dim=2, color='blue')
print(manifold_b)

manifold_y = TataManifoldFactory.create(dim=2, color='yellow')
print(manifold_y)

manifold_mat_b = TataManifoldFactory.create(dim=3, color='blue')
print(manifold_mat_b)

try:
    manifold_bad = TataManifoldFactory.create(dim=2, color='grey')
except Exception as e:
   print(f"Not Happy! {e}")




<__main__.BlueTataManifold object at 0x7efb8566a400>
<__main__.YellowTataManifold object at 0x7efb8566a470>
<__main__.MatrixBlueTataManifold object at 0x7efb8566a4e0>
Not Happy! no manifold with key containing [('color', 'grey'), ('dim', 2)] .keys ars dict_keys([(('color', 'blue'),), (('color', 'yellow'),), (('color', 'blue'), ('dim', 3))])


We can see that the type of manifold created is what we would expect. 

For the next step, let's add some metrics registered to our factory.
We will have a first metric with a custom name and a second one without a name.

In [5]:
# TODO metrics args and kwargs

@TataManifoldFactory.registerMetric(name='TheName')
class FirstTataMetric(Connection):
    def __init__(self):
        super().__init__(dim=2)

    def dummy_method(self, my_arg, my_kwarg=None):
        return f"First, dummy method. my manifold is {self.manifold}"

@TataManifoldFactory.registerMetric()
class SecondTataMetric(Connection):
    def __init__(self):
        super().__init__(dim=2)

    def dummy_method(self, my_arg, my_kwarg=None):
        return f"Second, dummy method, args: {my_arg} , kwargs: {my_kwarg}"


Now we will create a new manifold with an instance of these two metrics associated to it.

In [6]:
print(f"the available metrics are: {TataManifoldFactory.metric_keys()}")

manifold_with_one_metric = TataManifoldFactory.create(dim=2, color='yellow', metrics_names='TheName')
manifold_with_metrics = TataManifoldFactory.create(dim=2, color='yellow', metrics_names=['TheName', 'SecondTataMetric'])
print(f"my newly created manifold is associated with the following metrics: {manifold_with_metrics.metrics}")

the available metrics are: dict_keys(['TheName', 'SecondTataMetric'])
my newly created manifold is associated with the following metrics: [<__main__.FirstTataMetric object at 0x7efb8566a278>, <__main__.SecondTataMetric object at 0x7efb8566a358>]


In [7]:
# we can see what is happening, when we give a bad name
manifold_with_bad_metric = TataManifoldFactory.create(dim=2, color='yellow', metrics_names='IDontExist')



In [12]:
# we can now call a method on all the metrics of this manifold with some args and/or some kwargs.
# When there is multiple metrics, the result is a dictionnary with the metrics names as keys.
manifold_with_metrics.call_method_on_metrics('dummy_method', 4, my_kwarg='YES')

{'FirstTataMetric': 'First, dummy method. my manifold is <__main__.YellowTataManifold object at 0x7f89a63913c8>',
 'SecondTataMetric': 'Second, dummy method, args: 4 , kwargs: YES'}

In [13]:
# If there is only one metric, the result of the method is directly returned.
manifold_with_one_metric.call_method_on_metrics('dummy_method', 4, my_kwarg='YES')

'First, dummy method. my manifold is <__main__.YellowTataManifold object at 0x7f89a63954a8>'