Skip to content

google-research/spherical-cnn

main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

To increase the adoption of the compilation cache, we should
enable it by default. A prerequisite is to configure a default
cache directory.

Switch spherical_cnn molecules training and universal_diffusion
model wrapper to use the default cache.

Testing: manual testing with test workloads.
PiperOrigin-RevId: 585767363
061ce3c

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 

Spherical CNN JAX Library

This repo contains a JAX library to implement spherical CNNs and spin-weighted spherical CNNs. It was used in our ICML 2023 paper "Scaling Spherical CNNs." The code can also reproduce results from our previous papers "Spin-Weighted Spherical CNNs", NeurIPS'20 and "Learning SO(3) Equivariant Representations with Spherical CNNs", ECCV'18.

Experiments

Weather forecasting

Coming soon!

QM9

Use the following instructions to launch a short training job on QM9/H. See default.py for the longer configurations that reproduce the results in the paper.

git clone https://github.com/google-research/spherical-cnn.git
cd spherical-cnn
# Create a docker container, download and install dependencies, download and
# process the dataset.
docker build -f dockerfile-qm9 -t spherical_cnn_qm9 .
# Start training.
docker run spherical_cnn_qm9 \
    --workdir=/tmp/training_logs \
    --config=spherical_cnn/spherical_mnist/configs/small.py \
    --config.per_device_batch_size=2

It should train at around 21.9 steps/s with batch size 2 on 8 V100s and reach around 10.83 meV error for the enthalpy of atomization H (this trains for 250 epochs while 5.69 meV error in the paper was obtained by training for 2000 epochs, see default.py).

Spherical MNIST

Use the following instructions to train a small model on GPU on the spherical MNIST dataset.

git clone https://github.com/google-research/spherical-cnn.git
cd spherical-cnn
# Create a docker container, download and install dependencies, download and
# process the dataset.
docker build -f dockerfile-spherical-mnist -t spherical_cnn_mnist .
# Start training.
docker run spherical_cnn_mnist \
    --workdir=/tmp/training_logs \
    --config=spherical_cnn/spherical_mnist/configs/default.py \
    --config.model_name=spin_classifier_6_layers \
    --config.dataset=spherical_mnist/rotated \
    --config.combine_train_val_and_eval_on_test

It should train at around 22 steps/s on a single P100 and reach around 99.5% accuracy. Outputs should look something like this:

INFO 2023-08-21T19:30:28.855726181Z [Hyperparameters] {'checkpoint_every_steps': 1000, 'combine_train_val_and_eval_on_test': True, 'dataset': 'spherical_mnist/rotated', 'eval_every_steps': 1000, 'eval_pad_last_batch': True, 'learning_rate': 0.001, 'learning_rate_schedule': 'cosine', 'log_loss_every_steps': 100, 'model_name': 'spin_classifier_6_layers', 'num_epochs': 12, 'num_eval_steps': -1, 'num_train_steps': -1, 'per_device_batch_size': 32, 'seed': 42, 'shuffle_buffer_size': 1000, 'trial': 0, 'warmup_epochs': 1, 'weight_decay': 0.0}
INFO 2023-08-21T19:30:28.856940603Z Starting training loop at step 1.
INFO 2023-08-21T19:30:28.857277764Z [1] param_count=39146
INFO 2023-08-21T19:31:12.653463819Z [100] learning_rate=5.333333683665842e-05, loss=2.2819416522979736, loss_std=0.10880677402019501, train_accuracy=0.19312499463558197
INFO 2023-08-21T19:31:29.503783929Z [200] learning_rate=0.00010666667367331684, loss=1.8683496713638306, loss_std=0.14256055653095245, train_accuracy=0.3765625059604645

(...)

INFO 2023-08-21T19:48:41.827125703Z [22400] learning_rate=5.799532232231286e-08, loss=0.006118293385952711, loss_std=0.015895500779151917, train_accuracy=0.9984374642372131
INFO 2023-08-21T19:48:44.634986829Z [22500] steps_per_sec=22.0576
INFO 2023-08-21T19:48:44.635090221Z [22500] uptime=1095.78
INFO 2023-08-21T19:48:44.695150873Z [22500] learning_rate=0.0, loss=0.003276276867836714, loss_std=0.005533786956220865, train_accuracy=0.9993749856948853
INFO 2023-08-21T19:49:00.926620697Z Starting evaluation.
INFO 2023-08-21T19:49:16.283256304Z [22500] accuracy=0.9949000477790833, eval_loss=0.033049359917640686
INFO 2023-08-21T19:49:16.288987270Z Finishing training at step 22500

Unit tests

The code is extensively tested. The snippet below runs all tests given a docker container created from instructions above.

docker run --entrypoint pytest -it spherical_cnn -vv spherical_cnn_mnist

References

If you use this code, please cite the papers:

@InProceedings{pmlr-v202-esteves23a,
  title = {Scaling Spherical {CNN}s},
  author = {Esteves, Carlos and Slotine, Jean-Jacques and Makadia, Ameesh},
  booktitle = {Proceedings of the 40th International Conference on Machine Learning},
  pages = {9396--9411},
  year = {2023},
  editor = {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
  volume = {202},
  series = {Proceedings of Machine Learning Research},
  month = {23--29 Jul},
  publisher = {PMLR},
  pdf = {https://proceedings.mlr.press/v202/esteves23a/esteves23a.pdf},
  url = {https://proceedings.mlr.press/v202/esteves23a.html},
}
@inproceedings{esteves20_swscnn,
 author = {Esteves, Carlos and Makadia, Ameesh and Daniilidis, Kostas},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
 pages = {8614--8625},
 publisher = {Curran Associates, Inc.},
 title = {Spin-Weighted Spherical CNNs},
 url = {https://proceedings.neurips.cc/paper_files/paper/2020/file/6217b2f7e4634fa665d31d3b4df81b56-Paper.pdf},
 volume = {33},
 year = {2020}
}

Unittests

This is not an officially supported Google product.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages