Skip to content

Commit

Permalink
update to documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
mfinzi committed Mar 3, 2021
1 parent fc47b31 commit 4061a7f
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 85 deletions.
176 changes: 166 additions & 10 deletions docs/notebooks/building_a_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,46 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model"
"# Constructing Equivariant Models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Previously we showed examples of finding equivariant bases for different groups and representations, now we'll show how these bases can be assembled into equivariant neural networks such as EMLP. \n",
"\n",
"We will give examples at a high level showing how the specific EMLP model can be applied to different groups and input-output types, and later in a lower level showing how models like EMLP can be constructed with equivariant layers and making use of the equivariant bases."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using EMLP with different groups and representations (high level)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![ex 2.13](imgs/EMLP_fig.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A basic EMLP is a sequence of EMLP layers (containing G-equivariant linear layers, bilinear layers incorporated with a shortcut connection, and gated nonlinearities. While our numerical equivariance solver can work with any finite dimensional linear representation, for EMLP we restrict ourselves to _tensor_ representations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By tensor representations, we mean all representations which can be formed by arbitrary combinations of $\\oplus$,$\\otimes$,$^*$ (`+`,`*`,`.T`) of a base representation $\\rho$. This is useful because it simplifies the construction of our bilinear layer, which is a crucial ingredient for expressiveness and universality in EMLP.\n",
"\n",
"Following the $T_{(p,q)}=V^{\\otimes p}\\otimes (V^*)^{\\otimes p}$ notation in the paper, we provide the convenience function for constructing higher rank tensors."
]
},
{
Expand All @@ -13,18 +52,135 @@
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
"name": "stdout",
"output_type": "stream",
"text": [
"V⊗V⊗V*⊗V*⊗V*\n",
"V²⊗V*³\n"
]
}
],
"source": [
"2"
"from emlp.solver.representation import V\n",
"from emlp.solver.groups import SO13\n",
"\n",
"def T(p,q=0):\n",
" return (V**p*V.T**q)\n",
"\n",
"print(T(2,3))\n",
"print(T(2,3)(SO13()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets get started with a toy dataset: learning how an inertia matrix depends on the positions and masses of 5 point masses distributed in different ways. The data consists of mappings (positions, masses) --> (inertia matrix) pairs, and has an $G=O(3)$ symmetry (3D rotation and reflections). If we rotate all the positions, the resulting inertia matrix should be correspondingly rotated."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input type: 5V⁰+5V, output type: V²\n"
]
}
],
"source": [
"from emlp.models.datasets import Inertia\n",
"from emlp.solver.groups import SO,O,S,Z\n",
"dataset = Inertia(100) # Initialize dataset with 100 examples\n",
"G = O(3)\n",
"print(f\"Input type: {dataset.rep_in(G)}, output type: {dataset.rep_out(G)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For convenience, we store in the dataset the types for the input and the output. `5V⁰` are the $5$ mass values and `5V` are the position vectors of those masses, `V²` is the matrix type for the output, equivalent to $T_2$. To initialize the EMLP, we just need these input and output representations, the symmetry group, and the size of the network as parametrized by number of layers and number of channels (the dimension of the feature representation)."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from emlp.models.mlp import EMLP\n",
"model = EMLP(dataset.rep_in,dataset.rep_out,group=G,num_layers=3,ch=256)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from emlp.models.mlp import uniform_rep\n",
"r = uniform_rep(512,G)\n",
"print(r)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'Module' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-dfa6aa01e575>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mEMLP\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmetaclass\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNamed\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mrep_in\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mrep_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m384\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_layers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m#@\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Initing EMLP\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrep_in\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0mrep_in\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'Module' is not defined"
]
}
],
"source": [
"# class EMLP(Module,metaclass=Named):\n",
"# def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@\n",
"# super().__init__()\n",
"# logging.info(\"Initing EMLP\")\n",
"# self.rep_in =rep_in(group)\n",
"# self.rep_out = rep_out(group)\n",
" \n",
"# self.G=group\n",
"# # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps\n",
"# if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)]\n",
"# elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]\n",
"# else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]\n",
"# #assert all((not rep.G is None) for rep in middle_layers[0].reps)\n",
"# reps = [self.rep_in]+middle_layers\n",
"# #logging.info(f\"Reps: {reps}\")\n",
"# self.network = Sequential(\n",
"# *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],\n",
"# LieLinear(reps[-1],self.rep_out)\n",
"# )\n",
"# #self.network = LieLinear(self.rep_in,self.rep_out)\n",
"# def __call__(self,x,training=True):\n",
"# return self.network(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Equivariant Linear Layers (low level)"
]
},
{
Expand Down
89 changes: 87 additions & 2 deletions docs/notebooks/building_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,97 @@ kernelspec:
name: python3
---

# Model
# Constructing Equivariant Models

+++

Previously we showed examples of finding equivariant bases for different groups and representations, now we'll show how these bases can be assembled into equivariant neural networks such as EMLP.

We will give examples at a high level showing how the specific EMLP model can be applied to different groups and input-output types, and later in a lower level showing how models like EMLP can be constructed with equivariant layers and making use of the equivariant bases.

+++

## Using EMLP with different groups and representations (high level)

+++

![ex 2.13](imgs/EMLP_fig.png)

+++

A basic EMLP is a sequence of EMLP layers (containing G-equivariant linear layers, bilinear layers incorporated with a shortcut connection, and gated nonlinearities. While our numerical equivariance solver can work with any finite dimensional linear representation, for EMLP we restrict ourselves to _tensor_ representations.

+++

By tensor representations, we mean all representations which can be formed by arbitrary combinations of $\oplus$,$\otimes$,$^*$ (`+`,`*`,`.T`) of a base representation $\rho$. This is useful because it simplifies the construction of our bilinear layer, which is a crucial ingredient for expressiveness and universality in EMLP.

Following the $T_{(p,q)}=V^{\otimes p}\otimes (V^*)^{\otimes p}$ notation in the paper, we provide the convenience function for constructing higher rank tensors.

```{code-cell} ipython3
2
from emlp.solver.representation import V
from emlp.solver.groups import SO13
def T(p,q=0):
return (V**p*V.T**q)
print(T(2,3))
print(T(2,3)(SO13()))
```

Lets get started with a toy dataset: learning how an inertia matrix depends on the positions and masses of 5 point masses distributed in different ways. The data consists of mappings (positions, masses) --> (inertia matrix) pairs, and has an $G=O(3)$ symmetry (3D rotation and reflections). If we rotate all the positions, the resulting inertia matrix should be correspondingly rotated.

```{code-cell} ipython3
from emlp.models.datasets import Inertia
from emlp.solver.groups import SO,O,S,Z
dataset = Inertia(100) # Initialize dataset with 100 examples
G = O(3)
print(f"Input type: {dataset.rep_in(G)}, output type: {dataset.rep_out(G)}")
```

For convenience, we store in the dataset the types for the input and the output. `5V⁰` are the $5$ mass values and `5V` are the position vectors of those masses, `` is the matrix type for the output, equivalent to $T_2$. To initialize the EMLP, we just need these input and output representations, the symmetry group, and the size of the network as parametrized by number of layers and number of channels (the dimension of the feature representation).

```{code-cell} ipython3
from emlp.models.mlp import EMLP
model = EMLP(dataset.rep_in,dataset.rep_out,group=G,num_layers=3,ch=256)
```

```{code-cell} ipython3
from emlp.models.mlp import uniform_rep
r = uniform_rep(512,G)
print(r)
```

```{code-cell} ipython3
# class EMLP(Module,metaclass=Named):
# def __init__(self,rep_in,rep_out,group,ch=384,num_layers=3):#@
# super().__init__()
# logging.info("Initing EMLP")
# self.rep_in =rep_in(group)
# self.rep_out = rep_out(group)
# self.G=group
# # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
# if isinstance(ch,int): middle_layers = num_layers*[uniform_rep(ch,group)]#[uniform_rep(ch,group) for _ in range(num_layers)]
# elif isinstance(ch,Rep): middle_layers = num_layers*[ch(group)]
# else: middle_layers = [(c(group) if isinstance(c,Rep) else uniform_rep(c,group)) for c in ch]
# #assert all((not rep.G is None) for rep in middle_layers[0].reps)
# reps = [self.rep_in]+middle_layers
# #logging.info(f"Reps: {reps}")
# self.network = Sequential(
# *[EMLPBlock(rin,rout) for rin,rout in zip(reps,reps[1:])],
# LieLinear(reps[-1],self.rep_out)
# )
# #self.network = LieLinear(self.rep_in,self.rep_out)
# def __call__(self,x,training=True):
# return self.network(x)
```

```{code-cell} ipython3
```

## Equivariant Linear Layers (low level)

```{code-cell} ipython3
```
Binary file added docs/notebooks/imgs/EMLP_fig.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 4061a7f

Please sign in to comment.