Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
neural-networks-and-deep-learning/fig/weight_initialization.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
89 lines (78 sloc)
3.12 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""weight_initialization | |
~~~~~~~~~~~~~~~~~~~~~~~~ | |
This program shows how weight initialization affects training. In | |
particular, we'll plot out how the classification accuracies improve | |
using either large starting weights, whose standard deviation is 1, or | |
the default starting weights, whose standard deviation is 1 over the | |
square root of the number of input neurons. | |
""" | |
# Standard library | |
import json | |
import random | |
import sys | |
# My library | |
sys.path.append('../src/') | |
import mnist_loader | |
import network2 | |
# Third-party libraries | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def main(filename, n, eta): | |
run_network(filename, n, eta) | |
make_plot(filename) | |
def run_network(filename, n, eta): | |
"""Train the network using both the default and the large starting | |
weights. Store the results in the file with name ``filename``, | |
where they can later be used by ``make_plots``. | |
""" | |
# Make results more easily reproducible | |
random.seed(12345678) | |
np.random.seed(12345678) | |
training_data, validation_data, test_data = mnist_loader.load_data_wrapper() | |
net = network2.Network([784, n, 10], cost=network2.CrossEntropyCost) | |
print "Train the network using the default starting weights." | |
default_vc, default_va, default_tc, default_ta \ | |
= net.SGD(training_data, 30, 10, eta, lmbda=5.0, | |
evaluation_data=validation_data, | |
monitor_evaluation_accuracy=True) | |
print "Train the network using the large starting weights." | |
net.large_weight_initializer() | |
large_vc, large_va, large_tc, large_ta \ | |
= net.SGD(training_data, 30, 10, eta, lmbda=5.0, | |
evaluation_data=validation_data, | |
monitor_evaluation_accuracy=True) | |
f = open(filename, "w") | |
json.dump({"default_weight_initialization": | |
[default_vc, default_va, default_tc, default_ta], | |
"large_weight_initialization": | |
[large_vc, large_va, large_tc, large_ta]}, | |
f) | |
f.close() | |
def make_plot(filename): | |
"""Load the results from the file ``filename``, and generate the | |
corresponding plot. | |
""" | |
f = open(filename, "r") | |
results = json.load(f) | |
f.close() | |
default_vc, default_va, default_tc, default_ta = results[ | |
"default_weight_initialization"] | |
large_vc, large_va, large_tc, large_ta = results[ | |
"large_weight_initialization"] | |
# Convert raw classification numbers to percentages, for plotting | |
default_va = [x/100.0 for x in default_va] | |
large_va = [x/100.0 for x in large_va] | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
ax.plot(np.arange(0, 30, 1), large_va, color='#2A6EA6', | |
label="Old approach to weight initialization") | |
ax.plot(np.arange(0, 30, 1), default_va, color='#FFA933', | |
label="New approach to weight initialization") | |
ax.set_xlim([0, 30]) | |
ax.set_xlabel('Epoch') | |
ax.set_ylim([85, 100]) | |
ax.set_title('Classification accuracy') | |
plt.legend(loc="lower right") | |
plt.show() | |
if __name__ == "__main__": | |
main() |