Skip to content

Commit

Permalink
More configs
Browse files Browse the repository at this point in the history
  • Loading branch information
Mihai Maruseac committed May 1, 2011
1 parent 2002b22 commit c2e9d12
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 23 deletions.
30 changes: 27 additions & 3 deletions src/config.py
Expand Up @@ -22,7 +22,6 @@ def __init__(self, parent, title=''):
self._d = gtk.Dialog(title, parent,
gtk.DIALOG_MODAL | gtk.DIALOG_DESTROY_WITH_PARENT, btn)
self._d.set_deletable(False)
self._d.set_size_request(420, 220)
self._d.set_resizable(False)
self._build_gui()
self._d.show_all()
Expand All @@ -39,6 +38,9 @@ def _build_gui(self):
_checkHBox = gtk.HBox()
self._build_topology_gui(_checkHBox)
self._build_activation_gui(_checkHBox)
_topVBox.pack_start(_checkHBox, False, False, 5)
_checkHBox = gtk.HBox()
self._build_params_gui(_checkHBox)
self._build_extra_gui(_checkHBox)
_topVBox.pack_start(_checkHBox, False, False, 5)
self._d.vbox.add(_topVBox)
Expand Down Expand Up @@ -66,6 +68,17 @@ def _build_activation_gui(self, _checkHBox):
_aVBox.add(self._2log)
_aVBox.add(self._tanh)

def _build_params_gui(self, _checkHBox):
"""
Builds the GUI for setting the parameters.
"""
_aVBox = self._build_compund_gui_box(_checkHBox, "Parameters:")
self._etaCounter = self._build_counter('learning rate:', .1, 5, _aVBox, .1, 1)
self._alphaCounter = self._build_counter('momentum rate:', .2, .8, _aVBox, .1, 1, False)
self._minRmsCounter = self._build_counter('minimum error:', 0, .1, _aVBox, .01)
self._minRmsCounter.get_adjustment().set_value(.01)
self._minDeltaRmsCounter = self._build_counter('minimum delta error:', 0, .1, _aVBox, .01)

def _build_extra_gui(self, _checkHBox):
"""
Builds the GUI part for extra options.
Expand All @@ -76,6 +89,7 @@ def _build_extra_gui(self, _checkHBox):
self._minCounter.get_adjustment().set_value(-1)
self._maxCounter.get_adjustment().set_value(1)
self._momentum = gtk.CheckButton('Use momentum')
self._momentum.connect('clicked', self.__on_momentum)
self._recurrent = gtk.CheckButton('Recurent network')
_aVBox.add(self._momentum)
_aVBox.add(self._recurrent)
Expand All @@ -88,12 +102,14 @@ def _build_IO_gui(self, _topVBox):
_topVBox VBox holding the widgets built by this function
"""
_fileHBox = gtk.HBox()
_topVBox.pack_start(_fileHBox, False, False, 5)
_fileVBox = gtk.VBox()
_topVBox.pack_start(_fileVBox, False, False, 5)
_fileLabel = gtk.Label('Input filename:')
_fileHBox.pack_start(_fileLabel, False, False, 5)
self._fileChoose = gtk.FileChooserButton("Select input filename")
_fileHBox.pack_start(self._fileChoose, True, True, 5)
self._rCounter = self._build_counter('Max steps:', 1000, 3000, _fileHBox, 100, 0)
_fileVBox.pack_start(_fileHBox, False, False, 5)
self._rCounter = self._build_counter('Max steps:', 1000, 3000, _fileVBox, 100, 0)

def _build_compund_gui_box(self, _checkHBox, frame):
"""
Expand Down Expand Up @@ -239,6 +255,11 @@ def _complete_config(self):
self._configDict['momentum'] = self._momentum.get_active()
self._configDict['recurrent'] = self._recurrent.get_active()

self._configDict['alpha'] = self._alphaCounter.get_value()
self._configDict['eta'] = self._etaCounter.get_value()
self._configDict['min_rms'] = self._minRmsCounter.get_value()
self._configDict['min_delta_rms'] = self._minDeltaRmsCounter.get_value()

if self._configDict['minW'] > self._configDict['maxW'] - .1:
self._report(gtk.MESSAGE_WARNING, "Invalid interval for weights, ignored")
self._configDict['minW'] = -1
Expand Down Expand Up @@ -273,3 +294,6 @@ def _report(self, msg_type, text):
md.run()
md.destroy()

def __on_momentum(self, widget, data=None):
self._alphaCounter.set_sensitive(not self._alphaCounter.get_sensitive())

8 changes: 0 additions & 8 deletions src/globaldefs.py
Expand Up @@ -20,14 +20,6 @@
VAL_PLOT_SUFFIX = '.val.png'
LOG_SUFFIX = '.log'

XXX = 800
YYY = 600

LOGNAME = 'MAIN.log'
LOGFNAME = LOGNAME

MIN_RMS = 1e-4
MIN_DRMS = 0
ETA = .2
ALPHA = .5

3 changes: 2 additions & 1 deletion src/gui.py
Expand Up @@ -24,7 +24,8 @@ def __init__(self):
widgets.
"""
super(MainWindow, self).__init__()
self.set_size_request(XXX, YYY)
self.set_size_request(800, 600)

self.set_title(TITLE)
self.set_icon_from_file(ICON_FILE)
self.connect('delete_event', self.__on_exit)
Expand Down
21 changes: 15 additions & 6 deletions src/network.py
Expand Up @@ -10,7 +10,6 @@
import normalizer
import saver
from units import *
from globaldefs import *

def log(x):
"""
Expand Down Expand Up @@ -64,7 +63,6 @@ def __init__(self, config, gui, graph):
config User configuration.
"""
print config
self._gui = gui
self._parse_network(config)
self._parse_activation(config)
Expand Down Expand Up @@ -97,7 +95,7 @@ def learn_step(self):
Bootstraps the learning phase.
"""
rms = self._do_one_learning_step()
done = rms < MIN_RMS or abs(rms - self._orms) < MIN_DRMS
done = rms < self._MIN_RMS or abs(rms - self._orms) < self._MIN_DRMS
self._orms = rms

self._grapher.graph()
Expand Down Expand Up @@ -212,6 +210,14 @@ def _parse_network(self, config):
# momentum?
self._momentum = config['momentum']

# params
self._eta = config['eta']
self._alpha = config['alpha']

# rms params
self._MIN_RMS = config['min_rms']
self._MIN_DRMS = config['min_delta_rms']

def _prepare_data(self, config):
"""
Reads learning set, normalizing it and preparing the auxiliary lists
Expand Down Expand Up @@ -262,7 +268,8 @@ def __build_hidden1(self):
"""
self._hidden1 = []
for i in range(self._h1):
n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum, 'h1{0}'.format(i))
n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum,
'h1{0}'.format(i), self._eta, self._alpha)
n.set_recurrent(self._recurrent)
for inp in self._inputs:
n.connect(inp)
Expand All @@ -276,7 +283,8 @@ def __build_hidden2(self):
"""
self._hidden2 = []
for i in range(self._h2):
n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum, 'h2{0}'.format(i))
n = Neuron(self._mW, self._MW, self._f, self._df, self._momentum,
'h2{0}'.format(i), self._eta, self._alpha)
n.set_recurrent(self._recurrent)
if self._h1:
for inp in self._hidden1:
Expand All @@ -292,7 +300,8 @@ def __build_output(self):
"""
Builds the output layer and the end of the network.
"""
self._output = Neuron(self._mW, self._MW, self._f, self._df, self._momentum, 'o')
self._output = Neuron(self._mW, self._MW, self._f, self._df,
self._momentum, 'o', self._eta, self._alpha)
self._output.set_recurrent(self._recurrent)
if self._h2:
for inp in self._hidden2:
Expand Down
12 changes: 7 additions & 5 deletions src/units.py
Expand Up @@ -166,7 +166,7 @@ class Neuron(Unit):
self.value() will return the output of the neuron
"""
def __init__(self, minW, maxW, f, df, momentum, name=''):
def __init__(self, minW, maxW, f, df, momentum, name, eta, alpha):
super(Neuron, self).__init__(name, 0)
self._min = minW
self._max = maxW
Expand All @@ -177,6 +177,8 @@ def __init__(self, minW, maxW, f, df, momentum, name=''):
self._momentum = momentum
if self._momentum:
self._ow = []
self._ETA = eta
self._ALPHA = alpha

def inputs(self):
return self._inputs
Expand Down Expand Up @@ -245,9 +247,9 @@ def report_and_learn_from_error(self):

for i in range(len(self._weights)):
w, inp = self._weights[i], self._inputs[i]
delta = ETA * self._err * self._df(self._value) * inp.value()
delta = self._ETA * self._err * self._df(self._value) * inp.value()
if self._momentum:
delta += ETA * ALPHA * self._ow[i]
delta += self._ETA * self._ALPHA * self._ow[i]
self._ow[i] = delta
_logger.info('Neuron {0}: delta weight{1}: {2}'.format(self._name, i, delta))
self._weights[i] -= delta
Expand All @@ -257,9 +259,9 @@ def report_and_learn_from_error(self):
self._weights[i] = 1
_logger.info('Neuron {0}: weight{1}: {2}'.format(self._name, i, self._weights[i]))
if self._selfw:
delta = ETA * self._err * self._df(self._value) * self._value
delta = self._ETA * self._err * self._df(self._value) * self._value
if self._momentum:
delta += ETA * ALPHA * self._sow
delta += self._ETA * self._ALPHA * self._sow
self._sow = delta
_logger.info('Neuron {0}: delta self weight: {1}'.format(self._name, self._selfw))
self._selfw -= delta
Expand Down

0 comments on commit c2e9d12

Please sign in to comment.