Skip to content
Permalink
Browse files

first commit

  • Loading branch information
Andrej Karpathy
Andrej Karpathy committed Apr 22, 2018
0 parents commit 08431d9a9759446f1c298af727bc7e529dd24938
Showing with 248 additions and 0 deletions.
  1. +28 −0 README.md
  2. BIN made.png
  3. +124 −0 made.py
  4. +96 −0 run.py
@@ -0,0 +1,28 @@

# pytorch-made

This code is an implementation of ["Masked AutoEncoder for Density Estimation"](https://arxiv.org/abs/1502.03509) by Germain et al., 2015. The core idea is that you can turn an auto-encoder into an autoregressive density model just by appropriately masking the connections in the MLP, ordering the input dimensions in some way and making sure that all outputs only depend on inputs earlier in the list. Like other autoregressive models (char-rnn, pixel cnns, etc), evaluating the likelihood is very cheap (a single forward pass), but sampling is linear in the number of dimensions.

![figure 1](https://raw.github.com/karpathy/pytorch-made/master/made.png)

The authors of the paper also published code [here](https://github.com/mgermain/MADE), but it's a bit wordy, sprawling and in Theano. Hence my own shot at it with only ~150 lines of code and PyTorch <3.

## examples

First we download the [binarized mnist dataset](https://github.com/mgermain/MADE/releases/download/ICML2015/binarized_mnist.npz). Then we can reproduce the first point on the plot of Figure 2 by training a 1-layer MLP of 500 units with only a single mask, and using a single fixed (but random) ordering as so:

```
python run.py --data-path binarized_mnist.npz -q 500
```

which converges at binary cross entropy loss of `94.5`, as shown in the paper. We can then simultaneously train a larger model ensemble (with weight sharing in the one MLP) and average over all of the models at test time. For instance, we can use 10 orderings (`-n 10`) and also average over the 10 at inference time (`-s 10`):

```
python run.py --data-path binarized_mnist.npz -q 500 -n 10 -s 10
```

which gives a much better test loss of `79.3`, but at the cost of multiple forward passes. I was not able to reproduce single-forward-pass gains that the paper alludes to when training with multiple masks, might be doing something wrong.

## License

MIT
BIN +77.9 KB made.png
Binary file not shown.
124 made.py
@@ -0,0 +1,124 @@
"""
Implements Masked AutoEncoder for Density Estimation, by Germain et al. 2015
Re-implementation by Andrej Karpathy based on https://arxiv.org/abs/1502.03509
"""

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------------------------------------------------------

class MaskedLinear(nn.Linear):
""" same as Linear except has a configurable mask on the weights """

def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias)
self.register_buffer('mask', torch.ones(out_features, in_features))

def set_mask(self, mask):
self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T))

def forward(self, input):
return F.linear(input, self.mask * self.weight, self.bias)

class MADE(nn.Module):
def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False):
"""
nin: integer; number of inputs
hidden sizes: a list of integers; number of units in hidden layers
nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution
num_masks: can be used to train ensemble over orderings/connections
natural_ordering: force natural ordering of dimensions, don't use random permutations
"""

super().__init__()
self.nin = nin
self.hidden_sizes = hidden_sizes

# define a simple MLP neural net
self.net = []
hs = [nin] + hidden_sizes + [nout]
for h0,h1 in zip(hs, hs[1:]):
self.net.extend([
MaskedLinear(h0, h1),
nn.ReLU(),
])
self.net.pop() # pop the last ReLU for the output layer
self.net = nn.Sequential(*self.net)

# seeds for orders/connectivities of the model ensemble
self.natural_ordering = natural_ordering
self.num_masks = num_masks
self.seed = 0 # for cycling through num_masks orderings

self.m = {}
self.update_masks() # builds the initial self.m connectivity
# note, we could also precompute the masks and cache them, but this
# could get memory expensive for large number of masks.

def update_masks(self):
if self.m and self.num_masks == 1: return # only a single seed, skip for efficiency
L = len(self.hidden_sizes)

# fetch the next seed and construct a random stream
rng = np.random.RandomState(self.seed)
self.seed = (self.seed + 1) % self.num_masks

# sample the order of the inputs and the connectivity of all neurons
self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin)
for l in range(L):
self.m[l] = rng.randint(self.m[l-1].min(), self.nin-1, size=self.hidden_sizes[l])

# construct the mask matrices
masks = [self.m[l-1][:,None] <= self.m[l][None,:] for l in range(L)]
masks.append(self.m[L-1][:,None] < self.m[-1][None,:])

# set the masks in all MaskedLinear layers
layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)]
for l,m in zip(layers, masks):
l.set_mask(m)

def forward(self, x):
return self.net(x)

# ------------------------------------------------------------------------------

if __name__ == '__main__':
from torch.autograd import Variable

# run a quick and dirty test for the autoregressive property
D = 10
rng = np.random.RandomState(14)
x = (rng.rand(1, D) > 0.5).astype(np.float32)

# check both natural ordering and not
for natural in [True, False]:
# check a few configurations of hidden units, depths 1,2,3
for hiddens in [[200], [200,220], [200,220,230]]:

print("checking hiddens %s with natural = %s" % (hiddens, natural))
model = MADE(D, hiddens, D, natural_ordering=natural)

# run backpropagation for each dimension to compute what other
# dimensions it depends on.
res = []
for k in range(D):
xtr = Variable(torch.from_numpy(x), requires_grad=True)
xtrhat = model(xtr)
loss = xtrhat[0,k]
loss.backward()

depends = (xtr.grad[0].numpy() != 0).astype(np.uint8)
depends_ix = list(np.where(depends)[0])
isok = k not in depends_ix

res.append((len(depends_ix), k, depends_ix, isok))

# pretty print the dependencies
res.sort()
for nl, k, ix, isok in res:
print("output %d depends on inputs: %30s : %s" % (k, ix, "OK" if isok else "NOTOK"))

96 run.py
@@ -0,0 +1,96 @@
"""
Trains MADE on Binarized MNIST, which can be downloaded here:
https://github.com/mgermain/MADE/releases/download/ICML2015/binarized_mnist.npz
"""
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from made import MADE

# ------------------------------------------------------------------------------
def run_epoch(split, upto=None):
torch._C.set_grad_enabled(split == 'train') # enable/disable grad for efficiency of forwarding test batches
model.train() if split == 'train' else model.eval()
nsamples = 1 if split == 'train' else args.samples
x = xtr if split == 'train' else xte
N,D = x.size()
B = 100 # batch size
nsteps = N//B if upto is None else min(N//B, upto)
lossfs = []
for step in range(nsteps):

# fetch the next batch of data
xb = Variable(x[step*B:step*B+B])

# get the logits, potentially run the same batch a number of times, resampling each time
xbhat = torch.zeros_like(xb)
for s in range(nsamples):
# perform order/connectivity-agnostic training by resampling the masks
if step % args.resample_every == 0 or split == 'test': # if in test, cycle masks every time
model.update_masks()
# forward the model
xbhat += model(xb)
xbhat /= nsamples

# evaluate the binary cross entropy loss
loss = F.binary_cross_entropy_with_logits(xbhat, xb, size_average=False) / B
lossf = loss.data.item()
lossfs.append(lossf)

# backward/update
if split == 'train':
opt.zero_grad()
loss.backward()
opt.step()

print("%s epoch average loss: %f" % (split, np.mean(lossfs)))
# ------------------------------------------------------------------------------

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data-path', required=True, type=str, help="Path to binarized_mnist.npz")
parser.add_argument('-q', '--hiddens', type=str, default='500', help="Comma separated sizes for hidden layers, e.g. 500, or 500,500")
parser.add_argument('-n', '--num-masks', type=int, default=1, help="Number of orderings for order/connection-agnostic training")
parser.add_argument('-r', '--resample-every', type=int, default=20, help="For efficiency we can choose to resample orders/masks only once every this many steps")
parser.add_argument('-s', '--samples', type=int, default=1, help="How many samples of connectivity/masks to average logits over during inference")
args = parser.parse_args()
# --------------------------------------------------------------------------

# reproducibility is good
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

# load the dataset
print("loading binarized mnist from", args.data_path)
mnist = np.load(args.data_path)
xtr, xte = mnist['train_data'], mnist['valid_data']
xtr = torch.from_numpy(xtr).cuda()
xte = torch.from_numpy(xte).cuda()

# construct model and ship to GPU
hidden_list = list(map(int, args.hiddens.split(',')))
model = MADE(xtr.size(1), hidden_list, xtr.size(1), num_masks=args.num_masks)
print("number of model parameters:",sum([np.prod(p.size()) for p in model.parameters()]))
model.cuda()

# set up the optimizer
opt = torch.optim.Adam(model.parameters(), 1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=45, gamma=0.1)

# start the training
for epoch in range(100):
print("epoch %d" % (epoch, ))
scheduler.step(epoch)
run_epoch('test', upto=5) # run only a few batches for approximate test accuracy
run_epoch('train')

print("optimization done. full test set eval:")
run_epoch('test')

0 comments on commit 08431d9

Please sign in to comment.
You can’t perform that action at this time.