/
graph_2D.py
180 lines (150 loc) · 4.83 KB
/
graph_2D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Classes for making simple 2D visualizations.
"""
import numpy as N
from theano.compat.six.moves import xrange
from theano import config
class Graph2D(object):
"""
A class for plotting simple graphs in two dimensions.
Parameters
----------
shape : tuple
The shape of the display of the graph in (rows, cols)
format. Units are pixels
xlim : tuple
A tuple specifying (xmin, xmax). This determines what
portion of the real numbers are displayed in the graph.
ycenter : float
The coordinate of the center pixel along the y axis.
"""
def __init__(self, shape, xlim, ycenter):
self.xmin = 0.
self.xmax = 0.
self.set_shape(shape)
self.set_xlim(xlim)
self.set_ycenter(ycenter)
self.components = []
def set_shape(self, shape):
"""
Sets the shape of the display (in pixels)
Parameters
----------
shape : tuple
The (rows, columns) of the display.
"""
self.rows = shape[0]
self.cols = shape[1]
def set_xlim(self, xlim):
"""
Sets the range of space that is plotted in the graph.
Parameters
----------
xlim : tuple
The range (xmin, xmax)
"""
# x coordinate of center of leftmost pixel
self.xmin = xlim[0]
# x coordinate of center of rightmost pixel
self.xmax = xlim[1]
self.delta_x = (self.xmax-self.xmin)/float(self.cols-1)
def set_ycenter(self, ycenter):
"""
Sets the y coordinate of the central pixel of the display.
Parameters
----------
ycenter : float
The desired coordinate.
"""
self.delta_y = self.delta_x
self.ymin = ycenter - (self.rows / 2) * self.delta_y
self.ymax = self.ymin + (self.rows -1) * self.delta_y
def render(self):
"""
Renders the graph.
Returns
-------
output : ndarray
An ndarray in (rows, cols, RGB) format.
"""
rval = N.zeros((self.rows, self.cols, 3))
for component in self.components:
rval = component.render( prev_layer = rval, parent = self )
assert rval is not None
return rval
def get_coords_for_col(self, i):
"""
Returns the coordinates of every pixel in column i of the
graph.
Parameters
----------
i : int
Column index
Returns
-------
coords : ndarray
A vector containing the real-number coordinates of every
pixel in column i of the graph.
"""
X = N.zeros((self.rows,2),dtype=config.floatX)
X[:,0] = self.xmin + float(i) * self.delta_x
X[:,1] = self.ymin + N.cast[config.floatX](N.asarray(range(self.rows-1,-1,-1))) * self.delta_y
return X
class HeatMap(object):
"""
A class for plotting 2-D functions as heatmaps.
Parameters
----------
f : callable
A callable that takes a design matrix of 2D coordinates and returns a
vector containing the function value at those coordinates
normalizer : callable, optional
None or a callable that takes a 2D numpy array and returns a 2D numpy
array
render_mode : str
* 'o' : opaque.
* 'r' : render only to the (r)ed channel
"""
def __init__(self, f, normalizer=None, render_mode = 'o'):
self.f = f
self.normalizer = normalizer
self.render_mode = render_mode
def render(self, prev_layer, parent):
"""
Renders the heatmap.
Parameters
----------
prev_layer : numpy ndarray
An image that will be copied into the new output.
The new image will be rendered on top of the first
one, i.e., `prev_layer` will be visible through the
new heatmap if the new heatmap is not rendered in
fully opaque mode.
parent : Graph2D
A Graph2D object that defines the coordinate system
of the heatmap.
Returns
-------
img : The rendered heatmap
"""
my_img = prev_layer * 0.0
for i in xrange(prev_layer.shape[1]):
X = parent.get_coords_for_col(i)
f = self.f(X)
if len(f.shape) == 1:
for j in xrange(3):
my_img[:,i,j] = f
else:
my_img[:,i,:] = f
#end if
#end for i
if self.normalizer is not None:
my_img = self.normalizer(my_img)
assert my_img is not None
if self.render_mode == 'r':
my_img[:,:,1:] = prev_layer[:,:,1:]
elif self.render_mode == 'o':
pass
else:
raise NotImplementedError()
return my_img