Skip to content

MintNet: Building Invertible Neural Networks with Masked Convolutions

Notifications You must be signed in to change notification settings

ermongroup/mintnet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MintNet: Building Invertible Neural Networks with Masked Convolutions

This repository contains the PyTorch implementation of our paper: MintNet: Building Invertible Neural Networks with Masked Convolutions, NeurIPS 2019 . We propose a new way of constructing invertible neural networks by combining simple building blocks with a novel set of composition rules. This leads to a rich set of invertible architectures, including those similar to ResNets. Inversion is achieved with a locally convergent iterative procedure that is parallelizable and very fast in practice. Additionally, the determinant of the Jacobian can be computed analytically and efficiently, enabling their generative use as flow models.

Dependencies

The following are packages needed for running this repo.

  • PyTorch==1.1.0
  • tqdm
  • tensorboardX
  • Scipy
  • PyYAML
  • Numba

Running the experiments

python main.py --runner [runner name] --config [config file] --doc [experiment folder name]

Here runner name is one of the following:

  • DensityEstimationRunner. Experiments on MintNet density estimation.
  • ClassificationRunner. Experiments on MintNet classification.

config file is the directory of some YAML file in configs/, and experiment folder name is the folder names in run/.

For example, if you want to train MintNet density estimation model on MNIST, just run

python main.py --runner DensityEstimationRunner --config mnist_density_config.yml

Checkpoints

Checkpoints for both density estimation and classification can be downloaded from https://drive.google.com/file/d/12kGMMg0ivJI5y32hRouhZuddr9cJxfiR/view?usp=sharing

Unzip it to <root folder>/run.

Releases

No releases published

Packages

No packages published

Languages