Skip to content

Commit

Permalink
Merge 03434a7 into 08a45b7
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrine committed Mar 2, 2019
2 parents 08a45b7 + 03434a7 commit b38d37c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
11 changes: 9 additions & 2 deletions geoopt/manifolds/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
import abc
import torch
import torch.nn
import re

__all__ = ["Manifold"]
Expand Down Expand Up @@ -148,7 +148,7 @@ def not_implemented(*args, **kwargs):
raise NotImplementedError


class Manifold(metaclass=ManifoldMeta):
class Manifold(torch.nn.Module, metaclass=ManifoldMeta):
r"""
Base class for Manifolds
Expand Down Expand Up @@ -191,6 +191,13 @@ class Manifold(metaclass=ManifoldMeta):
reversible = None
_default_order = 1

def __init__(self, **kwargs):
super().__init__()

def forward(self, *input):
# this removes all warnings about implementing abstract methods
raise TypeError("Manifold is not callable")

# noinspection PyAttributeOutsideInit
def set_default_order(self, order):
"""
Expand Down
1 change: 1 addition & 0 deletions geoopt/manifolds/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class SphereSubspaceIntersection(Sphere):
name = "SphereSubspace"

def __init__(self, span):
super().__init__()
self._configure_manifold(span)
if (geoopt.linalg.batch_linalg.matrix_rank(self._projector) == 1).any():
raise ValueError(
Expand Down

0 comments on commit b38d37c

Please sign in to comment.