-
Notifications
You must be signed in to change notification settings - Fork 0
/
rsvd_lc
executable file
·80 lines (61 loc) · 2.08 KB
/
rsvd_lc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/python
"""A script to plot the output of rsvd_train.
Plots the training and probe error using matplotlib.
Reads input from stdin:
Usage:
$ rsvd_train --probe probe.arr train.arr x x x | rsvd_lc
"""
import sys
import re
import pylab as pl
import time
rec = re.compile(r"^(\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)$")
class Curve(object):
def __init__(self, ax, lw = 1):
self.curve, = ax.plot([], [], animated=True, lw = lw)
self.xdata, self.ydata = [], []
def update(self, x, y):
self.xdata.append(x)
self.ydata.append(y)
self.curve.set_data(self.xdata, self.ydata)
def main():
fig = pl.figure()
ax = fig.add_subplot(111)
ax.set_ylim(0.8, 1.2)
ax.set_xlim(0, 10)
traincurve = Curve(ax)
probecurve = Curve(ax)
pl.ylabel("RMSE")
pl.xlabel("epochs")
pl.legend(["train error","probe error"])
def run(*args):
print "rsvd learn curve reading from stdin..."
while 1:
line = sys.stdin.readline()
if line == None:
break
print line,
sys.stdout.flush()
m = rec.match(line)
if m != None:
epoch = int(m.groups()[0])
trainerr = float(m.groups()[1])
probeerr = float(m.groups()[2])
t = float(m.groups()[3])
xmin, xmax = ax.get_xlim()
if epoch >= xmax:
ax.set_xlim(xmin, 2*xmax)
fig.canvas.draw()
background = fig.canvas.copy_from_bbox(ax.bbox)
traincurve.update(epoch,trainerr)
probecurve.update(epoch,probeerr)
# just draw the animated artist
ax.draw_artist(traincurve.curve)
ax.draw_artist(probecurve.curve)
# just redraw the axes rectangle
fig.canvas.blit(ax.bbox)
manager = pl.get_current_fig_manager()
manager.window.after(100, run)
pl.show()
if __name__ == "__main__":
main()