Skip to content

Commit

Permalink
fixed training loop example in docs, docs updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Mar 5, 2021
1 parent b37e50f commit 4c428b2
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 99 deletions.
200 changes: 137 additions & 63 deletions docs/notebooks/building_a_model.ipynb

Large diffs are not rendered by default.

102 changes: 73 additions & 29 deletions docs/notebooks/building_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,28 @@ For convenience, we store in the dataset the types for the input and the output.

```{code-cell} ipython3
from emlp.models.mlp import EMLP,MLP
model = EMLP(dataset.rep_in,dataset.rep_out,group=G,num_layers=3,ch=384)
model = EMLP(trainset.rep_in,trainset.rep_out,group=G,num_layers=3,ch=384)
# uncomment the following line to instead try the MLP baseline
#model = MLP(dataset.rep_in,dataset.rep_out,group=G,num_layers=3,ch=384)
#model = MLP(trainset.rep_in,trainset.rep_out,group=G,num_layers=3,ch=384)
```

## Example Objax Training Loop

+++

We build our EMLP model with [objax](https://objax.readthedocs.io/en/latest/) because we feel the object oriented design makes building complicated layers easier. Below is a minimal training loop that you could use to train EMLP.

```{code-cell} ipython3
BS=500
lr=3e-3
NUM_EPOCHS=1000
NUM_EPOCHS=500
import objax
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from jax import vmap
opt = objax.optimizer.Adam(model.vars())
Expand All @@ -87,10 +94,9 @@ def loss(x, y,training=True):
yhat = model(x, training=training)
return ((yhat-y)**2).mean()
gv = objax.GradValues(loss, model.vars())
@objax.Function.with_vars(model.vars())
@objax.Function.with_vars(model.vars()+opt.vars())
def train_op(x, y, lr):
g, v = gv(x, y)
opt(lr=lr, grads=g)
Expand All @@ -117,45 +123,83 @@ import matplotlib.pyplot as plt
plt.plot(np.arange(NUM_EPOCHS),train_losses,label='Train loss')
plt.plot(np.arange(0,NUM_EPOCHS,10),test_losses,label='Test loss')
plt.legend()
plt.yscale('log')
```

```{code-cell} ipython3
def rel_err(a,b):
return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean()))#
rin,rout = trainset.rep_in(G),trainset.rep_out(G)
def equivariance_err(mb):
x,y = mb
x,y= jnp.array(x),jnp.array(y)
gs = G.samples(x.shape[0])
rho_gin = vmap(rin.rho_dense)(gs)
rho_gout = vmap(rout.rho_dense)(gs)
y1 = model((rho_gin@x[...,None])[...,0],training=False)
y2 = (rho_gout@model(x,training=False)[...,None])[...,0]
return rel_err(y1,y2)
```

As expected, the network continues to be equivariant as it is trained.

```{code-cell} ipython3
print(f"Average test equivariance error {np.mean([equivariance_err(mb) for mb in testloader]):.2e}")
```

## Equivariant Linear Layers (low level)

+++

Internally for EMLP, we use representations that uniformly allocate dimensions between different tensor representations.

```{code-cell} ipython3
from emlp.models.mlp import uniform_rep
r = uniform_rep(512,G)
print(r)
```

Below is a trimmed down version of EMLP, so you can see how it is built from the component layers.

```{code-cell} ipython3
# class EMLP(Module,metaclass=Named):
# def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@
# super().__init__()
# logging.info("Initing EMLP")
# self.rep_in =rep_in(group)
# self.rep_out = rep_out(group)
# self.G=group
# # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
# if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)]
# elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
# else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
# #assert all((not rep.G is None) for rep in middle_layers[0].reps)
# reps = [self.rep_in]+middle_layers
# #logging.info(f"Reps: {reps}")
# self.network = Sequential(
# *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
# LieLinear(reps[-1],self.rep_out)
# )
# #self.network = LieLinear(self.rep_in,self.rep_out)
# def __call__(self,x,training=True):
# return self.network(x)
from objax.module import Module
from emlp.models.mlp import Sequential,EMLPBlock,LieLinear
class EMLP(Module):
def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):
super().__init__()
reps = [rep_in(group)]+num_layers*[uniform_rep(ch,group)]
self.network = Sequential(
*[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
LieLinear(reps[-1],rep_out(group))
)
def __call__(self,x,training=True):
return self.network(x)
```

```{code-cell} ipython3
```

## Equivariant Linear Layers (low level)

```{code-cell} ipython3
```

```{code-cell} ipython3
# from emlp.models.mlp import Standardize
# from emlp.models.model_trainer import RegressorPlus
# from emlp.slax.utils import LoaderTo
# BS=500
# lr=3e-3
# NUM_EPOCHS=100
# dataloaders = {k:LoaderTo(DataLoader(v,batch_size=BS,shuffle=(k=='train'),
# num_workers=0,pin_memory=False)) for k,v in {'train':trainset,'test':testset}.items()}
# dataloaders['Train'] = dataloaders['train']
# opt_constr = objax.optimizer.Adam
# lr_sched = lambda e: lr
# trainer = RegressorPlus(model,dataloaders,opt_constr,lr_sched,log_args={'minPeriod':.02,'timeFrac':.25})
# trainer.train(NUM_EPOCHS)
```
101 changes: 100 additions & 1 deletion docs/notebooks/mixed_tensors.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,106 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# mixed_tensors"
"# Combining Representations from Different Groups (experimental)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from emlp.solver.groups import *\n",
"from emlp.solver.representation import T,vis"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"rep = 2*T(1)(Z(3))*T(1)(S(4))+T(1)(SO(2))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"V²+4V_S(4)⊗V_Z(3)⊗V_SO(2)+4V²_S(4)⊗V²_Z(3)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(rep>>rep)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIJ0lEQVR4nO3dT+jfdR0H8Pe0UlIIIle5P6HSamixWMmKTagYZtEkwUAvGa7LDrWDHZLw4GEd1KjlQaqD0mHUIdmCmHpqRZEGmTDopLUaiSNiVma1tS7dmu/3m96f93fP33o8jvt8vu/vez9+z71hr/f7/Vp37ty5AuS55EJPADg/4YRQwgmhhBNCCSeEel3t4a5bH6j+V+7lP3h62dm8hjMf2V59/sJt1b/GIt72k3XV56f2vDp9DqWUcmznw9XnNx364vQ5XPH7+s+ilFJOb/vH9Hls2vjH6vN7rn1i+hwe2fOJ4TGOHj9w3h+olRNCCSeEEk4IJZwQSjghlHBCKOGEUNUCYauO+eonb2x+wRK10FYd85rvnxn6fI9WHfOqI5cPj9GjVcc8dscDQ5/v0VPDfNOzbxgeo6VVx3zw+ZuHPn+hWTkhlHBCKOGEUMIJoYQTQgknhBJOCCWcEGqoOt+zwaC1USFhk0LPGC09GwxaGxUSNin0jNGjtclgLWxS6BljJisnhBJOCCWcEEo4IZRwQijhhFDCCaHW1VoA7r7k9un9AXsObP/u5vYlxqNatdATe89On0PXge33zf9ZtGqhH3pq//Q5tOqgpZRy/+cfnT6PVi30yv3jB/ldKg1rjHBCKOGEUMIJoYQTQgknhBJOCCWcEKq6CWHHnQ9VNyGsqpvz1YfqBelVdNhO6K5dSikb3/lS9fkqDgd/4LL6HErJ6LC9iu7aWx88PTyGTQiwxggnhBJOCCWcEEo4IZRwQijhhFDV4lxKN+fRDtsXy8XVpWR0c+6pYSZ02F7FxdUzWTkhlHBCKOGEUMIJoYQTQgknhBJOCDVUeEtpGDtaB+0Zo6WnhtmqhSbUQXvG6DHaxDehDtozxkxWTgglnBBKOCGUcEIo4YRQwgmhhBNCCSeEmn4T8hIHtkf1bDDo2agwaokD26W9h6CqZ4NBz0aFUaObFEoZ77Dds8GgZ6PCLFZOCCWcEEo4IZRwQijhhFDCCaGEE0JVC2+bv31p9cNLHA5+cWe1P+9/7Kg+XaaJb73mtSmggW8ppdx38K7q81UcDt607cXmO6MHtp/5+/rmOz/d/bXq8yUObP91Y8/v5xxWTgglnBBKOCGUcEIo4YRQwgmhhBNCVQuVKQ1jR8+EXiwXV5eS0TB2iTOhF8vF1TNZOSGUcEIo4YRQwgmhhBNCCSeEEk4IJZwQamiHQEo35yUurh7dqLDExdUJmxR6xugx2mE7YZNCzxgzWTkhlHBCKOGEUMIJoYQTQgknhBJOCDW9ee4SB7ZP7B2bQ08NM6GJb18D338NzWGRhrE3DU2hlDJeB11CTw2zVQu9+9HPLjWd/2LlhFDCCaGEE0IJJ4QSTgglnBBKOCGUcEKo6g6BLY+9Uv3w+7/17PgMdrZfeflg/aXXvzK/+/CBF+obCO7b9anh77jylyeb72z90Z+rzzc8dXZ4HuWG+uMvH7+1OcRzNx4amsKe9zzefGfrz/ZVn3/h04eH5lBKKYf/8q7hMf5XVk4IJZwQSjghlHBCKOGEUMIJoYQTQlXrnK065i8+t635BUvUQlt1zH++cd3Q53u06pj3/7hdl1uiFtqqY57cfenQ53v01DDf+/Qdw2O0tOqYX/9evR67RB10JisnhBJOCCWcEEo4IZRwQijhhFDCCaGGLpXuqWG2aqEJddCeMVp6apitWmhCHbRnjB6tOuZaqIP2jDGTlRNCCSeEEk4IJZwQSjghlHBCKOGEUMIJodadO/faxfd7n7tt+m3NPQe2X77uitnTaG5UWP/kielz6Dmw/chLH54+j9ZGhR/++tj0ObQ2KZRSyr4t8+fR2qhwzXdPDX/H0eMHzvvLZ+WEUMIJoYQTQgknhBJOCCWcEEo4IVS1zvnMb99RrXMucTi4x9m3v7n6fJEmvg1PBDTwLaWUP727/u/pKg4HH9lxXfOdRZr4Nvz8D5urz5c4sN3y8Y/ePjyGOiesMcIJoYQTQgknhBJOCCWcEEo4IZRwQqjqje8p3ZxHO2xfLLfKl5LRzblng0FCh+1V3Co/k5UTQgknhBJOCCWcEEo4IZRwQijhhFBDna1TujmP1kF7xmjpqWG2aqEJddCeMXqMdthOqIP2jDGTlRNCCSeEEk4IJZwQSjghlHBCKOGEUNVLpW/ZvH/6Tck9Z0IfP7199jSatdCEBr6llHL3PUemz6NVC915y6+mz6FVBy0lo4nvxi+dGf4Ol0rDGiOcEEo4IZRwQijhhFDCCaGEE0IJJ4SqbkLYvver1U0Iq+rm/I2vHKw+X0WH7YTu2qWUcuQ3N1Sfr+Jw8NZv7mu+k9BhexXdtU/edfXwGDYhwBojnBBKOCGUcEIo4YRQwgmhhBNCVS+VTmkYO9rE92K5uLqUjIaxPTXMhCa+q7i4eiYrJ4QSTgglnBBKOCGUcEIo4YRQwgmhhprnpjSMHa2D9ozR0lPDbNVCE+qgPWP0GG3im1AH7RljJisnhBJOCCWcEEo4IZRwQijhhFDCCaGEE0JVL5XecedD02+N7jmwvf7JE7On0dyokNBdu5RSjh7+zvR5tDYq7Nsyv6N0a5NCKSEdtje8dfg7XCoNa4xwQijhhFDCCaGEE0IJJ4QSTghVrXN+7Pp7V9MdF9aoUx98S/OdDZ95vvr8yK6H1TlhLRFOCCWcEEo4IZRwQijhhFDCCaGEE0IN3fgO/+/+dlX7soCTj11bf2HX+f/YygmhhBNCCSeEEk4IJZwQSjghlHBCqOpha+DCsXJCKOGEUMIJoYQTQgknhBJOCPVvSJB6KNEsKR0AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"vis(rep,rep)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"repin,repout = T(1)(SO(3))*T(2)(S(4)),T(2)(SO(3))*T(1)(S(4))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"V³_S(4)⊗V³_SO(3)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"repin>>repout"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"#vis(repin,repout)"
]
},
{
Expand Down
31 changes: 30 additions & 1 deletion docs/notebooks/mixed_tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,36 @@ kernelspec:
name: python3
---

# Combining Representations from Different Groups
# Combining Representations from Different Groups (experimental)

```{code-cell} ipython3
from emlp.solver.groups import *
from emlp.solver.representation import T,vis
```

```{code-cell} ipython3
rep = 2*T(1)(Z(3))*T(1)(S(4))+T(1)(SO(2))
```

```{code-cell} ipython3
(rep>>rep)
```

```{code-cell} ipython3
vis(rep,rep)
```

```{code-cell} ipython3
repin,repout = T(1)(SO(3))*T(2)(S(4)),T(2)(SO(3))*T(1)(S(4))
```

```{code-cell} ipython3
repin>>repout
```

```{code-cell} ipython3
#vis(repin,repout)
```

```{code-cell} ipython3
Expand Down
2 changes: 2 additions & 0 deletions emlp/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self,rep_in,rep_out):
self.rep_out=rep_out
self.linear = LieLinear(rep_in,gated(rep_out))
self.bilinear = BiLinear(gated(rep_out),gated(rep_out))
# self.linear = LieLinear(rep_in,rep_out)
# self.bilinear = BiLinear(rep_out,rep_out)
#self.bn = TensorBN(gated(rep_out))
self.nonlinearity = GatedNonlinearity(rep_out)
def __call__(self,x):
Expand Down
5 changes: 0 additions & 5 deletions emlp/models/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,12 @@ def rel_err(a,b):
def scale_adjusted_rel_err(a,b,g):
return jnp.sqrt(((a-b)**2).mean())/(jnp.sqrt((a**2).mean())+jnp.sqrt((b**2).mean())+jnp.abs(g-jnp.eye(g.shape[-1])).mean())

# def scale_adjusted_rel_error(t1,t2,g):
# return jnp.mean(jnp.abs(t1-t2))/(jnp.mean(jnp.abs(t1)) + jnp.mean(jnp.abs(t2))+jnp.mean()+1e-7)

def equivariance_err(model,mb,group=None):
x,y = mb
group = model.model.G if group is None else group
gs = group.samples(x.shape[0])
rho_gin = vmap(model.model.rep_in.rho_dense)(gs)
rho_gout = vmap(model.model.rep_out.rho_dense)(gs)
#rho_gin = jnp.stack([model.model.rep_in.rho(g) for g in ])
#rho_gout = jnp.stack([model.model.rep_out.rho(g) for g in group.samples(x.shape[0])])
y1 = model.predict((rho_gin@x[...,None])[...,0])
y2 = (rho_gout@model.predict(x)[...,None])[...,0]
return np.asarray(scale_adjusted_rel_err(y1,y2,gs))
Expand Down

0 comments on commit 4c428b2

Please sign in to comment.