diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6d60ba2..22982e2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,6 +1,6 @@ -# Adapted from https://github.com/pytorch/examples/blob/master/.github/workflows/main.yml +# Example of how to open issues on failed tests: https://git.io/JUfDd -name: Run Tests +name: tests on: push: @@ -8,7 +8,7 @@ on: - '**.py' # only run workflow when source files changed pull_request: paths: - - '**.py' # only run workflow when source files changed + - '**.py' jobs: tests: @@ -16,53 +16,33 @@ jobs: steps: - uses: actions/checkout@v2 + - name: Set up latest Python 3 + uses: actions/setup-python@v2 + with: + python-version: 3.x + - name: Cache dependencies # Adapted from https://git.io/JUfrK uses: actions/cache@v2 id: cache-deps with: - # This path is specific to Ubuntu - path: ~/.cache/pip + path: ${ pip cache dir } # requires pip 20.1+ # Look to see if there is a cache hit for the corresponding environment file - key: ${{ runner.os }}-pip-${{ hashFiles('env.yml') }} - # Ordered list of keys to use for restoring the cache if no cache hit occurred for key + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + # Optional ordered list of alternative keys to use for restoring the cache + # if no cache was found for key restore-keys: | ${{ runner.os }}-pip- - ${{ runner.os }}- - - - name: Set up latest Python 3 - if: steps.cache-deps.outputs.cache-hit != 'true' - uses: actions/setup-python@v2 - with: - python-version: 3.x - name: Install dependencies - if: steps.cache-deps.outputs.cache-hit != 'true' run: | - python -m pip install -U pip - pip install -U pytest sklearn flake8 - # Install CPU-based pytorch - pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + pip install -r requirements.txt pip install . - # Maybe use the CUDA 10.2 version instead? - # pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html - name: Run Tests run: | python -m pytest - - name: Open issue on failure - if: ${{ failure() }} - uses: rishabhgupta/git-action-issue@v2 - with: - token: ${{ secrets.GITHUB_TOKEN }} - title: Test suite failed on CI - body: Commit ${{ github.sha }} caused [CI run](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) to fail, please check why. - assignees: janosh - - linting: - runs-on: ubuntu-latest - steps: - name: Lint with flake8 # Adapted from https://git.io/JUfPw run: | diff --git a/env.yml b/env.yml deleted file mode 100644 index 120a20a..0000000 --- a/env.yml +++ /dev/null @@ -1,108 +0,0 @@ -name: default -channels: - - defaults -dependencies: - - _pytorch_select=0.1 - - appnope=0.1.0 - - attrs=19.3.0 - - backcall=0.1.0 - - blas=1.0 - - bleach=3.1.0 - - ca-certificates=2019.11.27 - - certifi=2019.11.28 - - cffi=1.13.2 - - cycler=0.10.0 - - dbus=1.13.12 - - decorator=4.4.1 - - defusedxml=0.6.0 - - entrypoints=0.3 - - expat=2.2.6 - - freetype=2.9.1 - - gettext=0.19.8.1 - - glib=2.63.1 - - icu=58.2 - - importlib_metadata=1.2.0 - - intel-openmp=2019.4 - - ipykernel=5.1.3 - - ipython=7.10.1 - - ipython_genutils=0.2.0 - - ipywidgets=7.5.1 - - jedi=0.15.1 - - jinja2=2.10.3 - - joblib=0.14.0 - - jpeg=9b - - jsonschema=3.2.0 - - jupyter=1.0.0 - - jupyter_client=5.3.4 - - jupyter_console=5.2.0 - - jupyter_core=4.6.1 - - kiwisolver=1.1.0 - - libcxx=4.0.1 - - libcxxabi=4.0.1 - - libedit=3.1.20181209 - - libffi=3.2.1 - - libgfortran=3.0.1 - - libiconv=1.15 - - libpng=1.6.37 - - libsodium=1.0.16 - - llvm-openmp=4.0.1 - - markupsafe=1.1.1 - - matplotlib=3.1.1 - - mistune=0.8.4 - - mkl=2019.4 - - mkl-service=2.3.0 - - mkl_fft=1.0.15 - - mkl_random=1.1.0 - - more-itertools=7.2.0 - - nbconvert=5.6.1 - - nbformat=4.4.0 - - ncurses=6.1 - - ninja=1.9.0 - - notebook=6.0.2 - - numpy=1.17.4 - - numpy-base=1.17.4 - - openssl=1.1.1d - - pandoc=2.2.3.2 - - pandocfilters=1.4.2 - - parso=0.5.1 - - pcre=8.43 - - pexpect=4.7.0 - - pickleshare=0.7.5 - - pip=19.3.1 - - prometheus_client=0.7.1 - - prompt_toolkit=3.0.2 - - ptyprocess=0.6.0 - - pycparser=2.19 - - pygments=2.5.2 - - pyparsing=2.4.5 - - pyqt=5.9.2 - - pyrsistent=0.15.6 - - python=3.6.9 - - python-dateutil=2.8.1 - - pytorch=1.3.1 - - pytz=2019.3 - - pyzmq=18.1.0 - - qt=5.9.7 - - qtconsole=4.6.0 - - readline=7.0 - - scikit-learn=0.21.3 - - scipy=1.3.2 - - send2trash=1.5.0 - - setuptools=42.0.2 - - sip=4.19.8 - - six=1.13.0 - - sqlite=3.30.1 - - terminado=0.8.3 - - testpath=0.4.4 - - tk=8.6.8 - - tornado=6.0.3 - - traitlets=4.3.3 - - wcwidth=0.1.7 - - webencodings=0.5.1 - - wheel=0.33.6 - - widgetsnbextension=3.5.1 - - xz=5.2.4 - - zeromq=4.3.1 - - zipp=0.6.0 - - zlib=1.2.11 -prefix: /usr/local/Caskroom/miniconda/base/envs/torch-nf diff --git a/readme.md b/readme.md index 29af8bd..a261f6a 100644 --- a/readme.md +++ b/readme.md @@ -1,4 +1,4 @@ -# Torch MNF +# Torch MNF   [![Test Status](https://github.com/janosh/torch-mnf/workflows/tests/badge.svg)](https://github.com/janosh/torch-mnf/actions) PyTorch implementation of Multiplicative Normalizing Flows [[1]](#mnf-bnn). diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..03b799e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +flake8 +matplotlib +numpy +pandas +pytest +seaborn +scikit-learn +scipy +torch +torchvision diff --git a/torch_mnf/notebooks/2d.py b/torch_mnf/notebooks/2d.py index 4381b7b..afdbb14 100644 --- a/torch_mnf/notebooks/2d.py +++ b/torch_mnf/notebooks/2d.py @@ -71,7 +71,7 @@ print(f"number of params: {sum(p.numel() for p in model.parameters()):,}") -def train(steps=1000, n_samples=128, report_every=100, cb=None): +def train_flow(steps=1000, n_samples=128, report_every=100, cb=None): for step in range(steps + 1): x = sample_target_dist(n_samples) @@ -91,7 +91,7 @@ def train(steps=1000, n_samples=128, report_every=100, cb=None): # %% -train() +train_flow() # %% @@ -202,4 +202,4 @@ def plot_learning(): # %% -train(steps=400, cb=plot_learning) +train_flow(steps=400, cb=plot_learning)