Skip to content

Commit

Permalink
Plotting of network done (99%)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mihai Maruseac committed Apr 29, 2011
1 parent b39fab3 commit b46d175
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
30 changes: 28 additions & 2 deletions src/grapher.py
Expand Up @@ -24,10 +24,33 @@ def graph(self):
Called when network graph needs to be updated.
"""
self._do_cleanup_draw()

gc = self._w.get_style().black_gc
pcon = self._w.get_pango_context()
for n in self._neurons:
for n in self._units:
n.draw(self.__pixmap, gc, SIZE, pcon)

ngc = self._w.get_window().new_gc()
ngc.copy(gc)
ngc.set_line_attributes(2, gtk.gdk.LINE_SOLID, gtk.gdk.CAP_ROUND,
gtk.gdk.JOIN_BEVEL)
ex, ey = self._end.entry_point(SIZE)
ox, oy = self._output.exit_point(SIZE)
self.__pixmap.draw_line(ngc, ox, oy, ex, ey)

for n in self._neurons:
ex, ey = n.entry_point(SIZE)
for (nn, w) in zip(n._inputs, n._weights):
if nn == n:
# TODO: recurrent networks
pass
if w < 0:
ngc.set_rgb_fg_color(gtk.gdk.Color(red=abs(w)))
else:
ngc.set_rgb_fg_color(gtk.gdk.Color(blue=w))
sx, sy = nn.exit_point(SIZE)
self.__pixmap.draw_line(ngc, sx, sy, ex, ey)

self._img.set_from_pixmap(self.__pixmap, None)

def build_basic_network(self, N, inputs, h1, hidden1, h2, hidden2,
Expand All @@ -53,7 +76,10 @@ def build_basic_network(self, N, inputs, h1, hidden1, h2, hidden2,
x += PAD + SIZE
self._place([end], x)

self._neurons = inputs + hidden1 + hidden2 + [output, end]
self._neurons = hidden1 + hidden2 + [output]
self._end = end
self._output = output
self._units = inputs + self._neurons + [end]

def _place(self, elems, x):
"""
Expand Down
12 changes: 12 additions & 0 deletions src/units.py
Expand Up @@ -43,13 +43,21 @@ def _draw_label(self, pbuff, gc, size, pcon):
def _draw_label_text(self, l):
l.set_text(self._name)

def exit_point(self, size):
return (self._x + size, self._y + size / 2)

def entry_point(self, size):
return (self._x, self._y + size / 2)

class Fixed(Unit):
"""
A unit holding a fixed value, keeping that value constant and not
learning.
"""
def __init__(self, name='', value=1):
super(Fixed, self).__init__(name, value)
self._inputs = []
self._weights = []

def _draw_label_text(self, l):
l.set_text("1")
Expand Down Expand Up @@ -103,6 +111,9 @@ def _draw_image(self, pbuff, gc, size):
(self._x + size - 10, self._y + size)]
pbuff.draw_polygon(gc, False, ps)

def entry_point(self, size):
return (self._x - 10, self._y + size / 2)

class Neuron(Unit):
"""
Actual neuron.
Expand All @@ -111,6 +122,7 @@ class Neuron(Unit):
"""
def __init__(self, minW, maxW, name=''):
super(Neuron, self).__init__(name, None)
print minW, maxW
self._min = minW
self._max = maxW
self._weights = []
Expand Down

0 comments on commit b46d175

Please sign in to comment.