Skip to content

Commit

Permalink
exposing groups, V,T,Rep, EMLP layers to __all__
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Mar 8, 2021
1 parent aefad18 commit b5de640
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 54 deletions.
10 changes: 10 additions & 0 deletions emlp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import importlib
import pkgutil
__all__ = []
for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
module = importlib.import_module('.'+module_name,package=__name__)
try:
globals().update({k: getattr(module, k) for k in module.__all__})
__all__ += module.__all__
except AttributeError: continue
# concatenate the __all__ from each of the submodules (expose to user)
5 changes: 5 additions & 0 deletions emlp/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from emlp.models import mlp
from emlp.models import batchnorm
__all__ = []
__all__ += mlp.__all__
__all__ += batchnorm.__all__
5 changes: 3 additions & 2 deletions emlp/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from functools import partial
import objax
from functools import lru_cache as cache
from emlp.solver.utils import export

def gated(sumrep):
return sumrep+sum([Scalar(rep.G) for rep in sumrep if rep!=Scalar and not rep.is_regular])
Expand Down Expand Up @@ -55,15 +56,15 @@ def regular_mask(sumrep):
i+=rep.size()
return mask


@export
class TensorBN(nn.BatchNorm0D):
""" Equivariant Batchnorm for tensor representations.
Applies BN on Scalar channels and Mean only BN on others """
def __init__(self,rep):
super().__init__(rep.size(),momentum=0.9)
self.rep=rep
def __call__(self,x,training): #TODO: support elementwise for regular reps
return x #DISABLE BN, harms performance!! !!
#return x #DISABLE BN, harms performance!! !!
smask = jax.device_put(scalar_mask(self.rep))
if training:
m = x.mean(self.redux, keepdims=True)
Expand Down
13 changes: 12 additions & 1 deletion emlp/solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# from .product_sum_reps import SumRep,DeferredSumRep,ProductRep,DeferredProductRep,DirectProduct
# __all__=["SumRep","DeferredSumRep","ProductRep","DeferredProductRep","DirectProduct"]
# __all__=["SumRep","DeferredSumRep","ProductRep","DeferredProductRep","DirectProduct"]
import importlib
import pkgutil
__all__ = []
for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):
module = importlib.import_module('.'+module_name,package=__name__)
try:
globals().update({k: getattr(module, k) for k in module.__all__})
__all__ += module.__all__
except AttributeError: continue

# concatenate __all__ from each of the modules
72 changes: 23 additions & 49 deletions emlp/solver/groups.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

import numpy as np
from scipy.linalg import expm
from oil.utils.utils import Named,export
from oil.utils.utils import Named
from .utils import export
import jax
import jax.numpy as jnp
from objax.nn.init import kaiming_normal, xavier_normal
Expand All @@ -27,19 +28,10 @@ class Group(object,metaclass=Named):
z_scale=None # For scale noise for sampling elements
is_orthogonal=None
is_regular = None
d = None #: The dimension of the base representation
d = NotImplemented #: The dimension of the base representation
def __init__(self,*args,**kwargs):
# # Set dense lie_algebra using lie_algebra_lazy if applicable
# if self.lie_algebra is NotImplemented and self.lie_algebra_lazy is not NotImplemented:
# Idense = np.eye(self.lie_algebra_lazy[0].shape[0])
# self.lie_algebra = np.stack([h@Idense for h in self.lie_algebra_lazy])
# # Set dense discrete_generators using discrete_generators_lazy if applicable
# if self.discrete_generators is NotImplemented and self.discrete_generators_lazy is not NotImplemented:
# Idense = np.eye(self.discrete_generators_lazy[0].shape[0])
# self.discrete_generators = np.stack([h@Idense for h in self.discrete_generators_lazy])

# get the dimension of the base group representation
if self.d is None:
if self.d is NotImplemented:
if self.lie_algebra is not NotImplemented and len(self.lie_algebra):
self.d= self.lie_algebra[0].shape[-1]
if self.discrete_generators is not NotImplemented and len(self.discrete_generators):
Expand Down Expand Up @@ -109,15 +101,10 @@ def __repr__(self):
return outstr
def __eq__(self,G2): #TODO: check that spans are equal?
return repr(self)==repr(G2)
# if self.lie_algebra.shape!=G2.lie_algebra.shape or \
# self.discrete_generators.shape!=G2.discrete_generators.shape:
# return False
# return (self.lie_algebra==G2.lie_algebra).all() and (self.discrete_generators==G2.discrete_generators).all()

def __hash__(self):
return hash(repr(self))
# algebra = jax.device_get(self.lie_algebra).tobytes()
# gens = jax.device_get(self.discrete_generators).tobytes()
# return hash((algebra,gens,self.lie_algebra.shape,self.discrete_generators.shape))

def __lt__(self, other):
return hash(self) < hash(other) #For sorting purposes only

Expand Down Expand Up @@ -153,8 +140,6 @@ def noise2samples(zs,ks,lie_algebra,discrete_generators,seed=0):
return vmap(noise2sample,(0,0,None,None,None),0)(zs,ks,lie_algebra,discrete_generators,seed)




class DirectProduct(Group):
def __init__(self,G1,G2):
I1,I2 = I(G1.d),I(G2.d)
Expand Down Expand Up @@ -193,13 +178,15 @@ def __init__(self,N):
self.lie_algebra[k,j,i] = -1
k+=1
super().__init__(N)

@export
class O(SO):
""" The Orthogonal group O(N) in N dimensions"""
def __init__(self,N):
self.discrete_generators = np.eye(N)[None]
self.discrete_generators[0,0,0]=-1
super().__init__(N)

@export
class C(Group):
""" The Cyclic group Ck in 2 dimensions"""
Expand Down Expand Up @@ -239,7 +226,7 @@ class SO13p(Group):

# Adjust variance for samples along boost generators. For equivariance checks
# the exps for high order tensors can get very large numbers
z_scale = np.array([.3,.3,.3,1,1,1])
z_scale = np.array([.3,.3,.3,1,1,1]) # can get rid of now
@export
class SO13(SO13p):
discrete_generators = -np.eye(4)[None]
Expand All @@ -255,18 +242,20 @@ class Lorentz(O13): pass

@export
class SO11p(Group):
""" The identity component of O(1,1) (Lorentz group in 1+1 dimensions)"""
lie_algebra = np.array([[0.,1.],[1.,0.]])[None]

@export
class O11(SO11p):
""" The Lorentz group O(1,1) in 1+1 dimensions """
discrete_generators = np.eye(2)[None]+np.zeros((2,1,1))
discrete_generators[0]*=-1
discrete_generators[1,0,0] = -1

@export
class Sp(Group):
""" Symplectic group Sp(m) in 2m dimensions (sometimes referred to
instead as Sp(2m)"""
instead as Sp(2m) )"""
def __init__(self,m):
self.lie_algebra = np.zeros((m*(2*m+1),2*m,2*m))
k=0
Expand All @@ -284,10 +273,7 @@ def __init__(self,m):
self.lie_algebra[k,j,m+i] = 1
k+=1
super().__init__(m)

@export
class Symplectic(Sp): pass


@export
class Z(Group):
r""" The cyclic group Z_n (discrete translation group) of order n.
Expand All @@ -296,40 +282,23 @@ def __init__(self,n):
self.discrete_generators = [LazyShift(n)]
super().__init__(n)

@export
class DiscreteTranslation(Z): pass # Alias cyclic translations with Z

@export
class S(Group): #The permutation group
r""" The permutation group S_n with an n dimensional regular representation."""
def __init__(self,n):
#K=n//5
# perms = np.arange(n)[None]+np.zeros((K,1)).astype(int)
# for i in range(1,K):
# perms[i,[0,(i*n)//K]] = perms[i,[(i*n)//K,0]]
# print(perms)
# self.discrete_generators = [LazyPerm(perm) for perm in perms]+[LazyShift(n)]
perms = np.arange(n)[None]+np.zeros((n-1,1)).astype(int)
perms[:,0] = np.arange(1,n)
perms[np.arange(n-1),np.arange(1,n)[None]]=0
self.discrete_generators = [LazyPerm(perm) for perm in perms]
#self.discrete_generators = [SwapMatrix((0,i),n) for i in range(1,n)]
# Adding superflous extra generators can actually *decrease* the runtime of the iterative
# krylov solver by improving the conditioning of the constraint matrix
# swap_perm = np.arange(n).astype(int)
# swap_perm[[0,1]] = swap_perm[[1,0]]
# swap_perm2 = np.arange(n).astype(int)
# swap_perm2[[0,n//2]] = swap_perm2[[n//2,0]]
# self.discrete_generators = [LazyPerm(swap_perm)]+[LazyShift(n,2**i) for i in range(int(np.log2(n)))]
# Adding superflous extra generators surprisingly can sometimes actually *decrease*
# the runtime of the iterative krylov solver by improving the conditioning
# of the constraint matrix
super().__init__(n)

@export
class Permutation(S): pass #Alias permutation group with Sn.

@export
class U(Group): # Of dimension n^2
""" The unitary group U(n) in n dimensions (complex)"""
def __init__(self,n):
""" The unitary group U(n) in n dimensions (complex)"""
lie_algebra_real = np.zeros((n**2,n,n))
lie_algebra_imag = np.zeros((n**2,n,n))
k=0
Expand All @@ -351,8 +320,8 @@ def __init__(self,n):
super().__init__(n)
@export
class SU(Group): # Of dimension n^2-1
""" The special unitary group SU(n) in n dimensions (complex)"""
def __init__(self,n):
""" The special unitary group SU(n) in n dimensions (complex)"""
if n==1: return Trivial(1)
lie_algebra_real = np.zeros((n**2-1,n,n))
lie_algebra_imag = np.zeros((n**2-1,n,n))
Expand Down Expand Up @@ -471,6 +440,8 @@ def __init__(self):

@export
class ZksZnxZn(Group):
""" One of the original GCNN groups ℤₖ⋉(ℤₙ×ℤₙ) for translation in x,y
and rotation with the discrete 90 degree rotations (k=4) or 180 degree (k=2)"""
def __init__(self,k,n):
Zn = Z(n)
Zk = Z(k)
Expand Down Expand Up @@ -504,12 +475,15 @@ def __repr__(self):

@export
def SO2eR3():
""" SO(2) embedded in R^3 with rotations about z axis"""
return Embed(SO(2),3,slice(2))

@export
def O2eR3():
""" O(2) embedded in R^3 with rotations about z axis"""
return Embed(O(2),3,slice(2))

@export
def DkeR3(k):
""" Dihedral D(k) embedded in R^3 with rotations about z axis"""
return Embed(D(k),3,slice(2))
2 changes: 1 addition & 1 deletion emlp/solver/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import matplotlib.pyplot as plt
from functools import reduce
import emlp.solver
from oil.utils.utils import export
from .utils import export

#TODO: add rep,v = flatten({'Scalar':..., 'Vector':...,}), to_dict(rep,vector) returns {'Scalar':..., 'Vector':...,}
#TODO and simpler rep = flatten({Scalar:2,Vector:10,...}),
Expand Down
10 changes: 9 additions & 1 deletion emlp/solver/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle, atexit
import logging

import sys
import time
import io
from functools import reduce
Expand All @@ -10,6 +10,14 @@
prod = lambda c: reduce(lambda a,b:a*b,c)



def export(fn):
mod = sys.modules[fn.__module__]
if hasattr(mod, '__all__'):
mod.__all__.append(fn.__name__)
else:
mod.__all__ = [fn.__name__]
return fn
# class TqdmToLogger(io.StringIO):
# """
# Output stream for TQDM which will output to logger module instead of
Expand Down

0 comments on commit b5de640

Please sign in to comment.