Skip to content

Commit

Permalink
Simplify TensorSignal tf loading
Browse files Browse the repository at this point in the history
Move create_signals from graph_optimizer to tensor_graph

Consolidate dummy testing objects in tests/dummies.py
  • Loading branch information
drasmuss committed Jul 24, 2018
1 parent 3523a6b commit 6ddf68e
Show file tree
Hide file tree
Showing 15 changed files with 837 additions and 772 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Release History
elements (requires ``tensorflow>=1.9.0``)
- Improved accuracy of ``SoftLIFRate`` implementation for small values (`#45
<https://github.com/nengo/nengo-dl/pull/45>`_)
- Simplified how ``TensorSignals`` are loaded into the TensorFlow graph

**Fixed**

Expand Down
166 changes: 1 addition & 165 deletions nengo_dl/graph_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nengo.utils.simulator import operator_dependency_graph
import numpy as np

from nengo_dl import (signals, process_builders, builder, tensor_node,
from nengo_dl import (process_builders, builder, tensor_node,
op_builders, learning_rule_builders, neuron_builders)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1065,170 +1065,6 @@ def noop_order_signals(plan, **_):
return all_signals, plan


def create_signals(sigs, plan, float_type, minibatch_size):
"""Groups signal data together into larger arrays, and represent each
individual signal as a slice into that array.
Parameters
----------
sigs : list of :class:`~nengo:nengo.builder.Signal`
Base signals arranged into the order in which they should reside in
memory (e.g., output from ``order_signals``)
plan : list of tuple of :class:`~nengo:nengo.builder.Operator`
Operator execution plan (only used to get a list of all the operators)
float_type : ``np.float32`` or ``np.float64``
Floating point precision to use for signals
minibatch_size : int
Number of items in each minibatch
Returns
-------
base_arrays : dict of {object : :class:`~numpy:numpy.ndarray`}
combined arrays, containing the initial values for all signals
sig_map : dict of {:class:`~nengo:nengo.builder.Signal`: \
:class:`.signals.TensorSignal`}
mapping from ``nengo`` Signals to ``nengo_dl`` TensorSignals (views
into the base arrays)
"""

base_arrays = OrderedDict()
curr_keys = {}
sig_map = {}
sig_idxs = {s: i for i, s in enumerate(sigs)}

# find the non-overlapping partitions of the signals
breaks = []
diff = defaultdict(int)
for ops in plan:
# note: we don't include Resets, otherwise the big reset block
# overrides most of the partitioning
if not isinstance(ops[0], Reset):
for i in range(len(ops[0].all_signals)):
op_sigs = [op.all_signals[i].base for op in ops]
idxs = [sig_idxs[s] for s in op_sigs]
diff[op_sigs[np.argmin(idxs)]] += 1
diff[op_sigs[np.argmax(idxs)]] -= 1

# find the partition points in signal list
open = 0
for i, s in enumerate(sigs):
if s in diff:
open += diff[s]

if open == 0:
breaks += [i + 1]

logging.debug("partitions")
logging.debug("\n%s", "".join("|" if i in breaks else " "
for i in range(len(sigs))))

# create all the base signals
for i, sig in enumerate(sigs):
assert sig not in sig_map
assert not sig.is_view

if i in breaks:
# start a new array for all current bases
for k in curr_keys:
curr_keys[k] = object()

# convert to appropriate dtype
if np.issubdtype(sig.dtype, np.floating):
dtype = float_type
elif np.issubdtype(sig.dtype, np.integer):
dtype = np.int32
else:
raise NotImplementedError

# resize scalars to length 1 vectors
shape = sig.shape if sig.shape != () else (1,)

# parameters of signal that affect the base array
array_params = (dtype, shape[1:], sig.trainable, sig.minibatched)

# key used to map signals to base arrays
if array_params not in curr_keys:
curr_keys[array_params] = object()
key = curr_keys[array_params]

initial_value = sig.initial_value.astype(dtype, copy=False)

# broadcast scalars up to full size
if initial_value.shape != shape:
initial_value = np.resize(initial_value, shape)

if sig.minibatched:
# duplicate along minibatch dimension
initial_value = np.tile(
initial_value[..., None],
tuple(1 for _ in shape) + (minibatch_size,))

if key in base_arrays:
base_arrays[key][0].append(initial_value)
base_arrays[key][2] += shape[0]
else:
base_arrays[key] = [[initial_value], sig.trainable, shape[0]]

n = base_arrays[key][-1]
indices = np.arange(n - shape[0], n)

sig_map[sig] = signals.TensorSignal(
indices, key, dtype, shape,
minibatch_size if sig.minibatched else None,
label=sig.name)

logger.debug("created base signal")
logger.debug(sig)
logger.debug(sig_map[sig])

for key in base_arrays:
arrs, t, _ = base_arrays[key]
base_arrays[key] = (np.concatenate(arrs, axis=0), t)

# add any signal views to the sig_map
all_views = [sig for ops in plan for op in ops for sig in op.all_signals
if sig.is_view]
for sig in all_views:
if sig.size == sig.base.size:
# reshape view
sig_map[sig] = sig_map[sig.base].reshape(sig.shape)
else:
if sig.shape[1:] != sig.base.shape[1:]:
# TODO: support this?
raise NotImplementedError(
"Slicing on axes > 0 is not supported")

# slice view
assert np.all([x == 1 for x in sig.elemstrides[1:]])

start = sig.elemoffset
stride = sig.elemstrides[0]
stop = start + sig.size * stride
if stop < 0:
stop = None

sig_map[sig] = sig_map[sig.base][slice(start, stop, stride)]

# error checking
for sig, tensor_sig in sig_map.items():
# tensorsignal shapes should match signal shapes
assert tensor_sig.shape == (sig.shape if sig.shape != () else (1,))

# tensorsignal values should match signal values
initial_value = sig.initial_value
if sig.minibatched:
initial_value = initial_value[..., None]

assert np.allclose(base_arrays[tensor_sig.key][0][tensor_sig.indices],
initial_value.astype(dtype))

logger.debug("base arrays")
logger.debug("\n".join([str((k, v.dtype, v.shape, trainable))
for k, (v, trainable) in base_arrays.items()]))

return base_arrays, sig_map


def remove_unmodified_resets(operators):
"""Remove any Reset operators that are targeting a signal that is
never modified.
Expand Down
17 changes: 5 additions & 12 deletions nengo_dl/learning_rule_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ def __init__(self, ops, signals):

self.pre_data = signals.combine(
[op.pre_filtered for op in ops
for _ in range(op.post_filtered.shape[0])], load_indices=False)
for _ in range(op.post_filtered.shape[0])])
self.pre_data = self.pre_data.reshape((self.post_data.shape[0],
ops[0].pre_filtered.shape[0]))
self.pre_data.load_indices(constant=signals.constant)

self.learning_rate = signals.op_constant(
ops, [op.post_filtered.shape[0] for op in ops], "learning_rate",
Expand Down Expand Up @@ -58,10 +57,9 @@ def __init__(self, ops, signals):

self.pre_data = signals.combine(
[op.pre_filtered for op in ops
for _ in range(op.post_filtered.shape[0])], load_indices=False)
for _ in range(op.post_filtered.shape[0])])
self.pre_data = self.pre_data.reshape((self.post_data.shape[0],
ops[0].pre_filtered.shape[0]))
self.pre_data.load_indices(constant=signals.constant)

self.weights_data = signals.combine([op.weights for op in ops])
self.output_data = signals.combine([op.delta for op in ops])
Expand Down Expand Up @@ -102,10 +100,9 @@ def __init__(self, ops, signals):

self.pre_data = signals.combine(
[op.pre_decoded for op in ops
for _ in range(op.post_filtered.shape[0])], load_indices=False)
for _ in range(op.post_filtered.shape[0])])
self.pre_data = self.pre_data.reshape((self.post_data.shape[0],
ops[0].pre_decoded.shape[0]))
self.pre_data.load_indices(constant=signals.constant)

self.learning_data = signals.combine(
[op.learning_signal for op in ops
Expand Down Expand Up @@ -280,17 +277,13 @@ class SimPESBuilder(OpBuilder):
def __init__(self, ops, signals):
super(SimPESBuilder, self).__init__(ops, signals)

self.error_data = signals.combine(
[op.error for op in ops], load_indices=False)
self.error_data = signals.combine([op.error for op in ops])
self.error_data = self.error_data.reshape(
(len(ops), ops[0].error.shape[0], 1))
self.error_data.load_indices(constant=signals.constant)

self.pre_data = signals.combine(
[op.pre_filtered for op in ops], load_indices=False)
self.pre_data = signals.combine([op.pre_filtered for op in ops])
self.pre_data = self.pre_data.reshape(
(len(ops), 1, ops[0].pre_filtered.shape[0]))
self.pre_data.load_indices(constant=signals.constant)

self.alpha = signals.op_constant(
ops, [1 for _ in ops], "learning_rate", signals.dtype, ndims=4) * (
Expand Down
28 changes: 9 additions & 19 deletions nengo_dl/op_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, ops, signals):
# bases, which we need to handle
scatters = defaultdict(list)
for op in ops:
scatters[signals.sig_map[op.dst].key] += [op]
scatters[signals[op.dst].key] += [op]
self.scatters = []
for group in scatters.values():
value = np.concatenate(
Expand Down Expand Up @@ -79,21 +79,19 @@ def __init__(self, ops, signals):
srcs = []
dsts = []
for op in ops:
srcs += [signals.sig_map[op.src][op.src_slice]]
dsts += [signals.sig_map[op.dst][op.dst_slice]]
srcs += [signals[op.src][op.src_slice]]
dsts += [signals[op.dst][op.dst_slice]]

self.mode = "inc" if ops[0].inc else "update"

self.src_data = signals.combine(srcs, load_indices=False)
self.src_data = signals.combine(srcs)
self.dst_data = signals.combine(dsts)

if not self.src_data.minibatched and self.dst_data.minibatched:
# broadcast indices so that the un-minibatched src data gets
# copied to each minibatch dimension in dst
self.src_data = self.src_data.broadcast(-1, signals.minibatch_size)

self.src_data.load_indices(constant=signals.constant)

def build_step(self, signals):
signals.scatter(self.dst_data, signals.gather(self.src_data),
mode=self.mode)
Expand Down Expand Up @@ -123,8 +121,8 @@ def __init__(self, ops, signals):
self.Y_data = signals.combine([op.Y for op in ops])

# group all the A's and X's
self.A_data = signals.combine([op.A for op in ops], load_indices=False)
self.X_data = signals.combine([op.X for op in ops], load_indices=False)
self.A_data = signals.combine([op.A for op in ops])
self.X_data = signals.combine([op.X for op in ops])

# separate data from each op along the first dimension
if self.A_data.shape[0] != self.X_data.shape[0]:
Expand All @@ -141,9 +139,6 @@ def __init__(self, ops, signals):
if not self.A_data.minibatched and self.X_data.minibatched:
self.A_data = self.A_data.reshape(self.A_data.shape + (1,))

self.A_data.load_indices(constant=signals.constant)
self.X_data.load_indices(constant=signals.constant)

def build_step(self, signals):
A = signals.gather(self.A_data)
X = signals.gather(self.X_data)
Expand Down Expand Up @@ -176,8 +171,8 @@ def __init__(self, ops, signals):
self.Y_data = signals.combine([op.Y for op in ops])

# group all the A's and X's
A_data = signals.combine([op.A for op in ops], load_indices=False)
X_data = signals.combine([op.X for op in ops], load_indices=False)
A_data = signals.combine([op.A for op in ops])
X_data = signals.combine([op.X for op in ops])

# separate data from each op along the first dimension
self.A_data = A_data.reshape((len(ops), -1, A_data.shape[1]))
Expand Down Expand Up @@ -206,9 +201,6 @@ def __init__(self, ops, signals):
# if not self.A_data.minibatched and self.X_data.minibatched:
# self.A_data = self.A_data.reshape(self.A_data.shape + (1,))

self.A_data.load_indices(constant=signals.constant)
self.X_data.load_indices(constant=signals.constant)

def build_step(self, signals):
A = signals.gather(self.A_data)
X = signals.gather(self.X_data)
Expand Down Expand Up @@ -287,10 +279,8 @@ def __init__(self, ops, signals):
self.Y_data = signals.combine([op.Y for op in ops])

# group all the A's and X's
self.A_data = signals.combine([op.A for op in ops],
load_indices=False)
self.A_data = signals.combine([op.A for op in ops])
self.A_data = self.A_data.reshape((-1,))
self.A_data.load_indices(constant=signals.constant)
self.X_data = signals.combine([op.X for op in ops])

assert not self.A_data.minibatched
Expand Down
Loading

0 comments on commit 6ddf68e

Please sign in to comment.