-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b91f2eb
commit c7a65a2
Showing
27 changed files
with
4,369 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
tags |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,24 @@ | ||
# noether-networks | ||
Meta-learning inductive biases in the form of useful conserved quantities. | ||
# Noether Networks: meta-learning useful conserved quantities | ||
|
||
This repository contains the code necessary to reproduce experiments from "Noether Networks: | ||
meta-learning useful conserved quantities." | ||
Noether Networks meta-learn inductive biases in the form of useful conserved quantities. | ||
For details on the method, check out our NeurIPS 2021 paper, linked on our | ||
[project website](https://dylandoblar.github.io/noether-networks/). | ||
|
||
For instructions on how to train and evaluate a Noether Network for video prediction, check out | ||
[`video_prediction/README.md`](video_prediction). | ||
|
||
|
||
## Citation | ||
If this work is useful to you, please cite our paper: | ||
``` | ||
@inproceedings{ | ||
alet2021noether, | ||
title={Noether Networks: meta-learning useful conserved quantities}, | ||
author={Ferran Alet and Dylan Doblar and Allan Zhou and Joshua B. Tenenbaum and Kenji Kawaguchi and Chelsea Finn}, | ||
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, | ||
year={2021}, | ||
url={https://openreview.net/forum?id=_NOwVKCmSo} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
theme: jekyll-theme-cayman | ||
defaults: | ||
- | ||
scope: | ||
path: "" | ||
type: "pages" | ||
values: | ||
title: "Noether Networks: meta-learning useful conserved quantities" | ||
authors: "Ferran Alet*<sup>1</sup>, Dylan Doblar*<sup>1</sup>, Allan Zhou<sup>2</sup>, Joshua B. Tenenbaum<sup>1</sup>, Kenji Kawaguchi<sup>3</sup>, Chelsea Finn<sup>2</sup>" | ||
affiliations: "<sup>1</sup>MIT, <sup>2</sup>Stanford University, <sup>3</sup>National University of Singapore; *Equal contribution" | ||
venue: "NeurIPS 2021; Workshop on Machine Learning and the Physical Sciences at NeurIPS 2021" | ||
paper-url: "/noether-networks/noether_networks_neurips_2021_CR.pdf" | ||
workshop-paper-url: "/noether-networks/noether_networks_neurips_ml4phys_2021_CR.pdf" | ||
poster-url: "/noether-networks/noether_networks_neurips_2021_poster.pdf" | ||
include-footer: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
<!DOCTYPE html> | ||
<html lang="{{ site.lang | default: "en-US" }}"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
|
||
{% seo %} | ||
<link rel="preconnect" href="https://fonts.gstatic.com"> | ||
<link rel="preload" href="https://fonts.googleapis.com/css?family=Open+Sans:400,700&display=swap" as="style" type="text/css" crossorigin> | ||
<meta name="viewport" content="width=device-width, initial-scale=1"> | ||
<meta name="theme-color" content="#157878"> | ||
<meta name="apple-mobile-web-app-status-bar-style" content="black-translucent"> | ||
<link rel="stylesheet" href="{{ '/assets/css/style.css?v=' | append: site.github.build_revision | relative_url }}"> | ||
{% include head-custom.html %} | ||
</head> | ||
<body> | ||
<a id="skip-to-content" href="#content">Skip to the content.</a> | ||
|
||
<header class="page-header" role="banner"> | ||
<h1 class="project-name">{{ page.title | default: site.title | default: site.github.repository_name }}</h1> | ||
<h2 class="project-tagline">{{ page.authors }}</h2> | ||
<h2 class="project-tagline">{{ page.affiliations }}</h2> | ||
<h3 class="project-tagline">{{ page.equal-contrib }}</h3> | ||
<h2 class="project-tagline"><em>{{ page.venue }}</em></h2> | ||
{% if site.github.is_project_page %} | ||
<a href="{{ page.paper-url }}" class="btn">Main Conference Paper</a> | ||
<a href="{{ page.workshop-paper-url }}" class="btn">Workshop Paper</a> | ||
<a href="{{ page.poster-url }}" class="btn">Poster</a> | ||
<a href="{{ site.github.repository_url }}" class="btn">Code</a> | ||
{% endif %} | ||
{% if site.show_downloads %} | ||
<a href="{{ site.github.zip_url }}" class="btn">Download .zip</a> | ||
<a href="{{ site.github.tar_url }}" class="btn">Download .tar.gz</a> | ||
{% endif %} | ||
</header> | ||
|
||
<main id="content" class="main-content" role="main"> | ||
{{ content }} | ||
|
||
{% if page.include-footer %} | ||
<footer class="site-footer"> | ||
{% if site.github.is_project_page %} | ||
<span class="site-footer-owner"><a href="{{ site.github.repository_url }}">{{ site.github.repository_name }}</a> is maintained by <a href="{{ site.github.owner_url }}">{{ site.github.owner_name }}</a>.</span> | ||
{% endif %} | ||
<span class="site-footer-credits">This page was generated by <a href="https://pages.github.com">GitHub Pages</a>.</span> | ||
</footer> | ||
{% endif %} | ||
</main> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
## Abstract | ||
Progress in machine learning (ML) stems from a combination of data availability, computational resources, and an appropriate encoding of inductive biases. Useful biases often exploit symmetries in the prediction problem, such as convolutional networks relying on translation equivariance. Automatically discovering these useful symmetries holds the potential to greatly improve the performance of ML systems, but still remains a challenge. In this work, we focus on sequential prediction problems and take inspiration from Noether's theorem to reduce the problem of finding inductive biases to meta-learning useful conserved quantities. We propose Noether Networks: a new type of architecture where a meta-learned conservation loss is optimized inside the prediction function. We show, theoretically and experimentally, that Noether Networks improve prediction quality, providing a general framework for discovering inductive biases in sequential problems. | ||
|
||
## Citation | ||
If this work is useful to you, please cite our paper: | ||
``` | ||
@inproceedings{ | ||
alet2021noether, | ||
title={Noether Networks: meta-learning useful conserved quantities}, | ||
author={Ferran Alet and Dylan Doblar and Allan Zhou and Joshua B. Tenenbaum and Kenji Kawaguchi and Chelsea Finn}, | ||
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, | ||
year={2021}, | ||
url={https://openreview.net/forum?id=_NOwVKCmSo} | ||
} | ||
``` | ||
<!-- | ||
You can use the [editor on GitHub](https://github.com/dylandoblar/noether-networks/edit/main/docs/index.md) to maintain and preview the content for your website in Markdown files. | ||
Whenever you commit to this repository, GitHub Pages will run [Jekyll](https://jekyllrb.com/) to rebuild the pages in your site, from the content in your Markdown files. | ||
### Markdown | ||
Markdown is a lightweight and easy-to-use syntax for styling your writing. It includes conventions for | ||
```markdown | ||
Syntax highlighted code block | ||
# Header 1 | ||
## Header 2 | ||
### Header 3 | ||
- Bulleted | ||
- List | ||
1. Numbered | ||
2. List | ||
**Bold** and _Italic_ and `Code` text | ||
[Link](url) and ![Image](src) | ||
``` | ||
For more details see [GitHub Flavored Markdown](https://guides.github.com/features/mastering-markdown/). | ||
### Jekyll Themes | ||
Your Pages site will use the layout and styles from the Jekyll theme you have selected in your [repository settings](https://github.com/dylandoblar/noether-networks/settings/pages). The name of this theme is saved in the Jekyll `_config.yml` configuration file. | ||
### Support or Contact | ||
Having trouble with Pages? Check out our [documentation](https://docs.github.com/categories/github-pages-basics/) or [contact support](https://support.github.com/contact) and we’ll help you sort it out. | ||
--> |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# Noether Networks for video prediction | ||
|
||
This directory contains code to train and evaluate a Noether Network for video prediction on the | ||
Physics 101 dataset. Much of the model and utility code comes directly from the | ||
[SVG codebase](https://github.com/edenton/svg) (Denton and Fergus); we use SVG as our baseline | ||
video prediction model. | ||
|
||
First, download the Physics 101 dataset, available at the Physics 101 [project | ||
page](http://phys101.csail.mit.edu/), and unzip it in the `./data/phys101/` directory. The ramp | ||
scenario data should be located at `./data/phys101/phys101/scenarios/ramp/`. You can do this by | ||
running | ||
``` | ||
./download_extract_phys101.sh | ||
``` | ||
|
||
Ensure you have all the required dependencies, you can install them with the following command: | ||
``` | ||
pip install requirements.txt | ||
``` | ||
|
||
Then, train a Noether Network with the training script. For example, to train from scratch with a | ||
single inner step, you can run the following command: | ||
``` | ||
python train_noether_net.py \ | ||
--image_width 128 \ | ||
--g_dim 128 \ | ||
--z_dim 64 \ | ||
--dataset phys101 \ | ||
--data_root ./data/phys101/phys101/scenarios/ramp \ | ||
--tailor \ | ||
--num_trials 1 \ | ||
--n_past 2 \ | ||
--n_future 20 \ | ||
--num_threads 6 \ | ||
--ckpt_every 2 \ | ||
--inner_crit_mode mse \ | ||
--enc_dec_type vgg \ | ||
--emb_type conserved \ | ||
--num_epochs_per_val 1 \ | ||
--emb_dim 64 \ | ||
--batch_size 2 \ | ||
--num_inner_steps 1 \ | ||
--num_jump_steps 0 \ | ||
--n_epochs 1000 \ | ||
--train_set_length 311 \ | ||
--test_set_length 78 \ | ||
--inner_lr .0001 \ | ||
--val_inner_lr .0001 \ | ||
--outer_lr 0.0001 \ | ||
--outer_opt_model_weights \ | ||
--random_weights \ | ||
--only_twenty_degree \ | ||
--frame_step 2 \ | ||
--center_crop 1080 \ | ||
--num_emb_frames 2 \ | ||
--horiz_flip \ | ||
--batch_norm_to_group_norm \ | ||
--reuse_lstm_eps \ | ||
--log_dir ./results/phys101/<experiment_id>/ | ||
``` | ||
where `<experiment_id>` specifies the subdirectory where the model checkpoints and tensorboard logs | ||
will be written. | ||
|
||
To train a baseline model, pass in `--num_inner_steps 0`. | ||
|
||
To evaluate, run the evaluation script, passing in the model checkpoint you want to use: | ||
``` | ||
python evaluate_noether_net.py \ | ||
--model_path ./results/phys101/<experiment_id>/model_400.pth \ | ||
--num_inner_steps 1 \ | ||
--n_future 20 \ | ||
--horiz_flip \ | ||
--test_set_length 78 \ | ||
--train_set_length 311 \ | ||
--val_inner_lr .0001 \ | ||
--reuse_lstm_eps \ | ||
--data_root ./data/phys101/phys101/scenarios/ramp \ | ||
--dataset phys101 \ | ||
--n_past 2 \ | ||
--tailor \ | ||
--n_trials 1 \ | ||
--only_twenty_degree \ | ||
--frame_step 2 \ | ||
--crop_upper_right 1080 \ | ||
--center_crop 1080 \ | ||
--batch_size 2 \ | ||
--image_width 128 \ | ||
--num_threads 4 | ||
``` | ||
You can pass `--adam_inner_opt` to use Adam instead of SGD in the inner loop. | ||
This script will run the evaluation script, compute metrics on the test set, and cache these | ||
metrics as numpy arrays. | ||
|
||
You can load and plot the metrics with the `generate_figures.ipynb` notebook, which also contains | ||
code to generate Grad-CAM heatmaps. |
Oops, something went wrong.