Skip to content

Commit

Permalink
Merge pull request #8 from ChenglongChen/master
Browse files Browse the repository at this point in the history
Add several functionalities to mlp.py
  • Loading branch information
mdenil committed Mar 24, 2014
2 parents 391b022 + e2808ed commit 6f8c362
Showing 1 changed file with 104 additions and 23 deletions.
127 changes: 104 additions & 23 deletions mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@
from load_data import load_umontreal_data, load_mnist


##################################
## Various activation functions ##
##################################
#### rectified linear unit
def ReLU(x):
y = T.maximum(0.0, x)
return(y)
#### sigmoid
def Sigmoid(x):
y = T.nnet.sigmoid(x)
return(y)
#### tanh
def Tanh(x):
y = T.tanh(x)
return(y)

class HiddenLayer(object):
def __init__(self, rng, input, n_in, n_out,
activation, W=None, b=None,
Expand Down Expand Up @@ -64,12 +80,12 @@ def _dropout_from_layer(rng, layer, p):

class DropoutHiddenLayer(HiddenLayer):
def __init__(self, rng, input, n_in, n_out,
activation, use_bias, W=None, b=None):
activation, dropout_rate, use_bias, W=None, b=None):
super(DropoutHiddenLayer, self).__init__(
rng=rng, input=input, n_in=n_in, n_out=n_out, W=W, b=b,
activation=activation, use_bias=use_bias)

self.output = _dropout_from_layer(rng, self.output, p=0.5)
self.output = _dropout_from_layer(rng, self.output, p=dropout_rate)


class MLP(object):
Expand All @@ -81,38 +97,44 @@ def __init__(self,
rng,
input,
layer_sizes,
dropout_rates,
activations,
use_bias=True):

rectified_linear_activation = lambda x: T.maximum(0.0, x)
#rectified_linear_activation = lambda x: T.maximum(0.0, x)

# Set up all the hidden layers
weight_matrix_sizes = zip(layer_sizes, layer_sizes[1:])
self.layers = []
self.dropout_layers = []
next_layer_input = input
first_layer = True
# dropout the input with prob 0.2
next_dropout_layer_input = _dropout_from_layer(rng, input, p=0.2)
#first_layer = True
# dropout the input
next_dropout_layer_input = _dropout_from_layer(rng, input, p=dropout_rates[0])
layer_counter = 0
for n_in, n_out in weight_matrix_sizes[:-1]:
next_dropout_layer = DropoutHiddenLayer(rng=rng,
input=next_dropout_layer_input,
activation=rectified_linear_activation,
n_in=n_in, n_out=n_out, use_bias=use_bias)
activation=activations[layer_counter],
n_in=n_in, n_out=n_out, use_bias=use_bias,
dropout_rate=dropout_rates[layer_counter])
self.dropout_layers.append(next_dropout_layer)
next_dropout_layer_input = next_dropout_layer.output

# Reuse the paramters from the dropout layer here, in a different
# path through the graph.
next_layer = HiddenLayer(rng=rng,
input=next_layer_input,
activation=rectified_linear_activation,
W=next_dropout_layer.W * (0.8 if first_layer else 0.5),
activation=activations[layer_counter],
# scale the weight matrix W with (1-p)
W=next_dropout_layer.W * (1 - dropout_rates[layer_counter]),
b=next_dropout_layer.b,
n_in=n_in, n_out=n_out,
use_bias=use_bias)
self.layers.append(next_layer)
next_layer_input = next_layer.output
first_layer = False
#first_layer = False
layer_counter += 1

# Set up the output layer
n_in, n_out = weight_matrix_sizes[-1]
Expand All @@ -124,7 +146,8 @@ def __init__(self,
# Again, reuse paramters in the dropout output.
output_layer = LogisticRegression(
input=next_layer_input,
W=dropout_output_layer.W * 0.5,
# scale the weight matrix W with (1-p)
W=dropout_output_layer.W * (1 - dropout_rates[-1]),
b=dropout_output_layer.b,
n_in=n_in, n_out=n_out)
self.layers.append(output_layer)
Expand All @@ -147,11 +170,15 @@ def test_mlp(
squared_filter_length_limit,
n_epochs,
batch_size,
mom_params,
activations,
dropout,
dropout_rates,
results_file_name,
layer_sizes,
dataset,
use_bias):
use_bias,
random_seed=1234):
"""
The dataset is the one from the mlp demo on deeplearning.net. This training
function is lifted from there almost exactly.
Expand All @@ -162,6 +189,14 @@ def test_mlp(
"""
assert len(layer_sizes) - 1 == len(dropout_rates)

# extract the params for momentum
mom_start = mom_params["start"]
mom_end = mom_params["end"]
mom_epoch_interval = mom_params["interval"]


datasets = load_mnist(dataset)
train_set_x, train_set_y = datasets[0]
valid_set_x, valid_set_y = datasets[1]
Expand All @@ -187,11 +222,14 @@ def test_mlp(
learning_rate = theano.shared(np.asarray(initial_learning_rate,
dtype=theano.config.floatX))

rng = np.random.RandomState(1234)
rng = np.random.RandomState(random_seed)

# construct the MLP class
classifier = MLP(rng=rng, input=x,
layer_sizes=layer_sizes, use_bias=use_bias)
layer_sizes=layer_sizes,
dropout_rates=dropout_rates,
activations=activations,
use_bias=use_bias)

# Build the expresson for the cost function.
cost = classifier.negative_log_likelihood(y)
Expand Down Expand Up @@ -230,26 +268,42 @@ def test_mlp(
gparams_mom.append(gparam_mom)

# Compute momentum for the current epoch
mom = ifelse(epoch < 500,
0.5*(1. - epoch/500.) + 0.99*(epoch/500.),
0.99)
mom = ifelse(epoch < mom_epoch_interval,
mom_start*(1.0 - epoch/mom_epoch_interval) + mom_end*(epoch/mom_epoch_interval),
mom_end)

# Update the step direction using momentum
updates = OrderedDict()
for gparam_mom, gparam in zip(gparams_mom, gparams):
updates[gparam_mom] = mom * gparam_mom + (1. - mom) * gparam
# Misha Denil's original version
#updates[gparam_mom] = mom * gparam_mom + (1. - mom) * gparam

# change the update rule to match Hinton's dropout paper
updates[gparam_mom] = mom * gparam_mom - (1. - mom) * learning_rate * gparam

# ... and take a step along that direction
for param, gparam_mom in zip(classifier.params, gparams_mom):
stepped_param = param - learning_rate * updates[gparam_mom]
# Misha Denil's original version
#stepped_param = param - learning_rate * updates[gparam_mom]

# since we have included learning_rate in gparam_mom, we don't need it
# here
stepped_param = param + updates[gparam_mom]

# This is a silly hack to constrain the norms of the rows of the weight
# matrices. This just checks if there are two dimensions to the
# parameter and constrains it if so... maybe this is a bit silly but it
# should work for now.
if param.get_value(borrow=True).ndim == 2:
squared_norms = T.sum(stepped_param**2, axis=1).reshape((stepped_param.shape[0],1))
scale = T.clip(T.sqrt(squared_filter_length_limit / squared_norms), 0., 1.)
#squared_norms = T.sum(stepped_param**2, axis=1).reshape((stepped_param.shape[0],1))
#scale = T.clip(T.sqrt(squared_filter_length_limit / squared_norms), 0., 1.)
#updates[param] = stepped_param * scale

# constrain the norms of the COLUMNs of the weight, according to
# https://github.com/BVLC/caffe/issues/109
col_norms = T.sqrt(T.sum(T.sqr(stepped_param), axis=0))
desired_norms = T.clip(col_norms, 0, T.sqrt(squared_filter_length_limit))
scale = desired_norms / (1e-7 + col_norms)
updates[param] = stepped_param * scale
else:
updates[param] = stepped_param
Expand Down Expand Up @@ -320,13 +374,36 @@ def test_mlp(

if __name__ == '__main__':
import sys

# set the random seed to enable reproduciable results
# It is used for initializing the weight matrices
# and generating the dropout masks for each mini-batch
random_seed = 1234

initial_learning_rate = 1.0
learning_rate_decay = 0.998
squared_filter_length_limit = 15.0
n_epochs = 3000
batch_size = 100
layer_sizes = [ 28*28, 1200, 1200, 10 ]

# dropout rate for each layer
dropout_rates = [ 0.2, 0.5, 0.5 ]
# activation functions for each layer
# For this demo, we don't need to set the activation functions for the
# on top layer, since it is always 10-way Softmax
activations = [ ReLU, ReLU ]

#### the params for momentum
mom_start = 0.5
mom_end = 0.99
# for epoch in [0, mom_epoch_interval], the momentum increases linearly
# from mom_start to mom_end. After mom_epoch_interval, it stay at mom_end
mom_epoch_interval = 500
mom_params = {"start": mom_start,
"end": mom_end,
"interval": mom_epoch_interval}

dataset = 'data/mnist_batches.npz'
#dataset = 'data/mnist.pkl.gz'

Expand All @@ -352,8 +429,12 @@ def test_mlp(
n_epochs=n_epochs,
batch_size=batch_size,
layer_sizes=layer_sizes,
mom_params=mom_params,
activations=activations,
dropout=dropout,
dropout_rates=dropout_rates,
dataset=dataset,
results_file_name=results_file_name,
use_bias=False)
use_bias=False,
random_seed=random_seed)

0 comments on commit 6f8c362

Please sign in to comment.