Skip to content

Commit

Permalink
moar docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Mar 7, 2021
1 parent 8ce6ce9 commit ffedc1b
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 209 deletions.
34 changes: 2 additions & 32 deletions emlp/models/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,7 @@ def regular_mask(sumrep):
return mask


# class TensorBN(nn.BatchNorm0D): #TODO find discrepancies with pytorch version
# """ Equivariant Batchnorm for tensor representations.
# Applies BN on Scalar channels and Mean only BN on others """
# def __init__(self,rep):
# super().__init__(rep.size(),momentum=0.9)
# self.rep=rep
# def __call__(self,x,training): #TODO: support elementwise for regular reps
# return x
# smask = jax.device_put(scalar_mask(self.rep))
# rmask = jax.device_put(regular_mask(self.rep))
# if training:
# m = ragged_gather_scatter(x.mean(self.redux),self.rep)
# squared = ragged_gather_scatter((x ** 2).mean(self.redux),self.rep)
# v = squared - m ** 2
# v = jnp.where(smask|rmask,v,squared) #in non scalar indices, divide by sum squared
# m,v = m[None],v[None]
# self.running_mean.value += (1 - self.momentum) * (m - self.running_mean.value)
# self.running_var.value += (1 - self.momentum) * (v - self.running_var.value)
# else:
# m, v = self.running_mean.value, self.running_var.value
# g = ragged_gather_scatter(self.gamma.value[0],self.rep)
# b = ragged_gather_scatter(self.beta.value[0],self.rep)
# normed_scalars = g * (x - m) * F.rsqrt(v + self.eps) + b
# normed_regulars = normed_scalars
# normed_else = g*x*F.rsqrt(v + self.eps)
# normed_nonscalars = jnp.where(rmask,normed_regulars,normed_else)
# y = jnp.where(smask,normed_scalars,normed_nonscalars)#(x-m)*F.rsqrt(v + self.eps))
# return y # switch to or (x-m)


class TensorBN(nn.BatchNorm0D): #TODO find discrepancies with pytorch version
class TensorBN(nn.BatchNorm0D):
""" Equivariant Batchnorm for tensor representations.
Applies BN on Scalar channels and Mean only BN on others """
def __init__(self,rep):
Expand All @@ -106,7 +76,7 @@ def __call__(self,x,training): #TODO: support elementwise for regular reps
y = jnp.where(smask,self.gamma.value * (x - m) * F.rsqrt(v + self.eps) + self.beta.value,x*F.rsqrt(v+self.eps))#(x-m)*F.rsqrt(v + self.eps))
return y # switch to or (x-m)

class MaskBN(nn.BatchNorm0D): #TODO find discrepancies with pytorch version
class MaskBN(nn.BatchNorm0D):
""" Equivariant Batchnorm for tensor representations.
Applies BN on Scalar channels and Mean only BN on others """
def __init__(self,ch):
Expand Down
1 change: 0 additions & 1 deletion emlp/models/hamiltonian_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def symplectic_form(z):
def hamiltonian_dynamics(hamiltonian, z,t):
grad_h = grad(hamiltonian)
gh = grad_h(z)
#print(z.shape,gh.shape)
return symplectic_form(gh)

def HamiltonianFlow(H,z0,T):
Expand Down

0 comments on commit ffedc1b

Please sign in to comment.