Skip to content

Commit

Permalink
Merge pull request #3 from mfinzi/haiku
Browse files Browse the repository at this point in the history
integrate Haiku and pytorch compatibility
  • Loading branch information
mfinzi committed Apr 19, 2021
2 parents 70db7f9 + 223d0d5 commit c325a44
Show file tree
Hide file tree
Showing 18 changed files with 1,319 additions and 533 deletions.
45 changes: 37 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,46 @@ Use at your own caution. But if you notice things behaving unexpectedly or get f

--------------------------------------------------------------------------------

Our type system is centered on it making it easy to combine representations using ρᵤ⊗ρᵥ, ρᵤ⊕ρᵥ, ρ*. For any given matrix group and representation formed in our type system, you can get the equivariant basis with [`rep.equivariant_basis()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_basis) or a matrix which projects to that subspace with [`rep.equivariant_projector()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_projector). For example:
We provide a type system for representations. With the operators ρᵤ⊗ρᵥ, ρᵤ⊕ρᵥ, ρ* implemented as `*`,`+` and `.T` build up different representations. The basic building blocks for representations are the base vector representation `V` and tensor representations `T(p,q) = V**p*V.T**q`.

For any given matrix group and representation formed in our type system, you can get the equivariant basis with [`rep.equivariant_basis()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_basis) or a matrix which projects to that subspace with [`rep.equivariant_projector()`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.equivariant_projector).


For example to find all O(1,3) (Lorentz) equivariant linear maps from from a 4-Vector Xᶜ to a rank (2,1) tensor Mᵇᵈₐ, you can run

```python
from emlp.reps import V
from emlp.reps import V,T
from emlp.groups import *

G = O13()
Q = (T(1,0)>>T(1,2))(G).equivariant_basis()
```

or how about equivariant maps from one Rubik's cube to another?
```python
G = RubiksCube()

Q = (V(G)>>V(G)).equivariant_basis()
```

Using `+` and `*` you can put together composite representations (where multiple representations are concatenated together). For example lets find all equivariant linear maps from 5 node features and 2 edge features to 3 global invariants and 1 edge feature of a graph of size n=5:
```python
G=S(5)

repin = 10*T(1)+5*T(2)
repout = 3*T(0)+T(2)
Q = (repin(G)>>repout(G)).equivariant_basis()
```

From the examples above, there are many different ways of writing a representation like `10*T(1)+5*T(2)` which are all equivalent.
`10*T(1)+5*T(2)` = `10*V+5*V**2` = `5*V*(2+V)`
<!-- Feel free to go wild:
```python
W=V(O13())
repin = (W+2*W**2)*(W.T+1*W).T + W.T
repout = 3*W**0 + W + W*W.T
Q = (repin>>repout).equivariant_basis()
```

is code that will run and produce the basis for linear maps from repin to repout that are equivariant to the Lorentz group O(1,3).
``` -->

You can even mix and match representations from different groups. For example with the cyclic group ℤ₃, the permutation group 𝕊₄, and the orthogonal group O(3)

Expand All @@ -34,9 +61,11 @@ rep = 2*V(Z(3))*V(S(4))+V(O(3))**2
Q = (rep>>rep).equivariant_basis()
```

You can visualize these equivariant bases with [`vis(repin,repout)`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.vis), such as with the two examples above
Outside of these tensor representations, our type system works with any finite dimensional linear representation and you can even build your own bespoke representations following the instructions [here](https://emlp.readthedocs.io/en/latest/notebooks/4new_representations.html).

You can visualize these equivariant bases with [`vis(repin,repout)`](https://emlp.readthedocs.io/en/latest/package/emlp.reps.html#emlp.reps.vis), such as with the three examples above

<img src="https://user-images.githubusercontent.com/12687085/111226517-a2192b80-85b7-11eb-8dba-c01399fb7105.png" width="350"/> <img src="https://user-images.githubusercontent.com/12687085/111226510-a0e7fe80-85b7-11eb-913b-09776cdaa92e.png" width="230"/>
<img src="https://user-images.githubusercontent.com/12687085/115313228-e19be000-a140-11eb-994f-d4eae4057eba.png" width="200"/> <img src="https://user-images.githubusercontent.com/12687085/115312972-6afee280-a140-11eb-82f0-603748694645.png" width="360"/> <img src="https://user-images.githubusercontent.com/12687085/111226510-a0e7fe80-85b7-11eb-913b-09776cdaa92e.png" width="200"/>
<!-- ![basis B](https://user-images.githubusercontent.com/12687085/111226517-a2192b80-85b7-11eb-8dba-c01399fb7105.png "title2")
![basis A](https://user-images.githubusercontent.com/12687085/111226510-a0e7fe80-85b7-11eb-913b-09776cdaa92e.png "title1") -->

Expand Down Expand Up @@ -93,7 +122,7 @@ python experiments/hnn.py --network EMLPH --group="DkeR3(6)"

These models are trained to fit a double spring dynamical system. 30s rollouts of the dataset, along with rollout error on these trajectories, and conservation of angular momentum are shown below.

<img src="https://user-images.githubusercontent.com/12687085/114937183-759d3d00-9e0b-11eb-9310-bbfc606e6bda.gif" width="250"/> <img src="https://user-images.githubusercontent.com/12687085/114937167-703ff280-9e0b-11eb-8421-d8408b31908a.PNG" width="300"/> <img src="https://user-images.githubusercontent.com/12687085/114937171-71711f80-9e0b-11eb-885e-a541ae1d28cc.PNG" width="260"/>
<img src="https://user-images.githubusercontent.com/12687085/114937183-759d3d00-9e0b-11eb-9310-bbfc606e6bda.gif" width="240"/> <img src="https://user-images.githubusercontent.com/12687085/114937167-703ff280-9e0b-11eb-8421-d8408b31908a.PNG" width="285"/> <img src="https://user-images.githubusercontent.com/12687085/114937171-71711f80-9e0b-11eb-885e-a541ae1d28cc.PNG" width="250"/>

<!-- #
<p align="center">
Expand Down
8 changes: 8 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ This is a comment.
Remember to align the itemized text with the first line of an item within a list.
-->

## EMLP 0.9.0
* Cross Platform Support:
* You can now use EMLP in PyTorch, check out `Using EMLP in PyTorch`
* You can also use EMLP with Haiku in jax, check out `Using EMLP with Haiku`

* Bug Fixes
* Fixed broken constraints with Trivial group

## EMLP 0.8.0 (Unreleased)

* New features:
Expand Down
8 changes: 7 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ and representations an easy task, and one that does not require knowledge of ana
notebooks/4new_representations.ipynb
notebooks/5mixed_tensors.ipynb
notebooks/6multilinear_maps.ipynb
notebooks/7pytorch_support.ipynb

.. toctree::
:maxdepth: 1
:caption: Cross Platform Support

notebooks/pytorch_support.ipynb
notebooks/haiku_support.ipynb

.. toctree::
:maxdepth: 1
Expand Down
182 changes: 0 additions & 182 deletions docs/notebooks/7pytorch_support.ipynb

This file was deleted.

0 comments on commit c325a44

Please sign in to comment.