Skip to content

emilemathieu/escnn_jax

 
 

Repository files navigation

E(n)-equivariant Steerable CNNs (escnn)

Documentation | escnn library |

🚀 ~20% faster than pytorch*

escnn_jax is a Jax port of the PyTorch escnn library for equivariant deep learning. escnn_jax supports steerable CNNs equivariant to both 2D and 3D isometries, as well as equivariant MLPs.


The library is structured into four subpackages with different high-level features:

Component Dependency Description
escnn.group Pure Python implements basic concepts of group and representation theory
escnn.gspaces Pure Python defines the Euclidean spaces and their symmetries
escnn.kernels Jax solves for spaces of equivariant convolution kernels
escnn.nn Equinox contains equivariant modules to build deep neural networks

TODOs

Priority

  • reproduce examples and baselines
    • mlp.ipynb
      • appart for IIDBatchNorm1d module
    • introduction.ipynb
    • model.ipynb
    • octahedral_cnn.ipynb
  • mimic requires_grad=false for 'buffer' variables to avoid including them in opt_state and grads
    • added in EquivariantModule the methods set_buffer and get_buffer which wrap the variable in lax.stop_gradient
    • added in EquivariantModule the methods set_parameter and get_parameter which wrap the Array a custom type escn_jax.nn.ParameterArray which can later be used to filter the parameters
  • enhance model.eval() behaviour; make EquivariantModule.eval recursively call submodules?
  • speed up module's __init__ e.g. nn.Linear and nn.R2Conv
  • speed up module's __call__ if possible?
  • better __repr__ for EquivariantModule and eqx.nn.Module more generally
  • make sure that tests pass for implemented modules and kernels
  • Bug? InnerBatchNorm.eval() without training returns high values
  • add export method for layers
  • properly measuring speed up wrt pytorch version

Nice to have

  • add support for haiku / flax under escnn.nn.haiku / escnn.nn.flax
  • jaxlinop for Representation class akin to emlp, and more generally rewrite escnn_jax.group in jax?
  • add missing modules cf /nn/__init__.py

Getting Started

escnn_jax is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).

from escnn_jax import gspaces                                          #  1
from escnn_jax import nn                                               #  2
import jax                                                             #  3
key = jax.random.PRNGKey(0)                                            #  4
key1, key2 = jax.random.split(key, 2)                                  #  5
                                                                       #  6
r2_act = gspaces.rot2dOnR2(N=8)                                        #  7
feat_type_in  = nn.FieldType(r2_act,  3*[r2_act.trivial_repr])         #  8
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr])         #  9
                                                                       # 10
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5, key=key1) # 11
relu = nn.ReLU(feat_type_out)                                          # 12
                                                                       # 13
x = jax.random.normal(key2, (16, 3, 32, 32))                           # 14
x = feat_type_in(x)                                                    # 15
                                                                       # 16
y = relu(conv(x))                                                      # 17

Dependencies

The library is based on Python3.7

jax
equinox
jaxtyping
numpy
scipy
lie_learn
joblibx
py3nj

Optional:

pymanopt>=1.0.0
optax
chex

WARNING: py3nj enables a fast computation of Clebsh Gordan coefficients. If this package is not installed, our library relies on a numerical method to estimate them. This numerical method is not guaranteed to return the same coefficients computed by py3nj (they can differ by a sign). For this reason, models built with and without py3nj might not be compatible.

To successfully install py3nj you may need a Fortran compiler installed in you environment.

Installation

You can install the latest release as

pip install escnn_jax

or you can clone this repository and manually install it with

pip install git+https://github.com/QUVA-Lab/escnn_jax

Contributing

Would you like to contribute to escnn_jax? That's great!

Then, check the instructions in CONTRIBUTING.md and help us to improve the library!

Cite

The development of this library was part of the work done for our papers A Program to Build E(N)-Equivariant Steerable CNNs and General E(2)-Equivariant Steerable CNNs. Please cite these works if you use our code:


   @inproceedings{cesa2022a,
        title={A Program to Build {E(N)}-Equivariant Steerable {CNN}s },
        author={Gabriele Cesa and Leon Lang and Maurice Weiler},
        booktitle={International Conference on Learning Representations},
        year={2022},
        url={https://openreview.net/forum?id=WE4qe9xlnQw}
    }
    
   @inproceedings{e2cnn,
       title={{General E(2)-Equivariant Steerable CNNs}},
       author={Weiler, Maurice and Cesa, Gabriele},
       booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
       year={2019},
   }

Feel free to contact us.

License

escnn_jax is distributed under BSD Clear license. See LICENSE file.

About

Equivariant Steerable CNNs Library for Pytorch https://quva-lab.github.io/escnn/

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%