Skip to content

Commit

Permalink
removed gated from bn, docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Mar 10, 2021
1 parent debaaa3 commit cc353b0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
7 changes: 3 additions & 4 deletions emlp/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from functools import lru_cache as cache
from emlp.solver.utils import export

def gated(sumrep):
return sumrep+sum([Scalar(rep.G) for rep in sumrep if rep!=Scalar and not rep.is_regular])


@cache(maxsize=None)
def gate_indices(sumrep): #TODO: add regular
def gate_indices(sumrep): #TODO: add support for mixed_tensors
""" Indices for scalars, and also additional scalar gates
added by gated(sumrep)"""
assert isinstance(sumrep,SumRep), f"unexpected type for gate indices {type(sumrep)}"
Expand Down Expand Up @@ -57,7 +56,7 @@ def regular_mask(sumrep):
return mask

@export
class TensorBN(nn.BatchNorm0D):
class TensorBN(nn.BatchNorm0D): #TODO: add suport for mixed tensors.
""" Equivariant Batchnorm for tensor representations.
Applies BN on Scalar channels and Mean only BN on others """
def __init__(self,rep):
Expand Down
12 changes: 9 additions & 3 deletions emlp/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,17 @@ def __call__(self, x,training=True):
return out

@export
def gated(sumrep):
def gated(sumrep): #TODO: generalize to mixed tensors?
""" Returns the rep with an additional scalar 'gate' for each of the nonscalars and non regular
reps in the input. To be used as the output for linear (and or bilinear) layers directly
before a :func:`GatedNonlinearity` to produce its scalar gates. """
return sumrep+sum([Scalar(rep.G) for rep in sumrep if rep!=Scalar and not rep.is_regular])

@export
class GatedNonlinearity(Module):
class GatedNonlinearity(Module): #TODO: add support for mixed tensors and non sumreps
""" Gated nonlinearity. Requires input to have the additional gate scalars
for every non regular and non scalar rep. Applies swish to regular and
scalar reps. (Right now assumes rep is a SumRep. TODO: extend to non sumreps)"""
scalar reps. (Right now assumes rep is a SumRep)"""
def __init__(self,rep):
super().__init__()
self.rep=rep
Expand All @@ -101,6 +101,12 @@ def __call__(self,x):
preact =self.bilinear(lin)+lin
return self.nonlinearity(preact)

def uniform_rep_general(ch,*rep_types):
""" adds all combinations of (powers of) rep_types up to
a total of ch channels."""
#TODO: write this function
raise NotImplementedError

@export
def uniform_rep(ch,group):
""" A heuristic method for allocating a given number of channels (ch)
Expand Down

0 comments on commit cc353b0

Please sign in to comment.