Skip to content

Commit

Permalink
adding docstring autodocs (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Mar 8, 2021
1 parent 2254d22 commit f8e22f8
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 68 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'sphinx_autodoc_typehints',
'myst_nb',
]

autosummary_generate = True
intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
'numpy': ('https://numpy.org/doc/stable/', None),
Expand Down
11 changes: 10 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,20 @@ A type system for the automated construction of equivariant layers.

.. toctree::
:maxdepth: 2
:caption: Developer documentation
:caption: Developer Documentation

documentation.md


.. toctree::
:glob:
:maxdepth: 1
:caption: Package Reference

package/emlp.solver.representation
package/emlp.solver.groups
package/emlp.models.mlp

Indices and tables
==================

Expand Down
6 changes: 6 additions & 0 deletions docs/package/emlp.models.mlp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
EMLP
====

.. automodule:: emlp.models.mlp
:members:
:show-inheritance:
5 changes: 5 additions & 0 deletions docs/package/emlp.solver.groups.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Groups
======

.. automodule:: emlp.solver.groups
:members:
9 changes: 9 additions & 0 deletions docs/package/emlp.solver.representation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Representation
==============

.. autoclass:: Rep
:members:

.. automodule:: emlp.solver.representation
:members:

54 changes: 5 additions & 49 deletions emlp/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def Sequential(*args):
""" Wrapped to mimic pytorch syntax"""
return nn.Sequential(args)

@export
class LieLinear(nn.Linear): #
def __init__(self, repin, repout):
nin,nout = repin.size(),repout.size()
Expand All @@ -48,6 +49,7 @@ def __call__(self, x): # (cin) -> (cout)
logging.debug(f"linear out shape:{out.shape}")
return out

@export
class BiLinear(Module):
def __init__(self, repin, repout):
super().__init__()
Expand All @@ -61,7 +63,7 @@ def __call__(self, x,training=True):
out= .1*(W@x[...,None])[...,0]
return out


@export
class GatedNonlinearity(Module): #TODO: support elementwise swish for regular reps
def __init__(self,rep):
super().__init__()
Expand All @@ -71,6 +73,7 @@ def __call__(self,values):
activations = jax.nn.sigmoid(gate_scalars) * values[..., :self.rep.size()]
return activations

@export
class EMLPBlock(Module):
def __init__(self,rep_in,rep_out):
super().__init__()
Expand All @@ -87,32 +90,7 @@ def __call__(self,x):
preact =self.bilinear(lin)+lin
return self.nonlinearity(preact)

# class EResBlock(Module):
# def __init__(self,rep_in,rep_out):
# super().__init__()
# grep_in = gated(rep_in)
# grep_out = gated(rep_out)

# self.bn1 = TensorBN(grep_in)
# self.nonlinearity1 = GatedNonlinearity(rep_in)
# self.linear1 = LieLinear(rep_in,grep_out)

# self.bn2 = TensorBN(grep_out)
# self.nonlinearity2 = GatedNonlinearity(rep_out)
# self.linear2 = LieLinear(rep_out,grep_out)


# self.bilinear1 = BiLinear(grep_in,grep_out)
# #self.bilinear2 = BiLinear(gated(rep_out),gated(rep_out))
# self.shortcut = LieLinear(grep_in,grep_out) if rep_in!=rep_out else Sequential()
# def __call__(self,x,training=True):

# z = self.nonlinearity1(self.bn1(x,training=training))
# z = self.linear1(z)
# z = self.nonlinearity2(self.bn2(x,training=training))
# z = self.linear2(z)
# return (z+self.shortcut(x)+self.bilinear1(x))/3

@export
def uniform_rep(ch,group):
""" A heuristic method for allocating a given number of channels (ch)
into tensor types. Attempts to distribute the channels evenly across
Expand Down Expand Up @@ -182,28 +160,6 @@ def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@
def __call__(self,x,training=True):
return self.network(x)

# @export
# class EMLP2(Module):
# def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@
# super().__init__()
# logging.info("Initing EMLP")
# self.rep_in =rep_in(group)
# self.rep_out = rep_out(group)
# repmiddle = uniform_rep(ch,group)
# #reps = [self.rep_in]+
# reps = num_layers*[repmiddle]# + [self.rep_out]
# logging.debug(reps)
# self.network = Sequential(
# LieLinear(self.rep_in,gated(repmiddle)),
# *[EResBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
# TensorBN(gated(repmiddle)),
# GatedNonlinearity(repmiddle),
# LieLinear(repmiddle,self.rep_out)
# )
# def __call__(self,x,training=True):
# y = self.network(x,training=training)
# return y

def swish(x):
return jax.nn.sigmoid(x)*x

Expand Down
52 changes: 39 additions & 13 deletions emlp/solver/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
def rel_err(A,B):
return jnp.mean(jnp.abs(A-B))/(jnp.mean(jnp.abs(A)) + jnp.mean(jnp.abs(B))+1e-6)

@export
class Group(object,metaclass=Named):
lie_algebra = NotImplemented
""" Abstract Group Object which new groups should inherit from. """
lie_algebra = NotImplemented #: The continuous generators
#lie_algebra_lazy = NotImplemented
discrete_generators = NotImplemented
discrete_generators = NotImplemented #: The discrete generators
#discrete_generators_lazy = NotImplemented
z_scale=None # For scale noise for sampling elements
is_orthogonal=None
is_regular = None
d = None
d = None #: The dimension of the base representation
def __init__(self,*args,**kwargs):
# # Set dense lie_algebra using lie_algebra_lazy if applicable
# if self.lie_algebra is NotImplemented and self.lie_algebra_lazy is not NotImplemented:
Expand Down Expand Up @@ -73,14 +75,17 @@ def __init__(self,*args,**kwargs):


def exp(self,A):
""" Matrix exponential """
return expm(A)
def num_constraints(self):
return len(self.lie_algebra)+len(self.discrete_generators)

def sample(self):
"""Draw a sample from the group (not necessarily Haar measure)"""
return self.samples(1)[0]

def samples(self,N):
""" Draw N samples from the group (not necessarily Haar measure)"""
A_dense = jnp.stack([Ai@jnp.eye(self.d) for Ai in self.lie_algebra]) if len(self.lie_algebra) else jnp.zeros((0,self.d,self.d))
h_dense = jnp.stack([hi@jnp.eye(self.d) for hi in self.discrete_generators]) if len(self.discrete_generators) else jnp.zeros((0,self.d,self.d))
z = np.random.randn(N,A_dense.shape[0])
Expand Down Expand Up @@ -170,13 +175,15 @@ def __init__(self,G1,G2):
raise NotImplementedError

@export
class Trivial(Group): #""" The trivial group G={I} in N dimensions """
class Trivial(Group):
""" The trivial group G={I} in N dimensions """
def __init__(self,N):
self.d = N
super().__init__(N)

@export
class SO(Group): #""" The special orthogonal group SO(N) in N dimensions"""
class SO(Group):
""" The special orthogonal group SO(N) in N dimensions"""
def __init__(self,N):
self.lie_algebra = np.zeros(((N*(N-1))//2,N,N))
k=0
Expand All @@ -187,25 +194,29 @@ def __init__(self,N):
k+=1
super().__init__(N)
@export
class O(SO): #""" The Orthogonal group O(N) in N dimensions"""
class O(SO):
""" The Orthogonal group O(N) in N dimensions"""
def __init__(self,N):
self.discrete_generators = np.eye(N)[None]
self.discrete_generators[0,0,0]=-1
super().__init__(N)
@export
class C(Group): #""" The Cyclic group Ck in 2 dimensions"""
class C(Group):
""" The Cyclic group Ck in 2 dimensions"""
def __init__(self,k):
theta = 2*np.pi/k
self.discrete_generators = np.zeros((1,2,2))
self.discrete_generators[0,:,:] = np.array([[np.cos(theta),np.sin(theta)],[-np.sin(theta),np.cos(theta)]])
super().__init__(k)
@export
class D(C): #""" The Dihedral group Dk in 2 dimensions"""
class D(C):
""" The Dihedral group Dk in 2 dimensions"""
def __init__(self,k):
super().__init__(k)
self.discrete_generators = np.concatenate((self.discrete_generators,np.array([[[-1,0],[0,1]]])))
@export
class Scaling(Group):
""" The scaling group Dk in 2 dimensions"""
def __init__(self,N):
self.lie_algebra = np.eye(N)[None]
super().__init__(N)
Expand All @@ -219,7 +230,8 @@ class TimeReversal(Group): #""" The time reversal group in 1+3 dimensions"""
discrete_generators[0,0,0] = -1

@export
class SO13p(Group): #""" The component of Lorentz group connected to identity"""
class SO13p(Group):
""" The component of Lorentz group connected to identity"""
lie_algebra = np.zeros((6,4,4))
lie_algebra[3:,1:,1:] = SO(3).lie_algebra
for i in range(3):
Expand All @@ -234,6 +246,7 @@ class SO13(SO13p):

@export
class O13(SO13p):
""" The full lorentz group (including Parity and Time reversal)"""
discrete_generators = np.eye(4)[None] +np.zeros((2,1,1))
discrete_generators[0] *= -1
discrete_generators[1,0,0] = -1
Expand All @@ -252,6 +265,8 @@ class O11(SO11p):

@export
class Sp(Group):
""" Symplectic group Sp(m) in 2m dimensions (sometimes referred to
instead as Sp(2m)"""
def __init__(self,m):
self.lie_algebra = np.zeros((m*(2*m+1),2*m,2*m))
k=0
Expand All @@ -275,6 +290,8 @@ class Symplectic(Sp): pass

@export
class Z(Group):
r""" The cyclic group Z_n (discrete translation group) of order n.
Features a regular base representation."""
def __init__(self,n):
self.discrete_generators = [LazyShift(n)]
super().__init__(n)
Expand All @@ -284,6 +301,7 @@ class DiscreteTranslation(Z): pass # Alias cyclic translations with Z

@export
class S(Group): #The permutation group
r""" The permutation group S_n with an n dimensional regular representation."""
def __init__(self,n):
#K=n//5
# perms = np.arange(n)[None]+np.zeros((K,1)).astype(int)
Expand Down Expand Up @@ -311,6 +329,7 @@ class Permutation(S): pass #Alias permutation group with Sn.
@export
class U(Group): # Of dimension n^2
def __init__(self,n):
""" The unitary group U(n) in n dimensions (complex)"""
lie_algebra_real = np.zeros((n**2,n,n))
lie_algebra_imag = np.zeros((n**2,n,n))
k=0
Expand All @@ -333,6 +352,7 @@ def __init__(self,n):
@export
class SU(Group): # Of dimension n^2-1
def __init__(self,n):
""" The special unitary group SU(n) in n dimensions (complex)"""
if n==1: return Trivial(1)
lie_algebra_real = np.zeros((n**2-1,n,n))
lie_algebra_imag = np.zeros((n**2-1,n,n))
Expand All @@ -359,8 +379,8 @@ def __init__(self,n):

@export
class Cube(Group):
# A discrete version of SO(3) including all 90 degree rotations in 3d space
# Implements a 6 dimensional representation on the faces of a cube
""" A discrete version of SO(3) including all 90 degree rotations in 3d space
Implements a 6 dimensional representation on the faces of a cube"""
def __init__(self):
order = np.arange(6) # []
Fperm = np.array([4,1,0,3,5,2])
Expand All @@ -382,10 +402,10 @@ def unpad(padded_perm):





@export
class RubiksCube(Group): #3x3 rubiks cube
r""" The Rubiks cube group G<S_48 consisting of all valid 3x3 Rubik's cube transformations.
Generated by the a quarter turn about each of the faces."""
def __init__(self):
#Faces are ordered U,F,R,B,L,D (the net of the cube) # B
order = np.arange(48) # L U R
Expand Down Expand Up @@ -463,7 +483,13 @@ def __init__(self,k,n):
self.discrete_generators = [LazyKron([Ik,nshift,In]),LazyKron([Ik,In,nshift]),LazyKron([kshift,Rot90(n,4//k)])]
super().__init__(k,n)

@export
class Embed(Group):
""" A method to embed a given base group representation in larger vector space.
Inputs:
G: the group (and base representation) to embed
d: the dimension in which to embed
slice: a slice object specifying which dimensions G acts on."""
def __init__(self,G,d,slice):
self.lie_algebra = np.zeros((G.lie_algebra.shape[0],d,d))
self.discrete_generators = np.zeros((G.discrete_generators.shape[0],d,d))
Expand Down
15 changes: 11 additions & 4 deletions emlp/solver/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
import matplotlib.pyplot as plt
from functools import reduce
import emlp.solver
from oil.utils.utils import export

#TODO: add rep,v = flatten({'Scalar':..., 'Vector':...,}), to_dict(rep,vector) returns {'Scalar':..., 'Vector':...,}
#TODO and simpler rep = flatten({Scalar:2,Vector:10,...}),
# Do we even want + operator to implement non canonical orderings?

__all__ = ["V", "Scalar"]

@export
class Rep(object):
""" The base Representation class. Representation objects formalize the vector space V
on which the group acts, the group representation matrix ρ(g), and the Lie Algebra
Expand Down Expand Up @@ -214,6 +218,7 @@ def __rmul__(self,other):
return other

class Base(Rep):
""" Base representation V of a group."""
def __init__(self,G=None):
self.G=G
self.concrete = (G is not None)
Expand Down Expand Up @@ -245,6 +250,7 @@ def T(self):
return Dual(self.G)

class Dual(Base):
""" The dual representation V* of the Base representation of a group."""
def __new__(cls,G=None):
if G is not None and G.is_orthogonal: return Base(G)
else: return super(Dual,cls).__new__(cls)
Expand Down Expand Up @@ -290,14 +296,15 @@ def __lt__(self,other):
# return super().__lt__(other)
# def size(self):
# return self.rep.size()
V=Vector= Base() #: Alias V or Vector for an instance of the Base representation of a group

Scalar = ScalarRep()#: An instance of the Scalar representation, equivalent to V**0

V=Vector= Base()
Scalar = ScalarRep()#V**0
@export
def T(p,q=0,G=None):
""" A convenience function for creating rank (p,q) tensors."""
return (V**p*V.T**q)(G)


def orthogonal_complement(proj):
""" Computes the orthogonal complement to a given matrix proj"""
U,S,VH = jnp.linalg.svd(proj,full_matrices=True)
Expand Down Expand Up @@ -453,7 +460,7 @@ def mul_part(bparams,x,bids):
b = prod(x.shape[:-1])
return (bparams@x[...,bids].T.reshape(bparams.shape[-1],-1)).reshape(-1,b).T


@export
def vis(repin,repout,cluster=True):
""" A function to visualize the basis of equivariant maps repin>>repout
as an image. Only use cluster=True if you know Pv will only have
Expand Down

0 comments on commit f8e22f8

Please sign in to comment.