Permalink
Browse files

More configs

  • Loading branch information...
1 parent 2002b22 commit c2e9d12f21b02a7c1aaa7f38071ec4bfc611e869 Mihai Maruseac committed May 1, 2011
Showing with 51 additions and 23 deletions.
  1. +27 −3 src/config.py
  2. +0 −8 src/globaldefs.py
  3. +2 −1 src/gui.py
  4. +15 −6 src/network.py
  5. +7 −5 src/units.py
View
30 src/config.py
@@ -22,7 +22,6 @@ def __init__(self, parent, title=''):
self._d = gtk.Dialog(title, parent,
gtk.DIALOG_MODAL | gtk.DIALOG_DESTROY_WITH_PARENT, btn)
self._d.set_deletable(False)
- self._d.set_size_request(420, 220)
self._d.set_resizable(False)
self._build_gui()
self._d.show_all()
@@ -39,6 +38,9 @@ def _build_gui(self):
_checkHBox = gtk.HBox()
self._build_topology_gui(_checkHBox)
self._build_activation_gui(_checkHBox)
+ _topVBox.pack_start(_checkHBox, False, False, 5)
+ _checkHBox = gtk.HBox()
+ self._build_params_gui(_checkHBox)
self._build_extra_gui(_checkHBox)
_topVBox.pack_start(_checkHBox, False, False, 5)
self._d.vbox.add(_topVBox)
@@ -66,6 +68,17 @@ def _build_activation_gui(self, _checkHBox):
_aVBox.add(self._2log)
_aVBox.add(self._tanh)
+ def _build_params_gui(self, _checkHBox):
+ """
+ Builds the GUI for setting the parameters.
+ """
+ _aVBox = self._build_compund_gui_box(_checkHBox, "Parameters:")
+ self._etaCounter = self._build_counter('learning rate:', .1, 5, _aVBox, .1, 1)
+ self._alphaCounter = self._build_counter('momentum rate:', .2, .8, _aVBox, .1, 1, False)
+ self._minRmsCounter = self._build_counter('minimum error:', 0, .1, _aVBox, .01)
+ self._minRmsCounter.get_adjustment().set_value(.01)
+ self._minDeltaRmsCounter = self._build_counter('minimum delta error:', 0, .1, _aVBox, .01)
+
def _build_extra_gui(self, _checkHBox):
"""
Builds the GUI part for extra options.
@@ -76,6 +89,7 @@ def _build_extra_gui(self, _checkHBox):
self._minCounter.get_adjustment().set_value(-1)
self._maxCounter.get_adjustment().set_value(1)
self._momentum = gtk.CheckButton('Use momentum')
+ self._momentum.connect('clicked', self.__on_momentum)
self._recurrent = gtk.CheckButton('Recurent network')
_aVBox.add(self._momentum)
_aVBox.add(self._recurrent)
@@ -88,12 +102,14 @@ def _build_IO_gui(self, _topVBox):
_topVBox VBox holding the widgets built by this function
"""
_fileHBox = gtk.HBox()
- _topVBox.pack_start(_fileHBox, False, False, 5)
+ _fileVBox = gtk.VBox()
+ _topVBox.pack_start(_fileVBox, False, False, 5)
_fileLabel = gtk.Label('Input filename:')
_fileHBox.pack_start(_fileLabel, False, False, 5)
self._fileChoose = gtk.FileChooserButton("Select input filename")
_fileHBox.pack_start(self._fileChoose, True, True, 5)
- self._rCounter = self._build_counter('Max steps:', 1000, 3000, _fileHBox, 100, 0)
+ _fileVBox.pack_start(_fileHBox, False, False, 5)
+ self._rCounter = self._build_counter('Max steps:', 1000, 3000, _fileVBox, 100, 0)
def _build_compund_gui_box(self, _checkHBox, frame):
"""
@@ -239,6 +255,11 @@ def _complete_config(self):
self._configDict['momentum'] = self._momentum.get_active()
self._configDict['recurrent'] = self._recurrent.get_active()
+ self._configDict['alpha'] = self._alphaCounter.get_value()
+ self._configDict['eta'] = self._etaCounter.get_value()
+ self._configDict['min_rms'] = self._minRmsCounter.get_value()
+ self._configDict['min_delta_rms'] = self._minDeltaRmsCounter.get_value()
+
if self._configDict['minW'] > self._configDict['maxW'] - .1:
self._report(gtk.MESSAGE_WARNING, "Invalid interval for weights, ignored")
self._configDict['minW'] = -1
@@ -273,3 +294,6 @@ def _report(self, msg_type, text):
md.run()
md.destroy()
+ def __on_momentum(self, widget, data=None):
+ self._alphaCounter.set_sensitive(not self._alphaCounter.get_sensitive())
+
View
8 src/globaldefs.py
@@ -20,14 +20,6 @@
VAL_PLOT_SUFFIX = '.val.png'
LOG_SUFFIX = '.log'
-XXX = 800
-YYY = 600
-
LOGNAME = 'MAIN.log'
LOGFNAME = LOGNAME
-MIN_RMS = 1e-4
-MIN_DRMS = 0
-ETA = .2
-ALPHA = .5
-
View
3 src/gui.py
@@ -24,7 +24,8 @@ def __init__(self):
widgets.
"""
super(MainWindow, self).__init__()
- self.set_size_request(XXX, YYY)
+ self.set_size_request(800, 600)
+
self.set_title(TITLE)
self.set_icon_from_file(ICON_FILE)
self.connect('delete_event', self.__on_exit)
View
21 src/network.py
@@ -10,7 +10,6 @@
import normalizer
import saver
from units import *
-from globaldefs import *
def log(x):
"""
@@ -64,7 +63,6 @@ def __init__(self, config, gui, graph):
config User configuration.
"""
- print config
self._gui = gui
self._parse_network(config)
self._parse_activation(config)
@@ -97,7 +95,7 @@ def learn_step(self):
Bootstraps the learning phase.
"""
rms = self._do_one_learning_step()
- done = rms < MIN_RMS or abs(rms - self._orms) < MIN_DRMS
+ done = rms < self._MIN_RMS or abs(rms - self._orms) < self._MIN_DRMS
self._orms = rms
self._grapher.graph()
@@ -212,6 +210,14 @@ def _parse_network(self, config):
# momentum?
self._momentum = config['momentum']
+ # params
+ self._eta = config['eta']
+ self._alpha = config['alpha']
+
+ # rms params
+ self._MIN_RMS = config['min_rms']
+ self._MIN_DRMS = config['min_delta_rms']
+
def _prepare_data(self, config):
"""
Reads learning set, normalizing it and preparing the auxiliary lists
@@ -262,7 +268,8 @@ def __build_hidden1(self):
"""
self._hidden1 = []
for i in range(self._h1):
- n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum, 'h1{0}'.format(i))
+ n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum,
+ 'h1{0}'.format(i), self._eta, self._alpha)
n.set_recurrent(self._recurrent)
for inp in self._inputs:
n.connect(inp)
@@ -276,7 +283,8 @@ def __build_hidden2(self):
"""
self._hidden2 = []
for i in range(self._h2):
- n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum, 'h2{0}'.format(i))
+ n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum,
+ 'h2{0}'.format(i), self._eta, self._alpha)
n.set_recurrent(self._recurrent)
if self._h1:
for inp in self._hidden1:
@@ -292,7 +300,8 @@ def __build_output(self):
"""
Builds the output layer and the end of the network.
"""
- self._output = Neuron(self._mW, self._MW, self._f, self._df, self._momentum, 'o')
+ self._output = Neuron(self._mW, self._MW, self._f, self._df,
+ self._momentum, 'o', self._eta, self._alpha)
self._output.set_recurrent(self._recurrent)
if self._h2:
for inp in self._hidden2:
View
12 src/units.py
@@ -166,7 +166,7 @@ class Neuron(Unit):
self.value() will return the output of the neuron
"""
- def __init__(self, minW, maxW, f, df, momentum, name=''):
+ def __init__(self, minW, maxW, f, df, momentum, name, eta, alpha):
super(Neuron, self).__init__(name, 0)
self._min = minW
self._max = maxW
@@ -177,6 +177,8 @@ def __init__(self, minW, maxW, f, df, momentum, name=''):
self._momentum = momentum
if self._momentum:
self._ow = []
+ self._ETA = eta
+ self._ALPHA = alpha
def inputs(self):
return self._inputs
@@ -245,9 +247,9 @@ def report_and_learn_from_error(self):
for i in range(len(self._weights)):
w, inp = self._weights[i], self._inputs[i]
- delta = ETA * self._err * self._df(self._value) * inp.value()
+ delta = self._ETA * self._err * self._df(self._value) * inp.value()
if self._momentum:
- delta += ETA * ALPHA * self._ow[i]
+ delta += self._ETA * self._ALPHA * self._ow[i]
self._ow[i] = delta
_logger.info('Neuron {0}: delta weight{1}: {2}'.format(self._name, i, delta))
self._weights[i] -= delta
@@ -257,9 +259,9 @@ def report_and_learn_from_error(self):
self._weights[i] = 1
_logger.info('Neuron {0}: weight{1}: {2}'.format(self._name, i, self._weights[i]))
if self._selfw:
- delta = ETA * self._err * self._df(self._value) * self._value
+ delta = self._ETA * self._err * self._df(self._value) * self._value
if self._momentum:
- delta += ETA * ALPHA * self._sow
+ delta += self._ETA * self._ALPHA * self._sow
self._sow = delta
_logger.info('Neuron {0}: delta self weight: {1}'.format(self._name, self._selfw))
self._selfw -= delta

0 comments on commit c2e9d12

Please sign in to comment.