Permalink
Cannot retrieve contributors at this time
"""weight_initialization | |
~~~~~~~~~~~~~~~~~~~~~~~~ | |
This program shows how weight initialization affects training. In | |
particular, we'll plot out how the classification accuracies improve | |
using either large starting weights, whose standard deviation is 1, or | |
the default starting weights, whose standard deviation is 1 over the | |
square root of the number of input neurons. | |
""" | |
# Standard library | |
import json | |
import random | |
import sys | |
# My library | |
sys.path.append('../src/') | |
import mnist_loader | |
import network2 | |
# Third-party libraries | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def main(filename, n, eta): | |
run_network(filename, n, eta) | |
make_plot(filename) | |
def run_network(filename, n, eta): | |
"""Train the network using both the default and the large starting | |
weights. Store the results in the file with name ``filename``, | |
where they can later be used by ``make_plots``. | |
""" | |
# Make results more easily reproducible | |
random.seed(12345678) | |
np.random.seed(12345678) | |
training_data, validation_data, test_data = mnist_loader.load_data_wrapper() | |
net = network2.Network([784, n, 10], cost=network2.CrossEntropyCost) | |
print "Train the network using the default starting weights." | |
default_vc, default_va, default_tc, default_ta \ | |
= net.SGD(training_data, 30, 10, eta, lmbda=5.0, | |
evaluation_data=validation_data, | |
monitor_evaluation_accuracy=True) | |
print "Train the network using the large starting weights." | |
net.large_weight_initializer() | |
large_vc, large_va, large_tc, large_ta \ | |
= net.SGD(training_data, 30, 10, eta, lmbda=5.0, | |
evaluation_data=validation_data, | |
monitor_evaluation_accuracy=True) | |
f = open(filename, "w") | |
json.dump({"default_weight_initialization": | |
[default_vc, default_va, default_tc, default_ta], | |
"large_weight_initialization": | |
[large_vc, large_va, large_tc, large_ta]}, | |
f) | |
f.close() | |
def make_plot(filename): | |
"""Load the results from the file ``filename``, and generate the | |
corresponding plot. | |
""" | |
f = open(filename, "r") | |
results = json.load(f) | |
f.close() | |
default_vc, default_va, default_tc, default_ta = results[ | |
"default_weight_initialization"] | |
large_vc, large_va, large_tc, large_ta = results[ | |
"large_weight_initialization"] | |
# Convert raw classification numbers to percentages, for plotting | |
default_va = [x/100.0 for x in default_va] | |
large_va = [x/100.0 for x in large_va] | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
ax.plot(np.arange(0, 30, 1), large_va, color='#2A6EA6', | |
label="Old approach to weight initialization") | |
ax.plot(np.arange(0, 30, 1), default_va, color='#FFA933', | |
label="New approach to weight initialization") | |
ax.set_xlim([0, 30]) | |
ax.set_xlabel('Epoch') | |
ax.set_ylim([85, 100]) | |
ax.set_title('Classification accuracy') | |
plt.legend(loc="lower right") | |
plt.show() | |
if __name__ == "__main__": | |
main() |