Skip to content

Commit

Permalink
MNIST Benchmarks (#1334)
Browse files Browse the repository at this point in the history
* Rename README.md to README1.md

* Rename examples/mnist/README1.md to examples/mnist/basic-mnist-benchmarks/README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Delete mnist_pytorch.py

* Add files via upload

* Update README.md

* Update README.md

* Update mnist_pytorch.py

* Update README.md

* Update README.md

* Update mnist_dynet_autobatch.py

* Update mnist_dynet_minibatch.py

* Update mnist_pytorch.py

* Update mnist_dynet_minibatch.py

* Update mnist_dynet_autobatch.py
  • Loading branch information
gpengzhi authored and neubig committed Apr 6, 2018
1 parent 857144a commit 4e68164
Show file tree
Hide file tree
Showing 4 changed files with 497 additions and 0 deletions.
71 changes: 71 additions & 0 deletions examples/mnist/basic-mnist-benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# MNIST Benchmarks

Here is the comparison between Dynet and PyTorch on the "Hello World" example of deep learning : MNIST digit classification.

## Usage (Dynet)

Download the MNIST dataset from the [official website](http://yann.lecun.com/exdb/mnist/) and decompress it.

<pre>
wget -O - http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz | gunzip > train-images.idx3-ubyte
wget -O - http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz | gunzip > train-labels.idx1-ubyte
wget -O - http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz | gunzip > t10k-images.idx3-ubyte
wget -O - http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | gunzip > t10k-labels.idx1-ubyte
</pre>

Install the GPU version of Dynet according to the instructions on the [official website](http://dynet.readthedocs.io/en/latest/python.html#installing-a-cutting-edge-and-or-gpu-version).

The architecture of the Convolutional Neural Network follows the architecture used in the [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/layers).

Here are two Python scripts for Dynet. One (`mnist_dynet_minibatch.py`) applies minibatch, and the other one (`mnist_dynet_autobatch.py`) applies autobatch.

Then, run the training:
<pre>
python mnist_dynet_minibatch.py --dynet_gpus 1
</pre>
or
<pre>
python mnist_dynet_autobatch.py --dynet_gpus 1 --dynet_autobatch 1
</pre>

## Usage (PyTorch)

The code of `mnist_pytorch.py` follows the same line as that of `main.py` in [PyTorch Examples](https://github.com/pytorch/examples/tree/master/mnist). We changed the network architecture as follows in order to match the architecture used in the [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/layers).

<pre>
class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
self.fc1 = nn.Linear(7*7*64, 1024)
self.fc2 = nn.Linear(1024, 10, bias=False)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 7*7*64)
x = F.relu(self.fc1(x))
x = F.dropout(x, 0.4)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
</pre>

Install CUDA version of PyTorch according to the instructions on the [official website](http://pytorch.org/).

Then, run the training:

<pre>
python mnist_pytorch.py
</pre>

## Benchmark

Batch size: 64, learning rate: 0.01.

| OS | Device | Framework | Speed | Accuracy (After 20 Epochs)|
| --- | --- | --- | --- | --- |
| Ubuntu 16.04 | GeForce GTX 1080 Ti | PyTorch | ~ 4.49±0.11 s per epoch | 98.95% |
| Ubuntu 16.04 | GeForce GTX 1080 Ti | DyNet (autobatch) | ~ 8.58±0.09 s per epoch | 99.14% |
| Ubuntu 16.04 | GeForce GTX 1080 Ti | DyNet (minibatch) | ~ 4.13±0.13 s per epoch | 99.16% |
152 changes: 152 additions & 0 deletions examples/mnist/basic-mnist-benchmarks/mnist_dynet_autobatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from __future__ import division
import os
import struct
import argparse
import random
import time
import numpy as np
# import dynet as dy
# import dynet_config
# dynet_config.set_gpu()
import dynet as dy

# First, download the MNIST dataset from the official website and decompress it.
# wget -O - http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz | gunzip > train-images.idx3-ubyte
# wget -O - http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz | gunzip > train-labels.idx1-ubyte
# wget -O - http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz | gunzip > t10k-images.idx3-ubyte
# wget -O - http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | gunzip > t10k-labels.idx1-ubyte

parser = argparse.ArgumentParser(description='DyNet MNIST Example')
parser.add_argument("--path", type=str, default=".",
help="Path to the MNIST data files (unzipped).")
parser.add_argument('--batch-size', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20,
help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--log-interval', type=int, default=10,
help='how many batches to wait before logging training status')
parser.add_argument("--dynet_autobatch", type=int, default=0,
help="Set to 1 to turn on autobatching.")
parser.add_argument("--dynet_gpus", type=int, default=0,
help="Set to 1 to train on GPU.")

HIDDEN_DIM = 1024
DROPOUT_RATE = 0.4

# Adapted from https://gist.github.com/akesling/5358964
def read(dataset, path):
if dataset is "training":
fname_img = os.path.join(path, "train-images.idx3-ubyte")
fname_lbl = os.path.join(path, "train-labels.idx1-ubyte")
elif dataset is "testing":
fname_img = os.path.join(path, "t10k-images.idx3-ubyte")
fname_lbl = os.path.join(path, "t10k-labels.idx1-ubyte")
else:
raise ValueError("dataset must be 'training' or 'testing'")

with open(fname_lbl, 'rb') as flbl:
_, _ = struct.unpack(">II", flbl.read(8))
lbl = np.fromfile(flbl, dtype=np.int8)

with open(fname_img, 'rb') as fimg:
_, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = np.multiply(np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols), 1.0/255.0)

get_img = lambda idx: (lbl[idx], img[idx])

for i in range(len(lbl)):
yield get_img(i)

class mnist_network(object):

def __init__(self, m):
self.pConv1 = m.add_parameters((5, 5, 1, 32))
self.pB1 = m.add_parameters((32, ))
self.pConv2 = m.add_parameters((5, 5, 32, 64))
self.pB2 = m.add_parameters((64, ))
self.pW1 = m.add_parameters((HIDDEN_DIM, 7*7*64))
self.pB3 = m.add_parameters((HIDDEN_DIM, ))
self.pW2 = m.add_parameters((10, HIDDEN_DIM))

def __call__(self, inputs, dropout=False):
x = dy.inputTensor(inputs)
conv1 = dy.parameter(self.pConv1)
b1 = dy.parameter(self.pB1)
x = dy.conv2d_bias(x, conv1, b1, [1, 1], is_valid=False)
x = dy.rectify(dy.maxpooling2d(x, [2, 2], [2, 2]))
conv2 = dy.parameter(self.pConv2)
b2 = dy.parameter(self.pB2)
x = dy.conv2d_bias(x, conv2, b2, [1, 1], is_valid=False)
x = dy.rectify(dy.maxpooling2d(x, [2, 2], [2, 2]))
x = dy.reshape(x, (7*7*64, 1))
w1 = dy.parameter(self.pW1)
b3 = dy.parameter(self.pB3)
h = dy.rectify(w1*x+b3)
if dropout:
h = dy.dropout(h, DROPOUT_RATE)
w2 = dy.parameter(self.pW2)
output = w2*h
# output = dy.softmax(w2*h)
return output

def create_network_return_loss(self, inputs, expected_output, dropout=False):
out = self(inputs, dropout)
loss = dy.pickneglogsoftmax(out, expected_output)
# loss = -dy.log(dy.pick(out, expected_output))
return loss

def create_network_return_best(self, inputs, dropout=False):
out = self(inputs, dropout)
out = dy.softmax(out)
return np.argmax(out.npvalue())
# return np.argmax(out.npvalue())

args = parser.parse_args()
train_data = [(lbl, img) for (lbl, img) in read("training", args.path)]
test_data = [(lbl, img) for (lbl, img) in read("testing", args.path)]

m = dy.ParameterCollection()
network = mnist_network(m)
trainer = dy.SimpleSGDTrainer(m, learning_rate=args.lr)

def train(epoch):
random.shuffle(train_data)
i = 0
epoch_start = time.time()
while i < len(train_data):
dy.renew_cg()
losses = []
for lbl, img in train_data[i:i+args.batch_size]:
loss = network.create_network_return_loss(img, lbl, dropout=True)
losses.append(loss)
mbloss = dy.average(losses)
if (int(i/args.batch_size)) % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, i, len(train_data),
100. * i/len(train_data), mbloss.value()))
mbloss.backward()
trainer.update()
i += args.batch_size
epoch_end = time.time()
print("{} s per epoch".format(epoch_end-epoch_start))

def test():
correct = 0
dy.renew_cg()
losses = []
for lbl, img in test_data:
losses.append(network.create_network_return_loss(img, lbl, dropout=False))
if lbl == network.create_network_return_best(img, dropout=False):
correct += 1
mbloss = dy.average(losses)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
mbloss.value(), correct, len(test_data),
100. * correct / len(test_data)))

for epoch in range(1, args.epochs + 1):
train(epoch)
test()

# m.save("/tmp/tmp.model")
157 changes: 157 additions & 0 deletions examples/mnist/basic-mnist-benchmarks/mnist_dynet_minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import division
import os
import struct
import argparse
import random
import time
import numpy as np
# import dynet as dy
# import dynet_config
# dynet_config.set_gpu()
import dynet as dy

# First, download the MNIST dataset from the official website and decompress it.
# wget -O - http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz | gunzip > train-images.idx3-ubyte
# wget -O - http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz | gunzip > train-labels.idx1-ubyte
# wget -O - http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz | gunzip > t10k-images.idx3-ubyte
# wget -O - http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | gunzip > t10k-labels.idx1-ubyte

parser = argparse.ArgumentParser(description='DyNet MNIST Example')
parser.add_argument("--path", type=str, default=".",
help="Path to the MNIST data files (unzipped).")
parser.add_argument('--batch-size', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=20,
help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--log-interval', type=int, default=10,
help='how many batches to wait before logging training status')
parser.add_argument("--dynet_autobatch", type=int, default=0,
help="Set to 1 to turn on autobatching.")
parser.add_argument("--dynet_gpus", type=int, default=0,
help="Set to 1 to train on GPU.")

HIDDEN_DIM = 1024
DROPOUT_RATE = 0.4

# Adapted from https://gist.github.com/akesling/5358964
def read(dataset, path):
if dataset is "training":
fname_img = os.path.join(path, "train-images.idx3-ubyte")
fname_lbl = os.path.join(path, "train-labels.idx1-ubyte")
elif dataset is "testing":
fname_img = os.path.join(path, "t10k-images.idx3-ubyte")
fname_lbl = os.path.join(path, "t10k-labels.idx1-ubyte")
else:
raise ValueError("dataset must be 'training' or 'testing'")

with open(fname_lbl, 'rb') as flbl:
_, _ = struct.unpack(">II", flbl.read(8))
lbl = np.fromfile(flbl, dtype=np.int8)

with open(fname_img, 'rb') as fimg:
_, _, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = np.multiply(np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols), 1.0/255.0)

get_img = lambda idx: (lbl[idx], img[idx])

for i in range(len(lbl)):
yield get_img(i)

class mnist_network(object):

def __init__(self, m):
self.pConv1 = m.add_parameters((5, 5, 1, 32))
self.pB1 = m.add_parameters((32, ))
self.pConv2 = m.add_parameters((5, 5, 32, 64))
self.pB2 = m.add_parameters((64, ))
self.pW1 = m.add_parameters((HIDDEN_DIM, 7*7*64))
self.pB3 = m.add_parameters((HIDDEN_DIM, ))
self.pW2 = m.add_parameters((10, HIDDEN_DIM))

def __call__(self, inputs, dropout=False):
x = dy.inputTensor(inputs, batched=True)
batchsize = x.dim()[-1]
conv1 = dy.parameter(self.pConv1)
b1 = dy.parameter(self.pB1)
x = dy.conv2d_bias(x, conv1, b1, [1, 1], is_valid=False)
x = dy.rectify(dy.maxpooling2d(x, [2, 2], [2, 2]))
conv2 = dy.parameter(self.pConv2)
b2 = dy.parameter(self.pB2)
x = dy.conv2d_bias(x, conv2, b2, [1, 1], is_valid=False)
x = dy.rectify(dy.maxpooling2d(x, [2, 2], [2, 2]))
x = dy.reshape(x, (7*7*64, 1), batch_size=batchsize)
w1 = dy.parameter(self.pW1)
b3 = dy.parameter(self.pB3)
h = dy.rectify(w1*x+b3)
if dropout:
h = dy.dropout(h, DROPOUT_RATE)
w2 = dy.parameter(self.pW2)
output = w2*h
# output = dy.softmax(w2*h)
return output

def create_network_return_loss(self, inputs, expected_output, dropout=False):
out = self(inputs, dropout)
loss = dy.pickneglogsoftmax_batch(out, expected_output)
# loss = -dy.log(dy.pick(out, expected_output))
return loss

def create_network_return_best(self, inputs, dropout=False):
out = self(inputs, dropout)
out = dy.softmax(out)
return np.argmax(out.npvalue(), 0)
# return np.argmax(out.npvalue())

args = parser.parse_args()
train_data = [(lbl, img) for (lbl, img) in read("training", args.path)]
test_data = [(lbl, img) for (lbl, img) in read("testing", args.path)]

m = dy.ParameterCollection()
network = mnist_network(m)
trainer = dy.SimpleSGDTrainer(m, learning_rate=args.lr)

def train(epoch):
random.shuffle(train_data)
i = 0
epoch_start = time.time()
while i < len(train_data):
dy.renew_cg()
lbls = []
imgs = []
for lbl, img in train_data[i:i+args.batch_size]:
lbls.append(lbl)
imgs.append(img)
losses = network.create_network_return_loss(imgs, lbls, dropout=True)
loss = dy.mean_batches(losses)
if (int(i/args.batch_size)) % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, i, len(train_data),
100. * i/len(train_data), loss.value()))
loss.backward()
trainer.update()
i += args.batch_size
epoch_end = time.time()
print("{} s per epoch".format(epoch_end-epoch_start))

def test():
lbls = []
imgs = []
for lbl, img in test_data:
lbls.append(lbl)
imgs.append(img)
dy.renew_cg()
losses = network.create_network_return_loss(imgs, lbls, dropout=False)
loss = dy.mean_batches(losses)
predicts = network.create_network_return_best(imgs, dropout=False)
correct = np.sum(lbls == predicts[0])
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
loss.value(), correct, len(test_data),
100. * correct / len(test_data)))

for epoch in range(1, args.epochs + 1):
train(epoch)
test()

# m.save("/tmp/tmp.model")

0 comments on commit 4e68164

Please sign in to comment.