Skip to content

Commit

Permalink
make Manifold as Module (#50)
Browse files Browse the repository at this point in the history
* make Manifold as Module

* add a record to changelog
  • Loading branch information
ferrine committed Mar 2, 2019
1 parent e080c25 commit 5fb91db
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Maintenance
-----------
* Add gitter chat (#31)
* Maintain torch>=1.0.0 only (#39)
* Manifolds are Modules (#49)

Deprecations
------------
Expand Down
11 changes: 9 additions & 2 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
import abc
import torch
import torch.nn
import re

__all__ = ["Manifold"]
Expand Down Expand Up @@ -148,7 +148,7 @@ def not_implemented(*args, **kwargs):
raise NotImplementedError


class Manifold(metaclass=ManifoldMeta):
class Manifold(torch.nn.Module, metaclass=ManifoldMeta):
r"""
Base class for Manifolds
Expand Down Expand Up @@ -191,6 +191,13 @@ class Manifold(metaclass=ManifoldMeta):
reversible = None
_default_order = 1

def __init__(self, **kwargs):
super().__init__()

def forward(self, *input):
# this removes all warnings about implementing abstract methods
raise TypeError("Manifold is not callable")

# noinspection PyAttributeOutsideInit
def set_default_order(self, order):
"""
Expand Down
1 change: 1 addition & 0 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class SphereSubspaceIntersection(Sphere):
name = "SphereSubspace"

def __init__(self, span):
super().__init__()
self._configure_manifold(span)
if (geoopt.linalg.batch_linalg.matrix_rank(self._projector) == 1).any():
raise ValueError(
Expand Down

0 comments on commit 5fb91db

Please sign in to comment.